safetensors
safetensors copied to clipboard
Fix byteswap for BF16 tensor on big-endian machine
What does this PR do?
This PR fixes an incorrect Tensor's data swap in BF16 on a big-endian machine. It uses Storage.byteswap(datatype) in PyTorch instead of numpy.byteswap(). This is because numpy does not support BF16. Unnecessary data conversions between BF16 and F16 caused incorrect results.
Fixes #448
@Narsil Do you have time to look at this PR?
@Narsil Do you have time to look at this PR?
@Narsil Do you have time to look at this PR?
@Narsil Do you have time to look at this PR?
@Narsil Do you have time to look at this PR?
@Narsil Would it be possible to look at this PR?
Can we kindly ask someone to review and merge this PR?
Hi @kiszk,
Sorry about the delay, I haven't been able to check every place I'm being pinged on (and far from it).
I understand the issue and the fix. From what I understand, the fix introduced 2 issues imho:
- It limits byteswap to torch 2.1 which is not required (for all the other dtypes other than bf16).
- It introduces a
cloneoperation, which is something I've vividly tried to avoid (given the size of the objects, cloning is super costly).
I'll try to take a look at how I can improve that solution if you're ok with my assumptions. Also there is a test that was supposed to catch that: https://github.com/huggingface/safetensors/blob/7d29f617a33251a8b5f9e8228c004edc3a730f4a/bindings/python/tests/test_pt_comparison.py#L53
I am guessing it didn't because it used zeros instead of randn, and bfloat16 is specifically the same bytes for the zero representation :( Therefore we could just update that test.
Managed to trigger the issue https://github.com/huggingface/safetensors/actions/runs/10109667229/job/27958053499?pr=507 on old tests with random values.
@Narsil, thank you for your response while you were super busy.
Your points are correct:
- Since the prior to PyTorch 2.1 does not support Big Endian correctly, I chose the implementation to throw an exception. We can implement the same feature using Rust (Python was too slow) in safetensors since this swap function exists in 2.1 or later.
- I chose the conservative implementation to avoid a destructive operation. It is fine if we can avoid the
cloneoperation.
For testing point, I think that this test uses 1.0 instead of zero. Anyway, I just realized this code does not update the original model. I agree that we have to update the test code.
I will work to update some of them if you are still super busy.
@kiszk Can you check the last implementation ?
I think numpy byteswap can still work, since bf16 byteswap and f16 byteswap are the same. The only thing necessary is those .view(..) which are just reinterpreting arrays.
No issue with torch version, no clone, faster byteswap.
Could you confirm the branch works for you (all the tests should show issues now).
Sure, I will check the latest main soon.
The latest main branch does not work for me.
As you said, byteswap is the same between f16 and bf16. IMHO, the problem may come from data conversion from bf16 and f16 / from f16 to bf16. We have to do byteswap without any data conversion.
In detail, on a big-endian platform, at the beginning, an element in bf16 tensor has an uninterpretable value since its format is bf16 in little-endian. In the current implementation, an uninterpretable value is interpreted as bf16, and then it is converted from bf16 to f16. The value in f16 is meaningless. Then, it is byteswaped for big-endian in numpy. After byteswap, in numpy, the value is not good in bf16 on a big-endian platform.
@Narsil If I make mistakes, could you please let me know?
@kiszk I didn't say main branch, sorry I was referrring to https://github.com/huggingface/safetensors/pull/507
Superseeded by https://github.com/huggingface/safetensors/pull/507
Basically the issue is that previous code was doing (bf16) -> to -> (f16) -> byteswap -> to -> (bf16).
to is a cast operator which changes the underlying bits of the vectors which is not desired.
Changing to view (which is torch reinterpret operator) does not change the bits anymore, meaning the byteswap is valid.
Closing this assuming the bug is fixed. Let's reopen something if itś not.