jetstream-pytorch
jetstream-pytorch copied to clipboard
Error Running `run_ray_serve_interleave` with Llama3 8B
I'm receiving an error when attempting to run:
ray job submit -- python run_ray_serve_interleave.py --tpu_chips=4 --num_hosts=1 --size=8B --model_name=llama-3 --batch_size=8 --max_cache_length=2048 --tokenizer_path=$tokenizer_path --checkpoint_path=$output_ckpt_dir --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
on a single-host v4 TPU of 2x2x1 topology. The error is:
ray.exceptions.ActorDiedError: The actor died because of an error raised in its creation task, [36mray::PyTorchRayWorker.__init__()[39m (pid=5137, ip=10.168.0.16, actor_id=243ec964a2f41eae1707d84404000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7abe4074add0>)
File "/home/ray/jetstream-pytorch/jetstream_pt/ray_worker.py", line 200, in __init__
pt_model = model_exportable.Transformer(args, env)
File "/home/ray/jetstream-pytorch/jetstream_pt/third_party/llama/model_exportable.py", line 192, in __init__
self.tok_embeddings = Embedding(
File "/home/ray/jetstream-pytorch/jetstream_pt/layers.py", line 57, in __init__
table = torch.ones(
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
result = fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_refs/__init__.py", line 4774, in ones
size = utils.extract_shape_from_varargs(size)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 854, in extract_shape_from_varargs
validate_shape(shape) # type: ignore[arg-type]
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 588, in validate_shape
validate_dim_length(l)
This seems to be related to this logic in ray_worker.py
that creates the pt_model
:
env_data.model_type = "llama-2-" + param_size
env_data.num_layers = args.n_layers
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)
should the model_type
for llama models be hardcoded to llama-2
?