trlx
trlx copied to clipboard
MPT is not working
🐛 Describe the bug
When running the following code:
import trlx
trainer = trlx.train(
"mosaicml/mpt-7b",
samples=[
['Question: 1 + 2 Answer:', '3'],
['Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:', '(pi ** 2)/ 6']
]
)
A ValueError
is raised:
Traceback (most recent call last):
File "--/rl-llm/train.py", line 14, in <module>
trainer = trlx.train(
File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trlx.py", line 92, in train
trainer = get_trainer(config.train.trainer)(
File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trainer/accelerate_sft_trainer.py", line 32, in __init__
super().__init__(config, **kwargs)
File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trainer/accelerate_base_trainer.py", line 66, in __init__
self.model = self.setup_model()
File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/trainer/accelerate_base_trainer.py", line 161, in setup_model
freeze_bottom_causal_layers(model.base_model, self.config.model.num_layers_unfrozen)
File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/utils/modeling.py", line 24, in freeze_bottom_causal_layers
hidden_layers = hf_get_decoder_blocks(model)
File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/utils/modeling.py", line 148, in hf_get_decoder_blocks
return findattr(model, hidden_layers_attrs)
File "--/miniconda3/envs/rl/lib/python3.10/site-packages/trlx/utils/modeling.py", line 96, in findattr
raise ValueError(f"Could not find an attribute from `{attrs}` in `{obj}`")
ValueError: Could not find an attribute from `('h', 'layers', 'model.layers', 'decoder.layers', 'transformer.h', 'transformer.blocks', 'model.decoder.layers', 'gpt_neox.layers', 'decoder.block')` in `MptModel(
(wte): Embedding(50432, 4096)
(blocks): ModuleList(
(0-31): 32 x MptBlock(
(norm_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attn): MptAttention(
(Wqkv): Linear(in_features=4096, out_features=12288, bias=False)
(out_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(norm_2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(ffn): MptMLP(
(up_proj): Linear(in_features=4096, out_features=16384, bias=False)
(act): GELU(approximate='none')
(down_proj): Linear(in_features=16384, out_features=4096, bias=False)
)
(resid_attn_dropout): Dropout(p=0, inplace=False)
)
)
(norm_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
I'm not sure what is going on, since #546 supposedly fixed it.
I installed trlx with
pip install -U git+https://github.com/CarperAI/trlx.git
and
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e .
it fails with both.
Which trlX version are you using?
0.7.0
Additional system and package information
linux