jeta
jeta copied to clipboard
MAML is expecting more parameters than being passed
https://github.com/SforAiDl/jeta/blob/e55bf83c9f89c662872f50a1ef8885c1085403ef/jeta/maml.py#L14-L15
The above lines cause an error while running MAML
using OptiTrainer
OptiTrainer
calls the function maml_adapt
and it's parameters are passed in OptiTrainer
itself.
Note:
maml_adapt
is passed as a parameter toOptiTrainer
asadapt_fn
https://github.com/SforAiDl/jeta/blob/e55bf83c9f89c662872f50a1ef8885c1085403ef/jeta/opti_trainer.py#L145
An error is being thrown because maml_adapt
expects the above 2 arguments which are not being passed.
fas
can be included into OptiTrainer
as this parameter is common across many Optimization based approaches, but maml_lr
can't be included as it is specific to only MAML
.
Hence a different approach must be used to allow it as a parameter.
Looks like a bug introduced in #39.
We can add kwargs in OptiTrainer
. That will solve the problem of passing maml_lr
.