flax icon indicating copy to clipboard operation
flax copied to clipboard

Issues restoring checkpoint of struct.dataclass w/ FrozenDict attr

Open rwightman opened this issue 5 years ago • 7 comments

Working from a modified ImageNet Linen example, I've added two state attr for Polyak averaging ema values as so

@flax.struct.dataclass
class TrainState:
    step: int
    optimizer: flax.optim.Optimizer
    model_state: Any
    dynamic_scale: flax.optim.DynamicScale
    ema_params: flax.core.FrozenDict = None  # lazy init on first step
    ema_model_state: flax.core.FrozenDict = None   # lazy init on first step

Restoring the checkpoints with that state causes an error as the FrozenDicts get restored as dicts. I'm not sure if this is a bug or feature request (ie is this expected). I noticed there is registration fn for restoring state dict, FrozenDicts are among them, should that not cover this case? Or should I wrap my ema state in another class and register my own state dict restore fn that freezes the dicts.

I'm currently doing this hack after restore to work around the issue...

    if step_offset > 0:
        state = state.replace(
            ema_params=flax.core.freeze(state.ema_params),
            ema_model_state=flax.core.freeze(state.ema_model_state))

rwightman avatar Nov 25 '20 21:11 rwightman

I think most likely this is simply an oversight -- dataclasses should restore FrozenDicts. We should probably have a unit test serializing and deserializing a FrozenDict to make sure it comes out as a FrozenDict as well.

Roping in @jheek as he's been doing some changes to FrozenDict lately.

avital avatar Nov 26 '20 12:11 avital

@jheek -- if this fix doesn't require deep understanding, maybe best to mark it as "pull requests encouraged" as looks like @rwightman has a workaround for now.

avital avatar Nov 26 '20 12:11 avital

This is not a trivial problem actually. We restore attributes based on the original type. But when it is None we cannot deduce it. I think this would just work if you use init_train_state = TrainState(..., ema_params=FrozenDict(),ema_model_state=FrozenDict())

jheek avatar Nov 27 '20 09:11 jheek

@jheek ah, k... so in thise case yeah, the empty dict acts as a sufficient 'not-initialized' truthy value that the rest of my lazy logic should still work.

rwightman avatar Nov 27 '20 17:11 rwightman

@jheek I thought this would be a quick and easy fix but ended up going down a rabbit hole. The idea doesn't work.

You cannot restore a FrozenDict with state and target having different keys, in this case no keys in the target, based on the way the FrozenDict restore works. For the typical use case I guess that makes sense but it's a non-obvious, silent failure in this case (you just end up restoring an empty Dict when there was one in the checkpoint)... confusing.

EDIT: So I guess I'm back at the cleanest path forward being to implement my own EmaUpdater() class holding the two optional FrozenDicts that has a default False truthy state when not initialized with params + state. I write my own to/from state dict methods for that class, register them, and I avoid calling the normal FrozenDict methods...

rwightman avatar Nov 27 '20 23:11 rwightman

Yes I think that is more clean. Another alternative if you don't want to register a bunch of classes is to use restore_checkpoint(target=None) which gives you the raw state dict. This state dict you can "pre-process" whatever we you like

jheek avatar Nov 30 '20 09:11 jheek

@jheek I created an EmaState dataclass and got that working in a less hacky fashion. Still have a bit of an issue, allowing training to start with ema active and then disabling or other way around seems to require custom serialization due to the way None values are handled.

So a question about handling None in either serialized state or target. Is the current behaviour ever correct or useful? On deseriization, if a target value is None the type isn't determined properly and it just dumps the dict of serialized state into target. If the state value is None it crashes trying to iterate over the state.

Wouldn't None be more useful as a 'do not deserialized' sentinel? If a target value is None it would not try to deserialize and restore that field. If the serialized state for that field was None, it would leave the current target unchanged. That seems to be much more useful as a pattern that would allow some natural checkpoint fwd/bwd compat or toggling of active state fields between sessions.

If there is a good reason not to use None for the described functionality. Would an explicit type make sense?

@struct.dataclass
class TrainState
   blahblah: SomeOtherState = _UNUSED()

rwightman avatar Dec 01 '20 06:12 rwightman