stanford_alpaca icon indicating copy to clipboard operation
stanford_alpaca copied to clipboard

How to make it work on Google Cloud TPU?

Open aicheung opened this issue 1 year ago • 2 comments

Hi. I got some free quotas for Google Cloud TPU and I tried to run the training on it these two days. I did the following:

  1. Create a TPU VM with the following config: TPU type v3-8 TPU software version tpu-vm-pt-1.13
  2. Install pyenv and install Python 3.8.16
  3. Git clone project (from my own fork at https://github.com/aicheung/stanford_alpaca. I removed some unused code for Py3.8 and incorporated updated prompts from alpaca-lora)
  4. Install requirements
pip install -r requirements.txt
  1. Install HuggingFace at per readme:
pip install git+https://github.com/huggingface/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
  1. Install pytorch-xla for TPU:
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
  1. Install git lfs and clone llama 7B model:
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
sudo apt-get install git-lfs
git clone https://huggingface.co/decapoda-research/llama-7b-hf
  1. Run training, but the following errors were produced:
$ torchrun --nproc_per_node=1 --master_port=23456 train.py  \
   --model_name_or_path ~/llama-7b-hf     --data_path ./alpaca_data.json     --bf16 True  \
   --output_dir ~/alpaca-7b     --num_train_epochs 3   \
  --per_device_train_batch_size 1     --per_device_eval_batch_size 1   \
  --gradient_accumulation_steps 8     --evaluation_strategy "no" \
    --save_strategy "steps"     --save_steps 2000     --save_total_limit 1  \
   --learning_rate 2e-5     --weight_decay 0.  \
   --warmup_ratio 0.03     --lr_scheduler_type "cosine"     --logging_steps 1 
WARNING:root:XRT configuration not detected. Defaulting to preview PJRT runtime. To silence this warning and continue using PJRT, explicitly set PJRT_DEVICE to a supported device or configure XRT. To disable default device selection, set PJRT_SELECT_DEFAULT_DEVICE=0
WARNING:root:For more information about the status of PJRT, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|████████████████| 33/33 [00:07<00:00,  4.62it/s]
Using pad_token, but it is not set yet.
WARNING:root:Loading data...
WARNING:root:Formatting inputs...
WARNING:root:Tokenizing inputs... This may take some time...
2023-03-19 13:02:38.933699: F tensorflow/tsl/platform/statusor.cc:33] Attempting to fetch value instead of handling error RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 172.00M. That was not possible. There are 84.95M free.; (0x0x0_HBM0)
https://symbolize.stripped_domain/r/?trace=7fc53537e00b,7fc53537e08f,7fc3ea57bbff,7fc3ea880643,7fc49ce4efbe,7fc3ea878fbf,8ec834860c78347&map=04ceea301ec570e6abcf4ef3f089f0fde6516664:7fc3e7627000-7fc3fb07e5e0 
*** SIGABRT received by PID 104543 (TID 104543) on cpu 24 from PID 104543; stack trace: ***
PC: @     0x7fc53537e00b  (unknown)  raise
    @     0x7fc3e69fba1a       1152  (unknown)
    @     0x7fc53537e090  1790163424  (unknown)
    @     0x7fc3ea57bc00        400  tsl::internal_statusor::Helper::Crash()
    @     0x7fc3ea880644       1312  xla::PjRtComputationClient::TransferToServer()
    @     0x7fc49ce4efbf  (unknown)  torch_xla::TensorToXlaData()
    @     0x7fc3ea878fc0  (unknown)  (unknown)
    @  0x8ec834860c78348  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7fc53537e00b,7fc3e69fba19,7fc53537e08f,7fc3ea57bbff,7fc3ea880643,7fc49ce4efbe,7fc3ea878fbf,8ec834860c78347&map=04ceea301ec570e6abcf4ef3f089f0fde6516664:7fc3e7627000-7fc3fb07e5e0,ceee8fa20ddf9c34af43f587221e91de:7fc3d9ad3000-7fc3e6c12840 
E0319 13:02:39.013511  104543 coredump_hook.cc:414] RAW: Remote crash data gathering hook invoked.
E0319 13:02:39.013527  104543 coredump_hook.cc:453] RAW: Skipping coredump since rlimit was 0 at process start.
E0319 13:02:39.013535  104543 client.cc:278] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0319 13:02:39.013539  104543 coredump_hook.cc:512] RAW: Sending fingerprint to remote end.
E0319 13:02:39.013548  104543 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0319 13:02:39.013555  104543 coredump_hook.cc:518] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0319 13:02:39.013560  104543 coredump_hook.cc:580] RAW: Dumping core locally.
E0319 13:02:39.376998  104543 process_state.cc:784] RAW: Raising signal 6 with default behavior
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 104543) of binary: /home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/bin/python3.8
Traceback (most recent call last):
  File "/home/aicheung/.pyenv/versions/alpaca3816/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/aicheung/.pyenv/versions/3.8.16/envs/alpaca3816/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-03-19_13:02:42
  host      : t1v-n-00692496-w-0.europe-west4-a.c.alpaca-381010.internal
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 104543)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 104543
============================================================

As you can see, I already reduced batch size to 1 but it still OOM at the TPU. Any way to solve it or anyone has any success with running on TPU? Thanks.

aicheung avatar Mar 19 '23 13:03 aicheung

I am running into the same issue. Looks to be setting to PJRT mode by default. Not sure how to ask it to use the XRT mode.

opooladz avatar Apr 04 '23 18:04 opooladz

Any update

jmikedupont2 avatar Feb 14 '24 17:02 jmikedupont2