stanford_alpaca
stanford_alpaca copied to clipboard
How to make it work on Google Cloud TPU?
Hi. I got some free quotas for Google Cloud TPU and I tried to run the training on it these two days. I did the following:
- Create a TPU VM with the following config: TPU type v3-8 TPU software version tpu-vm-pt-1.13
- Install pyenv and install Python 3.8.16
- Git clone project (from my own fork at https://github.com/aicheung/stanford_alpaca. I removed some unused code for Py3.8 and incorporated updated prompts from alpaca-lora)
- Install requirements
pip install -r requirements.txt
- Install HuggingFace at per readme:
pip install git+https://github.com/huggingface/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
- Install pytorch-xla for TPU:
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
- Install git lfs and clone llama 7B model:
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
sudo apt-get install git-lfs
git clone https://huggingface.co/decapoda-research/llama-7b-hf
- Run training, but the following errors were produced:
$ torchrun --nproc_per_node=1 --master_port=23456 train.py \
--model_name_or_path ~/llama-7b-hf --data_path ./alpaca_data.json --bf16 True \
--output_dir ~/alpaca-7b --num_train_epochs 3 \
--per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 --evaluation_strategy "no" \
--save_strategy "steps" --save_steps 2000 --save_total_limit 1 \
--learning_rate 2e-5 --weight_decay 0. \
--warmup_ratio 0.03 --lr_scheduler_type "cosine" --logging_steps 1
WARNING:root:XRT configuration not detected. Defaulting to preview PJRT runtime. To silence this warning and continue using PJRT, explicitly set PJRT_DEVICE to a supported device or configure XRT. To disable default device selection, set PJRT_SELECT_DEFAULT_DEVICE=0
WARNING:root:For more information about the status of PJRT, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|████████████████| 33/33 [00:07<00:00, 4.62it/s]
Using pad_token, but it is not set yet.
WARNING:root:Loading data...
WARNING:root:Formatting inputs...
WARNING:root:Tokenizing inputs... This may take some time...
2023-03-19 13:02:38.933699: F tensorflow/tsl/platform/statusor.cc:33] Attempting to fetch value instead of handling error RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 172.00M. That was not possible. There are 84.95M free.; (0x0x0_HBM0)
https://symbolize.stripped_domain/r/?trace=7fc53537e00b,7fc53537e08f,7fc3ea57bbff,7fc3ea880643,7fc49ce4efbe,7fc3ea878fbf,8ec834860c78347&map=04ceea301ec570e6abcf4ef3f089f0fde6516664:7fc3e7627000-7fc3fb07e5e0
*** SIGABRT received by PID 104543 (TID 104543) on cpu 24 from PID 104543; stack trace: ***
PC: @ 0x7fc53537e00b (unknown) raise
@ 0x7fc3e69fba1a 1152 (unknown)
@ 0x7fc53537e090 1790163424 (unknown)
@ 0x7fc3ea57bc00 400 tsl::internal_statusor::Helper::Crash()
@ 0x7fc3ea880644 1312 xla::PjRtComputationClient::TransferToServer()
@ 0x7fc49ce4efbf (unknown) torch_xla::TensorToXlaData()
@ 0x7fc3ea878fc0 (unknown) (unknown)
@ 0x8ec834860c78348 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7fc53537e00b,7fc3e69fba19,7fc53537e08f,7fc3ea57bbff,7fc3ea880643,7fc49ce4efbe,7fc3ea878fbf,8ec834860c78347&map=04ceea301ec570e6abcf4ef3f089f0fde6516664:7fc3e7627000-7fc3fb07e5e0,ceee8fa20ddf9c34af43f587221e91de:7fc3d9ad3000-7fc3e6c12840
E0319 13:02:39.013511 104543 coredump_hook.cc:414] RAW: Remote crash data gathering hook invoked.
E0319 13:02:39.013527 104543 coredump_hook.cc:453] RAW: Skipping coredump since rlimit was 0 at process start.
E0319 13:02:39.013535 104543 client.cc:278] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0319 13:02:39.013539 104543 coredump_hook.cc:512] RAW: Sending fingerprint to remote end.
E0319 13:02:39.013548 104543 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0319 13:02:39.013555 104543 coredump_hook.cc:518] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0319 13:02:39.013560 104543 coredump_hook.cc:580] RAW: Dumping core locally.
E0319 13:02:39.376998 104543 process_state.cc:784] RAW: Raising signal 6 with default behavior
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 104543) of binary: /home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/bin/python3.8
Traceback (most recent call last):
File "/home/aicheung/.pyenv/versions/alpaca3816/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
run(args)
File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2023-03-19_13:02:42
host : t1v-n-00692496-w-0.europe-west4-a.c.alpaca-381010.internal
rank : 0 (local_rank: 0)
exitcode : -6 (pid: 104543)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 104543
============================================================
As you can see, I already reduced batch size to 1 but it still OOM at the TPU. Any way to solve it or anyone has any success with running on TPU? Thanks.
I am running into the same issue. Looks to be setting to PJRT mode by default. Not sure how to ask it to use the XRT mode.
Any update