axon
axon copied to clipboard
Make Axon.Updates state easier to query/work with
With the new container API, we can support structs. Right now the Axon.Updates API uses a tuple to represent optimizer state. It's not a bad, but I think it's better to use something which more explicitly models optimizer state. It's basically just a linked list with transform_state
which is an Nx.Container that holds the current transform state and next_state
which holds the next transforms GradientState. I think this gives more flexibility to expand the API if necessary in the future, and makes the recursion a little easier to reason about than the current approach which just deletes / inserts data into nested tuples
Thinking about this a struct may or may not make sense, but it definitely makes sense to make it easier to access state of specific transformations. For example if I want to access the scale_by_state
state of an optimizer then it would be nice to be able to select that as a key or something else rather than having to know where it is in the state tuple
Polaris issue