pgmpy
pgmpy copied to clipboard
Estimator.get_parameters don't return tensors when using joblib and backend==torch
In [1]: from pgmpy.global_vars import config
In [2]: config.set_backend('torch')
In [3]: config.get_backend()
Out[3]: 'torch'
In [4]: from pgmpy.utils import get_example_model
In [5]: model = get_example_model('alarm')
df
In [6]: df = model.simulate(int(1e4))
In [7]: model.fit(df)
WARNING:pgmpy:Replacing existing CPD for HISTORY
WARNING:pgmpy:Replacing existing CPD for CVP
WARNING:pgmpy:Replacing existing CPD for PCWP
WARNING:pgmpy:Replacing existing CPD for HYPOVOLEMIA
WARNING:pgmpy:Replacing existing CPD for LVEDVOLUME
WARNING:pgmpy:Replacing existing CPD for LVFAILURE
WARNING:pgmpy:Replacing existing CPD for STROKEVOLUME
WARNING:pgmpy:Replacing existing CPD for ERRLOWOUTPUT
WARNING:pgmpy:Replacing existing CPD for HRBP
WARNING:pgmpy:Replacing existing CPD for HREKG
WARNING:pgmpy:Replacing existing CPD for ERRCAUTER
WARNING:pgmpy:Replacing existing CPD for HRSAT
WARNING:pgmpy:Replacing existing CPD for INSUFFANESTH
WARNING:pgmpy:Replacing existing CPD for ANAPHYLAXIS
WARNING:pgmpy:Replacing existing CPD for TPR
WARNING:pgmpy:Replacing existing CPD for EXPCO2
WARNING:pgmpy:Replacing existing CPD for KINKEDTUBE
WARNING:pgmpy:Replacing existing CPD for MINVOL
WARNING:pgmpy:Replacing existing CPD for FIO2
WARNING:pgmpy:Replacing existing CPD for PVSAT
WARNING:pgmpy:Replacing existing CPD for SAO2
WARNING:pgmpy:Replacing existing CPD for PAP
WARNING:pgmpy:Replacing existing CPD for PULMEMBOLUS
WARNING:pgmpy:Replacing existing CPD for SHUNT
WARNING:pgmpy:Replacing existing CPD for INTUBATION
WARNING:pgmpy:Replacing existing CPD for PRESS
WARNING:pgmpy:Replacing existing CPD for DISCONNECT
WARNING:pgmpy:Replacing existing CPD for MINVOLSET
WARNING:pgmpy:Replacing existing CPD for VENTMACH
WARNING:pgmpy:Replacing existing CPD for VENTTUBE
WARNING:pgmpy:Replacing existing CPD for VENTLUNG
WARNING:pgmpy:Replacing existing CPD for VENTALV
WARNING:pgmpy:Replacing existing CPD for ARTCO2
WARNING:pgmpy:Replacing existing CPD for CATECHOL
WARNING:pgmpy:Replacing existing CPD for HR
WARNING:pgmpy:Replacing existing CPD for CO
WARNING:pgmpy:Replacing existing CPD for BP
In [8]: model.cpds[0].values
Out[8]: array([0.9912, 0.0088])
The CPD values in the above example should be a torch tensor, but it returns a numpy array. This happens only when n_jobs != 1 in the get_parameters method call. A temporary fix has been applied currently.