Lambda functions in Symbolic_KANLayer attributes are not serializable using pickle (torch.save)
Hello,
The new Symbolic_KANLayer has lambda functions as attributes. This makes it hard to save model checkpoints. Default model saving in torch uses pickle, and lambda functions are not serializable using pickle unless one uses a serialization package like dill (please see this for a concrete example).
I don't know if the Symbolic_KANLayer will be used in the future for large-scale experiments. If so I think it is advisable to get rid of all the attributes that are lambda functions in Symbolic_KANLayer, to avoid by construction that kind of problems.
For referencing, the error was:
AttributeError: Can't pickle local object 'Symbolic_KANLayer.__init__.<locals>.<lambda>'. Did you mean: '_return_value'?
See also: AttributeError: Can't pickle local object <locals>.<lambda>
Minimal example to reproduce this behavior:
import torch
import kan
torch.save(kan.KAN([5,5]), 'test.pt')
This last gives the error: AttributeError: Can't pickle local object 'Symbolic_KANLayer.__init__.<locals>.zero_fun'
A workaround using dill package:
import torch
import kan
import dill
torch.save(kan.KAN([5,5]), 'test.npz', pickle_module=dill)
This last works.