Cannot train Seeker with batch size > 1
Bug description
It seems that Seeker training command does not support batch size > 1. I ran into a FSDP error when training Seeker-400M with -bs 2.
Reproduction steps
python -m parlai.scripts.multiprocessing_train \
--task projects.seeker.tasks.knowledge,projects.seeker.tasks.dialogue,projects.seeker.tasks.search_query \
--multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 \
--init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model \
--model projects.seeker.agents.seeker:ComboFidGoldDocumentAgent --n-docs 5 \
--text-truncate 1000 --label-truncate 128 --truncate 1000 \
--fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True \
--warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 \
--attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 \
--checkpoint-activations true--model-file /tmp/my_seeker_dialogue_model
Expected behavior I was hoping that training could succeed.
Logs Please paste the command line output:
Asserting FSDP instance is: FullyShardedDataParallel(
world_size=8, flatten_parameters=True, mixed_precision=True,
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): TransformerEncoderLayer_Swappable(
(attention): MultiHeadAttention(
(attn_dropout): Dropout(p=0.0, inplace=False)
(q_lin): Linear(in_features=1024, out_features=1024, bias=True)
(k_lin): Linear(in_features=1024, out_features=1024, bias=True)
(v_lin): Linear(in_features=1024, out_features=1024, bias=True)
(out_lin): Linear(in_features=1024, out_features=1024, bias=True)
)
(norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(ffn): TransformerFFN(
(relu_dropout): Dropout(p=0, inplace=False)
(lin1): Linear(in_features=1024, out_features=4096, bias=True)
(lin2): Linear(in_features=4096, out_features=1024, bias=True)
)
(norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
ERROR: expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.BACKWARD_PRE
2022-04-30 23:01:08,795 CRITICAL | Traceback (most recent call last):
File "/data/kai/ParlAI/parlai/scripts/multiprocessing_train.py", line 45, in multiprocess_train
return single_train.TrainLoop(opt).train()
File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 1000, in train
for _train_log in self.train_steps():
File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 907, in train_steps
world.parley()
File "/data/kai/ParlAI/parlai/core/worlds.py", line 880, in parley
batch_act = self.batch_act(agent_idx, batch_observations[agent_idx])
File "/data/kai/ParlAI/parlai/core/worlds.py", line 848, in batch_act
batch_actions = a.batch_act(batch_observation)
File "/data/kai/ParlAI/parlai/agents/fid/fid.py", line 389, in batch_act
batch_reply = super().batch_act(observations)
File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2238, in batch_act
output = self.train_step(batch)
File "/data/kai/ParlAI/parlai/core/torch_generator_agent.py", line 736, in train_step
self.backward(loss)
File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2324, in backward
self.optimizer.backward(loss, update_main_grads=False)
File "/data/kai/ParlAI/parlai/utils/fp16.py", line 194, in backward
loss.backward()
File "/data/kai/miniconda3/envs/parlai/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/data/kai/miniconda3/envs/parlai/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f0ce3a8e9e0> returned NULL without setting an error
Additional context Not sure if this is a bug or a feature request.
Can you report your fairscale version?
The above error message was from fairscale 0.3.7.
Also tried fairscale 0.4.6 and got a similar error:
2022-05-08 14:35:33,820 INFO | training...
rank: 3 | 2022-05-08 14:35:40,808 CRITICAL | Traceback (most recent call last):
File "/data/kai/ParlAI/parlai/scripts/multiprocessing_train.py", line 45, in multiprocess_train
return single_train.TrainLoop(opt).train()
File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 1000, in train
for _train_log in self.train_steps():
File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 907, in train_steps
world.parley()
File "/data/kai/ParlAI/parlai/core/worlds.py", line 880, in parley
batch_act = self.batch_act(agent_idx, batch_observations[agent_idx])
File "/data/kai/ParlAI/parlai/core/worlds.py", line 848, in batch_act
batch_actions = a.batch_act(batch_observation)
File "/data/kai/ParlAI/parlai/agents/fid/fid.py", line 389, in batch_act
batch_reply = super().batch_act(observations)
File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2239, in batch_act
output = self.train_step(batch)
File "/data/kai/ParlAI/parlai/core/torch_generator_agent.py", line 736, in train_step
self.backward(loss)
File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2325, in backward
self.optimizer.backward(loss, update_main_grads=False)
File "/data/kai/ParlAI/parlai/utils/fp16.py", line 194, in backward
loss.backward()
File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f0aa7bfc8d0> returned NULL without setting an error
I wonder if it's multiprocessing train... Does that work with a transformer/generator?
No. We just tried running this (fairscale 0.4.6):
parlai multiprocessing_train -t projects.seeker.tasks.knowledge,projects.seeker.tasks.dialogue,projects.seeker.tasks.search_query --multitask-weights 2,2,1 -veps 0.25 --attention-dropout 0.0 --batchsize 32 --model transformer/generator --embedding-size 2560 --ffn-size 10240 --variant prelayernorm --n-heads 32 --n-positions 128 --n-encoder-layers 2 --n-decoder-layers 24 --history-add-global-end-token end --delimiter ' ' --dict-tokenizer bytelevelbpe --dropout 0.1 --fp16 True --init-model zoo:blender/reddit_3B/model --dict-file zoo:blender/reddit_3B/model.dict --label-truncate 128 --log_every_n_secs 30 -lr 7e-06 --lr-scheduler reduceonplateau --lr-scheduler-patience 3 --optimizer adam --relu-dropout 0.0 --activation gelu --ddp-backend zero2 --learn-positional-embeddings true --save-after-valid True --text-truncate 128 --truncate 128 --warmup_updates 100 --fp16-impl mem_efficient --update-freq 1 --gradient-clip 0.1 --skip-generation True -vp 10 -vmt ppl -vmm min --tensorboard-log true --model-file /data/kai/modelfiles/test_train_3B/test_train_27B
And got this:
rank: 5 | 11:18:15 | Traceback (most recent call last):
File "/data/kai/ParlAI/parlai/scripts/multiprocessing_train.py", line 45, in multiprocess_train
return single_train.TrainLoop(opt).train()
File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 1000, in train
for _train_log in self.train_steps():
File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 907, in train_steps
world.parley()
File "/data/kai/ParlAI/parlai/core/worlds.py", line 880, in parley
batch_act = self.batch_act(agent_idx, batch_observations[agent_idx])
File "/data/kai/ParlAI/parlai/core/worlds.py", line 848, in batch_act
batch_actions = a.batch_act(batch_observation)
File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2239, in batch_act
output = self.train_step(batch)
File "/data/kai/ParlAI/parlai/core/torch_generator_agent.py", line 736, in train_step
self.backward(loss)
File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2325, in backward
self.optimizer.backward(loss, update_main_grads=False)
File "/data/kai/ParlAI/parlai/utils/fp16.py", line 522, in backward
loss.backward()
File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1485, in _pre_backward_hook
self._use_full_params()
File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1992, in _use_full_params
assert self.has_full_params
AssertionError
Sorry, one more thing. Can you roll back to 0.4.4?
Sorry, it turns out that transformer/generator works fine with bs > 1. We ran into the above error because we turned off flatten_parameter (which is also strange, but I suppose this is a fairscale problem).
We still couldn't train Seeker with bs > 1 with fairscale 0.4.6. We're trying 0.4.4 now and will report back.
So we just tried training Seeker with fairscale 0.4.4 and got the same error.
I'm able to repro on my end so I'll try to look into it a bit more and report back here with findings
Edit Update 1
The model is able to train with multiprocessing_train, --batchsize 2, and only 1 exposed GPU. Bumping up to 2 GPUs, it fails.
Passing command:
CUDA_VISIBLE_DEVICES=0 CUDA_LAUNCH_BLOCKING=1 python -m parlai.scripts.multiprocessing_train --task projects.seeker.tasks.search_query --multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 --init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model --model projects.seeker.agents.seeker:ComboFidGoldDocumentAgent --n-docs 5 --text-truncate 1000 --label-truncate 128 --truncate 1000 --fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True --warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 --attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 --checkpoint-activations true --model-file
Failing command:
CUDA_VISIBLE_DEVICES=0,1 CUDA_LAUNCH_BLOCKING=1 python -m parlai.scripts.multiprocessing_train --task projects.seeker.tasks.search_query --multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 --init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model --model projects.seeker.agents.seeker:ComboFidGoldDocumentAgent --n-docs 5 --text-truncate 1000 --label-truncate 128 --truncate 1000 --fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True --warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 --attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 --checkpoint-activations true --model-file
Update 2
This fails with the gold doc standard FiD agent as well
CUDA_VISIBLE_DEVICES=0,1 CUDA_LAUNCH_BLOCKING=1 python -m parlai.scripts.multiprocessing_train --task projects.seeker.tasks.knowledge:WoiKnowledgeTeacher --multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 --init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model --model parlai.agents.fid.fid:WizIntGoldDocRetrieverFiDAgent --n-docs 5 --text-truncate 1000 --label-truncate 128 --truncate 1000 --fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True --warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 --attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 --checkpoint-activations true --model-file
Should we try w/ slurm to rule out it being multiprocessing?
Should we try w/ slurm to rule out it being multiprocessing?
Tried this, still fails. something is hanging somewhere...
This issue has not had activity in 30 days. Please feel free to reopen if you have more issues. You may apply the "never-stale" tag to prevent this from happening.