NEKO
NEKO copied to clipboard
Error running `eval.py` on checkpoint created using `--pretrained_lm=gpt2`.
Here's the steps to recreate:
First, to show that it's specifically related to the --pretrained-lm
argument, run this train/eval pair once without --pretrained-lm=gpt2
in the training arguments, then run the pair again with --pretrained-lm=gpt2
in the training arguments.
- Train
python -m pdb train.py \
--training_steps=12 \
--log_eval_freq=4 \
--warmup_steps=1 \
--batch_size=4 \
--eval_episodes=1 \
--activation_fn=gelu \
--save_model \
--save_mode=checkpoint \
--text_prop=1.0 \
--eval_text_log_examples \
--text_datasets=wikitext-2-v1 \
--text_datasets_paths=wikitext \
--disable_cosine_decay
- Evaluate
python -m pdb eval.py \
--model_path=./models/neko-gato-620082/checkpoint_12.pt \
--text_datasets=wikitext-2-v1 \
--text_datasets_paths=wikitext \
--eval_episodes=1
When you run the above with the --pretrained_lm=gpt2
argument, you get the following error message:
RuntimeError: Error(s) in loading state_dict for GatoPolicy:
Unexpected key(s) in state_dict: "transformer.h.8.ln_1.weight", "transformer.h.8.ln_1.bias", "transformer.h.8.attn.bias", "transformer.h.8.attn.masked_bias", "transformer.h.8.attn.c_attn.weight", "transformer.h.8.attn.c_attn.bias", "transformer.h.8.attn.c_proj.weight", "transformer.h.8.attn.c_proj.bias", "transformer.h.8.ln_2.weight", "transformer.h.8.ln_2.bias", "transformer.h.8.mlp.c_fc.weight", "transformer.h.8.mlp.c_fc.bias", "transformer.h.8.mlp.c_proj.weight", "transformer.h.8.mlp.c_proj.bias", "transformer.h.9.ln_1.weight", "transformer.h.9.ln_1.bias", "transformer.h.9.attn.bias", "transformer.h.9.attn.masked_bias", "transformer.h.9.attn.c_attn.weight", "transformer.h.9.attn.c_attn.bias", "transformer.h.9.attn.c_proj.weight", "transformer.h.9.attn.c_proj.bias", "transformer.h.9.ln_2.weight", "transformer.h.9.ln_2.bias", "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_fc.bias", "transformer.h.9.mlp.c_proj.weight", "transformer.h.9.mlp.c_proj.bias", "transformer.h.10.ln_1.weight", "transformer.h.10.ln_1.bias", "transformer.h.10.attn.bias", "transformer.h.10.attn.masked_bias", "transformer.h.10.attn.c_attn.weight", "transformer.h.10.attn.c_attn.bias", "transformer.h.10.attn.c_proj.weight", "transformer.h.10.attn.c_proj.bias", "transformer.h.10.ln_2.weight", "transformer.h.10.ln_2.bias", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_fc.bias", "transformer.h.10.mlp.c_proj.weight", "transformer.h.10.mlp.c_proj.bias", "transformer.h.11.ln_1.weight", "transformer.h.11.ln_1.bias", "transformer.h.11.attn.bias", "transformer.h.11.attn.masked_bias", "transformer.h.11.attn.c_attn.weight", "transformer.h.11.attn.c_attn.bias", "transformer.h.11.attn.c_proj.weight", "transformer.h.11.attn.c_proj.bias", "transformer.h.11.ln_2.weight", "transformer.h.11.ln_2.bias", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_fc.bias", "transformer.h.11.mlp.c_proj.weight", "transformer.h.11.mlp.c_proj.bias".
size mismatch for transformer.wte.weight: copying a param with shape torch.Size([50257, 768]) from checkpoint, the shape in current model is torch.Size([1, 768]).
Doing some archeology, we find at one point in time pretrained_lm
was removed from the training args before evaluation.
commit 1640ce0d97f9801695c1b2241ad6c29608e5f1e9
Author: Daniel Lawson <[email protected]>
Date: Wed Jun 28 12:39:12 2023 -0400
added init
---
eval.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/eval.py b/eval.py
index f463c4e..8eaa8c8 100644
--- a/eval.py
+++ b/eval.py
@@ -24,6 +24,8 @@ def main(args):
args_path = args.args_path
training_args = json.load(open(args_path, 'r'))
+ if 'pretrained_lm' in training_args:
+ del training_args['pretrained_lm']
# update args with eval_args
for k, v in args.items():
That change was then modified to only delete the --pretrained_lm
argument if --lora
was passed.
commit ec4a486afc069f02c572c18cf73199223e0f1a8a
Author: Daniel Lawson <[email protected]>
Date: Mon Jul 3 01:47:41 2023 -0400
fixed eval afer lora
---
eval.py | 14 ++++++++++++--
1 file changed, 12 insertions(+), 2 deletions(-)
diff --git a/eval.py b/eval.py
index 9be6c83..0c827eb 100644
--- a/eval.py
+++ b/eval.py
@@ -6,6 +6,8 @@ import time
import numpy as np
import torch
+from peft import LoraConfig, TaskType, get_peft_model
+
from gato.utils.utils import DotDict
from gato.policy.gato_policy import GatoPolicy
from gato.envs.setup_env import load_envs
@@ -24,8 +26,8 @@ def main(args):
args_path = args.args_path
training_args = json.load(open(args_path, 'r'))
- if 'pretrained_lm' in training_args:
- del training_args['pretrained_lm']
+ if not ('lora' in training_args and training_args['lora']):
+ training_args['pretrained_lm'] = None
# update args with eval_args
for k, v in args.items():
@@ -72,7 +74,15 @@ def main(args):
use_patch_pos_encoding=not eval_args.disable_patch_pos_encoding,
use_pos_encoding=not eval_args.disable_inner_pos_encoding,
activation_fn=eval_args.activation_fn,
+ pretrained_lm=eval_args.pretrained_lm,
+ flash=eval_args.flash
)
+
+ if eval_args.get('lora', False):
+ assert eval_args.pretrained_lm is not None, 'Must specify pretrained LM for LORA'
+ peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=eval_args.lora_r, lora_alpha=eval_args.lora_alpha, lora_dropout=eval_args.lora_dropout)
+ model.transformer = get_peft_model(model.transformer, peft_config)
+
model.load_state_dict(gato_checkpoint)
model = model.to(eval_args.device)
model.device = eval_args.device
The logic explicitly raises if lora
_, and not pretrained_lm
.
The logic implicitly fails if pretrained_lm
and not lora
.
I'm not sure if that's purposeful or accidental. I don't know much about lora and how it interacts with the model. I'm guessing it's fine to run pretrained_lm
without lora
.
Just logging this right now as research/investigation notes to pick back up later.