ao
ao copied to clipboard
GPTQ refactor [WIP]
fixed the device swapping, the previous code was mega slow without it
changed the point where we get tensor ids and check the name so we can safely to device swapping.
fixed issue where the MultiTensor weight would end up with a size n multi tensor afterwards
data_ptr() seems to be not a unique identifier??? I had situations where 2 MultiTensors had the same data_ptr() which was surprising. Changed to id() which seems to work better.
incidentally changed the id_to_paramter_name thing so that it connects id's to names after we turn things to multitensors. This makes it easier to identify at runtime. The alternative was giving me issues since creating a MultiTensor param requires you to do a detach operation so it was changing the id, easiest to just create the lookup table after all the changes are made.
there's a lot of test code in here just to get a faster debug cycle that need to be removed but it may make it easier to debug as well.
ran into an issue with shapes when doing the final loading state dict back into the model which i can debug but its super late. everything else seems to be working more or less correctly. Didn't want to overwrite the other GPTQ refactor PR because it says this one is behind.