safetensors icon indicating copy to clipboard operation
safetensors copied to clipboard

Fix byteswap for BF16 tensor on big-endian machine

Open kiszk opened this issue 1 year ago • 10 comments

What does this PR do?

This PR fixes an incorrect Tensor's data swap in BF16 on a big-endian machine. It uses Storage.byteswap(datatype) in PyTorch instead of numpy.byteswap(). This is because numpy does not support BF16. Unnecessary data conversions between BF16 and F16 caused incorrect results.

Fixes #448

kiszk avatar Mar 14 '24 15:03 kiszk

@Narsil Do you have time to look at this PR?

kiszk avatar Mar 18 '24 16:03 kiszk

@Narsil Do you have time to look at this PR?

kiszk avatar Apr 10 '24 14:04 kiszk

@Narsil Do you have time to look at this PR?

kiszk avatar Apr 24 '24 15:04 kiszk

@Narsil Do you have time to look at this PR?

kiszk avatar May 16 '24 02:05 kiszk

@Narsil Do you have time to look at this PR?

kiszk avatar Jun 05 '24 02:06 kiszk

@Narsil Would it be possible to look at this PR?

kiszk avatar Jul 01 '24 17:07 kiszk

Can we kindly ask someone to review and merge this PR?

abalib avatar Jul 22 '24 15:07 abalib

Hi @kiszk,

Sorry about the delay, I haven't been able to check every place I'm being pinged on (and far from it).

I understand the issue and the fix. From what I understand, the fix introduced 2 issues imho:

  • It limits byteswap to torch 2.1 which is not required (for all the other dtypes other than bf16).
  • It introduces a clone operation, which is something I've vividly tried to avoid (given the size of the objects, cloning is super costly).

I'll try to take a look at how I can improve that solution if you're ok with my assumptions. Also there is a test that was supposed to catch that: https://github.com/huggingface/safetensors/blob/7d29f617a33251a8b5f9e8228c004edc3a730f4a/bindings/python/tests/test_pt_comparison.py#L53

I am guessing it didn't because it used zeros instead of randn, and bfloat16 is specifically the same bytes for the zero representation :( Therefore we could just update that test.

Narsil avatar Jul 26 '24 10:07 Narsil

Managed to trigger the issue https://github.com/huggingface/safetensors/actions/runs/10109667229/job/27958053499?pr=507 on old tests with random values.

Narsil avatar Jul 26 '24 10:07 Narsil

@Narsil, thank you for your response while you were super busy.

Your points are correct:

  1. Since the prior to PyTorch 2.1 does not support Big Endian correctly, I chose the implementation to throw an exception. We can implement the same feature using Rust (Python was too slow) in safetensors since this swap function exists in 2.1 or later.
  2. I chose the conservative implementation to avoid a destructive operation. It is fine if we can avoid the clone operation.

For testing point, I think that this test uses 1.0 instead of zero. Anyway, I just realized this code does not update the original model. I agree that we have to update the test code.

I will work to update some of them if you are still super busy.

kiszk avatar Jul 26 '24 15:07 kiszk

@kiszk Can you check the last implementation ?

I think numpy byteswap can still work, since bf16 byteswap and f16 byteswap are the same. The only thing necessary is those .view(..) which are just reinterpreting arrays.

No issue with torch version, no clone, faster byteswap.

Could you confirm the branch works for you (all the tests should show issues now).

Narsil avatar Jul 30 '24 17:07 Narsil

Sure, I will check the latest main soon.

kiszk avatar Jul 30 '24 17:07 kiszk

The latest main branch does not work for me.

As you said, byteswap is the same between f16 and bf16. IMHO, the problem may come from data conversion from bf16 and f16 / from f16 to bf16. We have to do byteswap without any data conversion.

In detail, on a big-endian platform, at the beginning, an element in bf16 tensor has an uninterpretable value since its format is bf16 in little-endian. In the current implementation, an uninterpretable value is interpreted as bf16, and then it is converted from bf16 to f16. The value in f16 is meaningless. Then, it is byteswaped for big-endian in numpy. After byteswap, in numpy, the value is not good in bf16 on a big-endian platform.

kiszk avatar Jul 30 '24 18:07 kiszk

@Narsil If I make mistakes, could you please let me know?

kiszk avatar Jul 30 '24 18:07 kiszk

@kiszk I didn't say main branch, sorry I was referrring to https://github.com/huggingface/safetensors/pull/507

Narsil avatar Jul 31 '24 07:07 Narsil

Superseeded by https://github.com/huggingface/safetensors/pull/507

Basically the issue is that previous code was doing (bf16) -> to -> (f16) -> byteswap -> to -> (bf16).

to is a cast operator which changes the underlying bits of the vectors which is not desired. Changing to view (which is torch reinterpret operator) does not change the bits anymore, meaning the byteswap is valid.

Closing this assuming the bug is fixed. Let's reopen something if itś not.

Narsil avatar Jul 31 '24 08:07 Narsil