MXFusion
MXFusion copied to clipboard
Possible issue in forward sampling
Describe the bug When doing forward sampling, my code hits the following line:
https://github.com/amzn/MXFusion/blob/9e8e0a096504a76bc5d6bc9d4509965eda14780c/mxfusion/models/factor_graph.py#L427
This looks like a bug to me as the model class never has a graph attribute at the top level (or does it?).
To Reproduce Steps to reproduce the behavior:
- Build a model like BNN
- Run inference on BNN on Task A.
- Fine tune for task A using posteriors from 2.
- Predict on task A using fine tuned model from 3.
- Run inference on BNN using posteriors from 2 on Task B.
- Fine tune for task A using posteriors from 5.
- Predict on task A using fine tuned model from 6.
- Fine tune for task B using posteriors from 5.
- Predict on task B using fine tuned model from 8.
- See error
Desktop (please complete the following information):
- OS: OSX
- Python version 3.6
- MXNet version 1.3.0
- MXFusion version 0.2.2
- MXNet context CPU
- MXNet dtype float32
Additional context Add any other context about the problem here.
This is a bug on that line of code.
It should probably look something:
new_leaf = v.replicate(var_map=var_map,
replication_function=lambda x: ('recursive', 'recursive'))
new_leaf.graph = new_model.components_graph