transformers
transformers copied to clipboard
Fix attn mask ignore logic in training-time trace
This pr fixes a scenario where we want to use dynamo trace in training mode, the current attn mask ignore logic creates a problem where data-dependent branch condition torch.all(attn_mask==1)
will cause graph breaks and disable full-graph tracing, the current solution is to disable mask ignore logic as long as we are in tracing mode no matter we are in training or inference phase.
This will enable compilation for training(forward+backward) like this:
model = LlamaForCausalLM(config).cuda()
model = torch.compile(model, fullgraph=True)
loss = model(**inputs)[0]
loss.backward()