error occurs when set gradient_checking = True during stage 2
File "./Moore-AnimateAnyone/src/models/mutual_self_attention.py", line 154, in
On stage 2, when I turned on the gradient_checking = True, it shows the above error. If I turned it off, the error disappears. Why is this?
Hi, did you ever solve this problem?
Problem in positional arguments:
Try to replace this code in src/models/transformer_3d.py (157-168 lines):
hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, attention_mask=None, video_length=video_length, self_attention_additional_feats=self_attention_additional_feats, mode=mode, )
with this one:
def wrapper_fn(block, *args): return block( args[0], encoder_hidden_states=args[1], timestep=args[2], attention_mask=args[3], video_length=args[4], self_attention_additional_feats=args[5], mode=args[6], )
hidden_states = torch.utils.checkpoint.checkpoint(
wrapper_fn,
block,
hidden_states,
encoder_hidden_states,
timestep,
None,
video_length,
self_attention_additional_feats
mode
)