transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix attn mask ignore logic in training-time trace

Open zhenglongjiepheonix opened this issue 6 months ago • 2 comments

This pr fixes a scenario where we want to use dynamo trace in training mode, the current attn mask ignore logic creates a problem where data-dependent branch condition torch.all(attn_mask==1) will cause graph breaks and disable full-graph tracing, the current solution is to disable mask ignore logic as long as we are in tracing mode no matter we are in training or inference phase.

This will enable compilation for training(forward+backward) like this:

model = LlamaForCausalLM(config).cuda()
model = torch.compile(model, fullgraph=True)
loss = model(**inputs)[0]
loss.backward()

zhenglongjiepheonix avatar Aug 11 '24 23:08 zhenglongjiepheonix