ring-flash-attention
ring-flash-attention copied to clipboard
Bugs when using zigzag_ring_flash_attn: RuntimeError: Number of requests do not match number of collectives
您好,我在使用EasyContext的zigzag_ring_flash_attn模式的时候报错如上 我的所有数据都被group by length到32768+1的长度上(根据https://github.com/jzhang38/EasyContext/issues/31#issue-2308064466)
在数据并行模式下可以正常运行,但序列并行报错。
code:
def main(args):
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
if args.wandb:
import wandb
wandb.login()
set_seed(args.seed)
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulate_every,
mixed_precision="bf16",
log_with="wandb" if args.wandb else None,
kwargs_handlers=[timeout],
# fsdp_plugin=fsdp_plugin,
)
accelerator.init_trackers(project_name=args.wandb, init_kwargs={"wandb":{"name":args.output_dir.split("/")[-1]}})
accelerator.print(f"Total GPUS: {accelerator.num_processes}")
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map=accelerator.device,
torch_dtype=torch.bfloat16,
rope_theta=args.rope_theta,
_attn_implementation="flash_attention_2",
)
# tokenizer = AutoTokenizer.from_pretrained(
# args.model,
# trust_remote_code=True,
# # llama不支持fast
# )
try:
train_dataset = load_dataset(args.dataset)
except:
train_dataset = load_from_disk(args.dataset)
if isinstance(train_dataset, DatasetDict):
train_dataset = train_dataset["train"]
# train_dataset = QwenSFTDataset(args.dataset, tokenizer, args)
assert isinstance(
model, (transformers.LlamaForCausalLM, transformers.MistralForCausalLM)
), "Only support llama and mistral model"
model_type = (
"llama" if isinstance(model, transformers.LlamaForCausalLM) else "mistral"
)
apply_seq_parallel_monkey_patch(args.parallel_mode, model_type)
if "input_ids" not in train_dataset.column_names:
raise RuntimeError("Dataset must include an `input_ids` feature")
# remove everything that is not input_ids
to_remove = [col for col in train_dataset.column_names if col != "input_ids"]
train_dataset = train_dataset.remove_columns(to_remove)
train_dataset = train_dataset.shuffle(seed=args.seed)
print("Dataset Size:", len(train_dataset))
train_loader = DataLoader(
train_dataset,
collate_fn=default_data_collator,
shuffle=True,
batch_size=args.batch_size,
)
if args.learning_rate != 2e-5:
accelerator.print(f"Warning: You also need to modify accelerate_configs/zero3_offload.json to change the learning rate")
optim = DummyOptim(model.parameters(), lr=args.learning_rate)
scheduler = DummyScheduler(
optim,
num_training_steps=args.max_train_steps,
total_num_steps=args.max_train_steps,
)
model, optim, scheduler = accelerator.prepare(model, optim, scheduler)
train_loader = prepare_dataloader(args.parallel_mode, train_loader, accelerator)
model.gradient_checkpointing_enable()
accelerator.register_for_checkpointing(scheduler)
accelerator.print(f"Max train steps: {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0
model.train()
loss_func = CrossEntropyLoss(inplace_backward=True)
for step, batch in enumerate(train_loader):
input_ids = batch["input_ids"][..., : args.seq_length + 1][..., :-1]
target_ids = batch["input_ids"][..., : args.seq_length + 1][..., 1:]
position_ids = (
torch.arange(args.seq_length).unsqueeze(0).expand(input_ids.shape[0], -1)
)
# shard the input_ids according to the world size and rank according to zig zag attention
# print(input_ids.shape, position_ids.shape) # these values must be equal
prepared = prepare_seq_parallel_inputs(
args.parallel_mode,
input_ids,
position_ids,
target_ids,
accelerator.process_index,
accelerator.num_processes,
accelerator.device,
)
local_input_ids = prepared["local_input_ids"]
local_position_ids = prepared["local_position_ids"]
local_target_ids = prepared["local_target_ids"]
loss_log = None
with accelerator.accumulate(model):
logits = model(
local_input_ids,
position_ids=local_position_ids,
).logits
loss = loss_func(
logits.reshape(-1, logits.shape[-1]), local_target_ids.reshape(-1)
)
accelerator.backward(loss)
if accelerator.sync_gradients:
# pay attention here. When any seq parallel algo is turned on. This technically only log the very first chunk's loss
# and what is the first chunk really depends on how do you shard the sequence
# for zig zag attention, the first chunk contains the left most and rightmost tokens
# so you cannot compare the (logged) loss of dist attention and zigzag ring attention.
# loss_log = {"loss": loss.item(), "ppl": math.exp(loss.item())}
# we now try gathered loss to verify if ring attention and dist flash attention produce the same loss
# this may slow down the training
gathered_loss = accelerator.reduce(loss.clone().detach(), "mean")
loss_log = {
"loss": gathered_loss.item(),
"ppl": math.exp(gathered_loss.item()),
}
accelerator.log(loss_log, step=completed_steps)
optim.step()
scheduler.step()
optim.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
if loss_log is not None:
progress_bar.set_postfix(loss_log)
completed_steps += 1
if completed_steps >= args.max_train_steps:
break
accelerator.print(f"Training Finished")
accelerator.end_training()
if args.output_dir is not None:
accelerator.print(f"Saving model to {args.output_dir}")
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
accelerator.unwrap_model(model).save_pretrained(
f"{args.output_dir}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=state_dict,
)
accelerator.print(f"Saving Finished")