keras
keras copied to clipboard
tweak torch parameter registration mechanism
this is a follow up from https://github.com/keras-team/keras/pull/19885 discussion where i am trying to make torch / keras well played together on tracking parameters.
the solution i ended up with:
- since modules are properly tracked with torch module, every torch_params will only safe it's own variables. nested variable resolution will be done by torch with
recurse=True
- change back to use parameter list instead of dict. i did consider to keep using dict given the readability since now key in torch param could actually be
variable.name
with just tracking variables the current layer holds. however, current seed generator actually create duplicated variable names. if https://github.com/keras-team/keras/blob/master/keras/src/random/seed_generator.py#L80 can be changed to something likef"{self.name}_generator_state"
it will work with ParameterDict approach. - in
_post_track/untrack_variables
, refresh the entire torch params and it's sublayers. this could be changed to not re-create all sublayers if this function ever becomes too slow.
i also added few torch specific tests to reflect some of the assumptions and usecases that torch user might have. eg. use state_dict
.