nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

Model.py simplifications

Open Andrei-Aksionov opened this issue 1 year ago • 2 comments

Accidently messed up with the PR and the branch, so let's try one more time

I really don't like making such somewhat big PRs, but don't want to bombard with small ones either.

Here are some changes (code simplifications) for model.py file:

  1. We don't need a separate function for gelu activation function, since it's already implemented in PyTorch. In order to have the same behaviour just need to provide approximate='tanh'.
  2. In order to have LayerNorm with additional bias argument we can simply inherit from nn.LayerNorm and disable bias if needed in init method. In my opinion looks cleaner, plus preserve functionality of PyTorch's implementation.
  3. Additional head_size argument for CausalSelfAttention class. In case if we don't want to have head_size=n_embd//n_head. If it's not provided (default) - previous behavior. Adds additional parameter to finetune and shapes become more self-explainatory. Intuition came from this article. Note: it's note possible for right now to provide head_size as argument in CLI. The expected logic is to have this value (as global value in train.py) as None by default and then override, but because of the checks in configurator.py it's not possible (NoneType will not match with int type). Perhaps I'll make a PR for it later.
  4. 'bias' in CausalSelfAttention can be directly created in 4 dimensions. Or it can be created in 2 dimensions and then just leverage broadcasting in forward method.
  5. configure_optimizer method simplifications. If traverse only in parameters and retrieve modules from parameter's name, so we don't need those comments that might confuse people. By the way, why do we need two separate lists (whitelist and blacklist)? Since the goal is to have all parameters either in decay or no_decay the code can be further simplified:
decay, no_decay = [], []
for pn, _ in self.named_parameters():
    # get the parent module by the parameter's name
    module = reduce(lambda module, key: getattr(module, key), pn.split(".")[:-1], self)
    if pn.endswith('weight') and isinstance(module, nn.Linear):
        decay.add(pn)
    else:
       no_decay.add(pn)

and then we don't need to check sets interaction and union.

Andrei-Aksionov avatar Mar 10 '23 14:03 Andrei-Aksionov