llama-recipes icon indicating copy to clipboard operation
llama-recipes copied to clipboard

Cache mechanism would result in Error while training?

Open HeguangtongchenWZQ opened this issue 1 year ago • 3 comments

I encountered a problem when fine-tuning the model. I believe the issue lies in the cache mechanism of the class Attention(nn.Module). During the previous forward pass, the cache for k and v is saved as part of the old computation graph. The problem is that the grad_fn for backward propagation is still retained, resulting in the error message

RuntimeError: Trying to backward through the graph a second time.

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        ......
        self.cache_k = torch.zeros(...).cuda()
        self.cache_v = torch.zeros(...).cuda()
    def forward(...)
         ......
        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

And now I revise the func: forward

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        ......
    def forward(...)
         ......
        self.cache_k = self.cache_k.to(xq).detach()
        self.cache_v = self.cache_v.to(xq).detach()

I use .detach() to make the computation graph seperate to previous computation graph. And then it can be fintuned (forward and backward batch by batch).

Do you think I make a right revision?

HeguangtongchenWZQ avatar Oct 07 '23 09:10 HeguangtongchenWZQ