NEKO icon indicating copy to clipboard operation
NEKO copied to clipboard

Error running `eval.py` on checkpoint created using `--pretrained_lm=gpt2`.

Open eihli opened this issue 1 year ago • 0 comments

Here's the steps to recreate:

First, to show that it's specifically related to the --pretrained-lm argument, run this train/eval pair once without --pretrained-lm=gpt2 in the training arguments, then run the pair again with --pretrained-lm=gpt2 in the training arguments.

  1. Train
python -m pdb train.py \
    --training_steps=12 \
    --log_eval_freq=4 \
    --warmup_steps=1 \
    --batch_size=4 \
    --eval_episodes=1 \
    --activation_fn=gelu \
    --save_model \
    --save_mode=checkpoint \
    --text_prop=1.0 \
    --eval_text_log_examples \
    --text_datasets=wikitext-2-v1 \
    --text_datasets_paths=wikitext \
    --disable_cosine_decay
  1. Evaluate
python -m pdb eval.py \
    --model_path=./models/neko-gato-620082/checkpoint_12.pt \
    --text_datasets=wikitext-2-v1 \
    --text_datasets_paths=wikitext \
    --eval_episodes=1

When you run the above with the --pretrained_lm=gpt2 argument, you get the following error message:

RuntimeError: Error(s) in loading state_dict for GatoPolicy:
        Unexpected key(s) in state_dict: "transformer.h.8.ln_1.weight", "transformer.h.8.ln_1.bias", "transformer.h.8.attn.bias", "transformer.h.8.attn.masked_bias", "transformer.h.8.attn.c_attn.weight", "transformer.h.8.attn.c_attn.bias", "transformer.h.8.attn.c_proj.weight", "transformer.h.8.attn.c_proj.bias", "transformer.h.8.ln_2.weight", "transformer.h.8.ln_2.bias", "transformer.h.8.mlp.c_fc.weight", "transformer.h.8.mlp.c_fc.bias", "transformer.h.8.mlp.c_proj.weight", "transformer.h.8.mlp.c_proj.bias", "transformer.h.9.ln_1.weight", "transformer.h.9.ln_1.bias", "transformer.h.9.attn.bias", "transformer.h.9.attn.masked_bias", "transformer.h.9.attn.c_attn.weight", "transformer.h.9.attn.c_attn.bias", "transformer.h.9.attn.c_proj.weight", "transformer.h.9.attn.c_proj.bias", "transformer.h.9.ln_2.weight", "transformer.h.9.ln_2.bias", "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_fc.bias", "transformer.h.9.mlp.c_proj.weight", "transformer.h.9.mlp.c_proj.bias", "transformer.h.10.ln_1.weight", "transformer.h.10.ln_1.bias", "transformer.h.10.attn.bias", "transformer.h.10.attn.masked_bias", "transformer.h.10.attn.c_attn.weight", "transformer.h.10.attn.c_attn.bias", "transformer.h.10.attn.c_proj.weight", "transformer.h.10.attn.c_proj.bias", "transformer.h.10.ln_2.weight", "transformer.h.10.ln_2.bias", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_fc.bias", "transformer.h.10.mlp.c_proj.weight", "transformer.h.10.mlp.c_proj.bias", "transformer.h.11.ln_1.weight", "transformer.h.11.ln_1.bias", "transformer.h.11.attn.bias", "transformer.h.11.attn.masked_bias", "transformer.h.11.attn.c_attn.weight", "transformer.h.11.attn.c_attn.bias", "transformer.h.11.attn.c_proj.weight", "transformer.h.11.attn.c_proj.bias", "transformer.h.11.ln_2.weight", "transformer.h.11.ln_2.bias", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_fc.bias", "transformer.h.11.mlp.c_proj.weight", "transformer.h.11.mlp.c_proj.bias".
        size mismatch for transformer.wte.weight: copying a param with shape torch.Size([50257, 768]) from checkpoint, the shape in current model is torch.Size([1, 768]).

Doing some archeology, we find at one point in time pretrained_lm was removed from the training args before evaluation.

commit 1640ce0d97f9801695c1b2241ad6c29608e5f1e9
Author: Daniel Lawson <[email protected]>
Date:   Wed Jun 28 12:39:12 2023 -0400

    added init
---
 eval.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/eval.py b/eval.py
index f463c4e..8eaa8c8 100644
--- a/eval.py
+++ b/eval.py
@@ -24,6 +24,8 @@ def main(args):
         args_path = args.args_path

     training_args = json.load(open(args_path, 'r'))
+    if 'pretrained_lm' in training_args:
+        del training_args['pretrained_lm']

     # update args with eval_args
     for k, v in args.items():

That change was then modified to only delete the --pretrained_lm argument if --lora was passed.

commit ec4a486afc069f02c572c18cf73199223e0f1a8a
Author: Daniel Lawson <[email protected]>
Date:   Mon Jul 3 01:47:41 2023 -0400

    fixed eval afer lora
---
 eval.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/eval.py b/eval.py
index 9be6c83..0c827eb 100644
--- a/eval.py
+++ b/eval.py
@@ -6,6 +6,8 @@ import time
 import numpy as np
 import torch

+from peft import LoraConfig, TaskType, get_peft_model
+
 from gato.utils.utils import DotDict
 from gato.policy.gato_policy import GatoPolicy
 from gato.envs.setup_env import load_envs
@@ -24,8 +26,8 @@ def main(args):
         args_path = args.args_path

     training_args = json.load(open(args_path, 'r'))
-    if 'pretrained_lm' in training_args:
-        del training_args['pretrained_lm']
+    if not ('lora' in training_args and training_args['lora']):
+        training_args['pretrained_lm'] = None

     # update args with eval_args
     for k, v in args.items():
@@ -72,7 +74,15 @@ def main(args):
         use_patch_pos_encoding=not eval_args.disable_patch_pos_encoding,
         use_pos_encoding=not eval_args.disable_inner_pos_encoding,
         activation_fn=eval_args.activation_fn,
+        pretrained_lm=eval_args.pretrained_lm,
+        flash=eval_args.flash
     )
+
+    if eval_args.get('lora', False):
+        assert eval_args.pretrained_lm is not None, 'Must specify pretrained LM for LORA'
+        peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=eval_args.lora_r, lora_alpha=eval_args.lora_alpha, lora_dropout=eval_args.lora_dropout)
+        model.transformer = get_peft_model(model.transformer, peft_config)
+
     model.load_state_dict(gato_checkpoint)
     model = model.to(eval_args.device)
     model.device = eval_args.device

The logic explicitly raises if lora_, and not pretrained_lm.

The logic implicitly fails if pretrained_lm and not lora.

I'm not sure if that's purposeful or accidental. I don't know much about lora and how it interacts with the model. I'm guessing it's fine to run pretrained_lm without lora.

Just logging this right now as research/investigation notes to pick back up later.

eihli avatar Jan 14 '24 20:01 eihli