Separate out the biases
This separates the single bias config into 3 separate bias configs: QKV bias, attention projection bias, and MLP bias. This would be necessary to implement Grok, for example, which uses a QKV bias but no MLP bias.
We might want to revive https://github.com/Lightning-AI/litgpt/pull/878 if we are doing this. What do you prefer?
I had no idea. Yeah, then let's revive @Andrei-Aksionov's #878
Only one test fails:
tests/test_config_hub.py::test_config_help[litgpt/pretrain.py-https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/tinystories.yaml]
The reason is that in the main branch, the yaml file contains the old bias notation. After the PR is merged, this fail should be fixed automagically.
The reason is that in the main branch, the yaml file contains the old bias notation.
This also signals a breaking change. Can you add backwards-compatibility code to Config as we had in the past for other arguments?
This also signals a breaking change. Can you add backwards-compatibility code to Config as we had in the past for other arguments?
Sure. But before we handled this in .from_* method, where legacy args were provided via **kwargs.
Now the issue comes from jsonargparse when it checks/compares provided args against listed fields in the Config class.
So the solution might be to have a legacy bias field and in __post_init__ feed its value into all biases except lm_head.
But I don't like it. Maybe there is a better solution.
I also tried dealing with it in the __init__ method, a quick example should be like this:
def __init__(self, **kwargs: Any) -> None:
names = {f.name for f in fields(self)}
for arg_name in list(kwargs):
if arg_name in names:
setattr(self, arg_name, kwargs.pop(arg_name))
# deal with legacy args
# ...
if "bias" in kwargs:
bias = kwargs.pop("bias")
self.attn_qkv_bias = bias
self.attn_proj_bias = bias
self.mlp_bias = bias
if kwargs != {}:
raise ValueError(f"Non empty kwargs: {kwargs}")
self.__post_init__()
But it throws
Validation failed: No action for key "name" to check its value.
I'll investigate it tomorrow, but maybe you know a better way?
Argh good point. This needs to be solved at the CLI level, but I'm not sure of the best way to do it. Opened https://github.com/omni-us/jsonargparse/issues/479 to ask.