SwissArmyTransformer icon indicating copy to clipboard operation
SwissArmyTransformer copied to clipboard

测试源码中给的qlora.py报错

Open shituo123456 opened this issue 1 year ago • 7 comments

直接跑源码的qlora.py,报错 image 给model.child = LoraLinear(100, 200, 10)改为model.child = LoraLinear(100, 200, 10,10,2)后,又报错 image

shituo123456 avatar Aug 01 '23 03:08 shituo123456

这是旧版本的__main__函数了,需要你自己改一下。

1049451037 avatar Aug 01 '23 03:08 1049451037

这是qlora.py的执行代码,怎么改呢,一直做cv,才基础多模态大模型

if __name__ == '__main__':
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.child = nn.Linear(100, 200)
        
        def forward(self, x):
            return self.child(x)

    model = Model()
    torch.save(model.state_dict(), "linear.pt")
    x = torch.randn(2, 100)
    out1 = model(x)
    model.child = LoraLinear(100, 200, 10)
    model.load_state_dict(torch.load("linear.pt"), strict=False)
    out2 = model(x)
    torch.save(model.state_dict(), "lora.pt")
    ckpt = torch.load("lora.pt")
    breakpoint()
    model.load_state_dict(ckpt, strict=False)
    out3 = model(x)
    breakpoint() 

shituo123456 avatar Aug 01 '23 03:08 shituo123456

我也忘记了,时间太久了,你自己读一下源码吧,也不长

1049451037 avatar Aug 01 '23 03:08 1049451037

我也忘记了,时间太久了,你自己读一下源码吧,也不长

好的,那我先试试

shituo123456 avatar Aug 01 '23 03:08 shituo123456

这样改还会报quant_state不能是None,这个quant_state该怎么添加

if __name__ == '__main__':
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.child = nn.Linear(100, 200)
        
        def forward(self, x):
            return self.child(x)

    model = Model()
    torch.save(model.state_dict(), "linear.pt")
    x = torch.randn(2, 100)
    out1 = model(x)
    model.child = LoraLinear(nn.Linear, 5, 100, 200, 10, qlora=True)
    model.load_state_dict(torch.load("linear.pt"), strict=False)
    out2 = model(x)
    torch.save(model.state_dict(), "lora.pt")
    ckpt = torch.load("lora.pt")
    breakpoint()
    model.load_state_dict(ckpt, strict=False)
    out3 = model(x)
    breakpoint()

image

shituo123456 avatar Aug 01 '23 05:08 shituo123456

需要在gpu上运行才会有quant_state。也就是说你需要model = model.cuda()x = x.cuda()

并且注意model.cuda只能调用一次,不然会出错(这是bitsandbytes的实现,我也控制不了,他们重载了.cuda()函数)

1049451037 avatar Aug 01 '23 05:08 1049451037

需要在gpu上运行才会有quant_state。也就是说你需要model = model.cuda()x = x.cuda()

并且注意model.cuda只能调用一次,不然会出错(这是bitsandbytes的实现,我也控制不了,他们重载了.cuda()函数)

确实只能.cuda()一次,给LoraLinear提前.cuda()就会报维度错误。 调试好了,非常感谢耐心回复

shituo123456 avatar Aug 01 '23 06:08 shituo123456