llama-recipes
llama-recipes copied to clipboard
Cache mechanism would result in Error while training?
I encountered a problem when fine-tuning the model. I believe the issue lies in the cache mechanism of the class Attention(nn.Module). During the previous forward pass, the cache for k and v is saved as part of the old computation graph. The problem is that the grad_fn for backward propagation is still retained, resulting in the error message:
RuntimeError: Trying to backward through the graph a second time.
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
......
self.cache_k = torch.zeros(...).cuda()
self.cache_v = torch.zeros(...).cuda()
def forward(...)
......
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
And now I revise the func: forward
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
......
def forward(...)
......
self.cache_k = self.cache_k.to(xq).detach()
self.cache_v = self.cache_v.to(xq).detach()
I use .detach()
to make the computation graph seperate to previous computation graph.
And then it can be fintuned (forward and backward batch by batch).
Do you think I make a right revision?