litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Separate out the biases

Open rasbt opened this issue 1 year ago • 6 comments

This separates the single bias config into 3 separate bias configs: QKV bias, attention projection bias, and MLP bias. This would be necessary to implement Grok, for example, which uses a QKV bias but no MLP bias.

rasbt avatar Mar 18 '24 17:03 rasbt

We might want to revive https://github.com/Lightning-AI/litgpt/pull/878 if we are doing this. What do you prefer?

carmocca avatar Mar 18 '24 17:03 carmocca

I had no idea. Yeah, then let's revive @Andrei-Aksionov's #878

rasbt avatar Mar 18 '24 17:03 rasbt

Only one test fails:

tests/test_config_hub.py::test_config_help[litgpt/pretrain.py-https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/tinystories.yaml]

The reason is that in the main branch, the yaml file contains the old bias notation. After the PR is merged, this fail should be fixed automagically.

Andrei-Aksionov avatar Mar 21 '24 20:03 Andrei-Aksionov

The reason is that in the main branch, the yaml file contains the old bias notation.

This also signals a breaking change. Can you add backwards-compatibility code to Config as we had in the past for other arguments?

carmocca avatar Mar 25 '24 03:03 carmocca

This also signals a breaking change. Can you add backwards-compatibility code to Config as we had in the past for other arguments?

Sure. But before we handled this in .from_* method, where legacy args were provided via **kwargs. Now the issue comes from jsonargparse when it checks/compares provided args against listed fields in the Config class. So the solution might be to have a legacy bias field and in __post_init__ feed its value into all biases except lm_head. But I don't like it. Maybe there is a better solution.

I also tried dealing with it in the __init__ method, a quick example should be like this:

    def __init__(self, **kwargs: Any) -> None:
        names = {f.name for f in fields(self)}
        for arg_name in list(kwargs):
            if arg_name in names:
                setattr(self, arg_name, kwargs.pop(arg_name))

        # deal with legacy args
        # ...
        if "bias" in kwargs:
            bias = kwargs.pop("bias")
            self.attn_qkv_bias = bias
            self.attn_proj_bias = bias
            self.mlp_bias = bias

        if kwargs != {}:
            raise ValueError(f"Non empty kwargs: {kwargs}")

        self.__post_init__()

But it throws

Validation failed: No action for key "name" to check its value.

I'll investigate it tomorrow, but maybe you know a better way?

Andrei-Aksionov avatar Mar 26 '24 18:03 Andrei-Aksionov

Argh good point. This needs to be solved at the CLI level, but I'm not sure of the best way to do it. Opened https://github.com/omni-us/jsonargparse/issues/479 to ask.

carmocca avatar Mar 26 '24 19:03 carmocca