vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
LSQ half precision problem #116
trafficstars
I believe there is a similar problem to #116.
File "/tmp/ray/session_2024-06-30_09-41-50_254745_1/runtime_resources/pip/ed0d17a5f9a959a3d03116db0bba20a6c15cac27/virtualenv/lib/python3.10/site-packages/vector_quantize_pytorch/lookup_free_quantization.py", line 273, in forward
x = self.project_in(x)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
thxxx
PS: I'd like to use it as float32 within a bfloat16 module in FSDP but I do not know how