alpa
alpa copied to clipboard
IndexError: `InlinedVector::at(size_type) const` failed bounds check
Please describe the bug
IndexError: InlinedVector::at(size_type) const
failed bounds check
System information and environment
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker):
- Python version:3.8.10
- CUDA version:11.3
- NCCL version:2.9
- cupy version:11.3
- GPU model and memory:2*A100(80G)
- Alpa version:0.2.3
- TensorFlow version:2.8.0
- JAX version:0.3.22
To Reproduce
Steps to reproduce the behavior:
1.Training an llama model implemented by flax produces the following error
2. See error
2023-09-24 12:29:49,782 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 10.233.115.148:6379...
2023-09-24 12:29:49,795 INFO worker.py:1528 -- Connected to Ray cluster.
Training/epoch 0: 0%| | 0/7473 [00:01<?, ?it/s]
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in InlinedVector::at(size_type) const
failed bounds check
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in InlinedVector::at(size_type) const
failed bounds check
Screenshots
Code snippet to reproduce the problem @alpa.parallelize(batch_argnums=(1,2,3,4)) def train_step(state, seq, seq_mask, labels, labels_mask):
def train_forward(params): # seq, seq_mask, labels, labels_mask = data_batch position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seq).shape[-1]),seq.shape) outputs = state.apply_fn( params, seq, seq_mask, position_ids, deterministic = False, return_dict = False, ) logits = outputs[0] loss = cross_entropy_loss(logits, labels, mask=labels_mask) return loss dynamic_scale = state.dynamic_scale if dynamic_scale: grad_fn = dynamic_scale.value_and_grad(train_forward) dynamic_scale, is_fin, loss, grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
if dynamic_scale: new_state = new_state.replace( opt_state=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params), master_copy=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.master_copy, state.master_copy), dynamic_scale=dynamic_scale)
return new_state, loss
def main() -> None: global llama_model alpa.init(cluster="ray") lr = 0.001 batch_size = 1 max_len = 640 n_epochs = 7
load_pretrained_model = False ckpt_dir="./JAX_model/7B"
prepare dataset
tokenizer = LLaMATokenizer("./JAX_model/tokenizer.model") dataset = GSMDataset(split='train') collate_fn = partial(gsm_collate_fn_train, tokenizer=tokenizer, max_len=max_len) dataloader = LlamaDataLoader(dataset, batch_size, collate_fn)
set config
if load_pretrained_model: with open(Path(ckpt_dir)/"params.json", "r") as f: config_params = json.loads(f.read()) config_params.update({'vocab_size': len(tokenizer), 'max_seq_len':max_len}) llama_config = LLaMAConfig(**config_params) else: llama_config = LLaMAConfig(num_hidden_layers=4) llama_model = LLaMAForCausalLMModule(llama_config)
init model
input_ids = jnp.ones((batch_size,max_len)) attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),input_ids.shape) params = llama_model.init(input_ids, attention_mask, position_ids, return_dict=False, init_cache=False)
if load_pretrained_model: param = restore(Path(ckpt_dir)/"consolidated.nra", replace_keys=False) params['param'] = param
n_steps = math.ceil(len(dataloader))
schedule = warmup_cosine_decay_schedule( init_value=0., peak_value=lr, warmup_steps=n_steps, decay_steps=n_steps + 1, end_value=lr, ) optimizer = adamw(learning_rate=schedule)
use_master_copy = True dynamic_scale = DynamicScale() alpa.global_config.flax_always_use_fp16_embedding = True state = TrainState.create(apply_fn=llama_model.run, params=params, tx=optimizer,dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)
for epoch in range(n_epochs): with tqdm(dataloader) as tepoch: tepoch.set_description(f"Training/epoch {epoch}") for batch in tepoch: seq, seq_mask, labels, labels_mask = batch state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
if name == 'main': main()
Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.