MXFusion
MXFusion copied to clipboard
Inference.run throws when passing Variables of type PARAMETER as 'constants' argument
Describe the bug
When passing Variables of type PARAMETER as 'constants' argument to the constructor of Inference (or child classes), this line throws during the execution of Inference's run method since the Variables that were passed as 'constants' are in self._var_trans but not in the 'kw' input argument.
Expected behavior
Execution does not throw and Variables passed as 'constants' argument to the Inference constructor, even if they are of type PARAMETER, are treated as constants and not optimized in training.
Desktop:
- OS: Ubuntu 18.04.2
- Python version: 3.6
- MXNet version: 1.4.1
- MXFusion version: 0.3.1
- MXNet context: CPU
Thanks for finding this!
The first question I have for this is whether it even makes sense to pass PARAMETER type variables to the 'constants' argument in Inference? It may not, in which case we should clarify the interface here in some way to make that obvious.
If it does make sense, it looks like the way to fix it is by adding a check in the prepare_executor method of the InferenceAlgorithm to not add 'constants' to self._var_trans. This is because if they are constants they shouldn't need to be transformed anyways, and we can just get them from the constants dictionary. Though that is a slightly funny interpretation of the transformation interface I think.
In the meantime, can you workaround any issue this causes by simply redefining the Variable in your model as a Constant?
Thanks @meissnereric. This issue is not blocking me. Defining Variable as CONSTANT does not work due to Issue 192, but setting grad_req property of Inference Parameter to 'null'' does the job.
I think there is value on supporting a PARAMETER to be passed as 'constant' argument of Inference. Several use-cases could benefit from it. If not supported, I suggest (1) making very clear in documentation it's not supported, (2) describing an alternative, and (3) issuing a warning if PARAMETER is passed as 'constant' to facilitate debugging.