keras icon indicating copy to clipboard operation
keras copied to clipboard

tweak torch parameter registration mechanism

Open haohuanw opened this issue 8 months ago • 26 comments

this is a follow up from https://github.com/keras-team/keras/pull/19885 discussion where i am trying to make torch / keras well played together on tracking parameters.

the solution i ended up with:

  1. since modules are properly tracked with torch module, every torch_params will only safe it's own variables. nested variable resolution will be done by torch with recurse=True
  2. change back to use parameter list instead of dict. i did consider to keep using dict given the readability since now key in torch param could actually be variable.name with just tracking variables the current layer holds. however, current seed generator actually create duplicated variable names. if https://github.com/keras-team/keras/blob/master/keras/src/random/seed_generator.py#L80 can be changed to something like f"{self.name}_generator_state" it will work with ParameterDict approach.
  3. in _post_track/untrack_variables, refresh the entire torch params and it's sublayers. this could be changed to not re-create all sublayers if this function ever becomes too slow.

i also added few torch specific tests to reflect some of the assumptions and usecases that torch user might have. eg. use state_dict.

haohuanw avatar Jun 23 '24 21:06 haohuanw