TorchSharp icon indicating copy to clipboard operation
TorchSharp copied to clipboard

How to load bfloat (float16) weight into torchsharp model

Open LittleLittleCloud opened this issue 1 year ago • 10 comments

The current convert python script converts a tensor to np array before writing to file. However, since np array doesn't support the bf16 type, the convert script won't work if the model weight contains bf16 type.

My current workaround is to save model weight in f32 type and set bf16 as default weight before inferencing the model. However, the cost is nearly double the size of the exported model weight. So I wonder if it's possible to 1) add function to save bf16 weight in python convert script and 2) maybe add support to load from pytorch checkpoint file or hf .safetensor format to further facilitate the loading model weight process.

LittleLittleCloud avatar Jan 21 '24 10:01 LittleLittleCloud

Storing using binary

def encode(writer,value: int) -> None:
    if value < 0:
        raise NotImplementedError("LEB128 encoding of negative numbers is not implemented")

    while value > 0:
        num = value & 127
        value >>= 7
        if value != 0:
            byte_to_write = num |128
            # 写入当前字节
            writer.write(byte_to_write.to_bytes())
        else:
            break

    # 当value为0时,最后写入的num就是最终结果
    writer.write(num.to_bytes())
def save_tensor_to_binary(tensor: torch.Tensor, binary_file):
    # 先处理设备问题
    flag = False
    if tensor.device.type != 'cpu':
        tensor = tensor.to('cpu')
        flag = True

    match tensor.dtype:
        case torch.float16:
            dtype_code = 5
        case torch.float32:
            dtype_code = 6
        case torch.float64:
            dtype_code = 7

    # 写入数据类型
    encode(binary_file, dtype_code)
    # 写入形状长度
    shape_len = len(tensor.shape)
    encode(binary_file, shape_len)
    # 写入每个维度大小
    for dim in tensor.shape:
        encode(binary_file, dim)
    # 将tensor内容转换为字节并写入
    tensor_data = tensor.numpy().tobytes()
    binary_file.write(tensor_data)

def save_tensor(tensor:torch.Tensor|list[torch.Tensor], file_path: str):
    with open(file_path, 'wb') as binary_file:
        if isinstance(tensor,list):
            encode(binary_file,2)
            encode(binary_file, len(tensor))
            for tensor in tensor:
                save_tensor_to_binary(tensor.double(), binary_file)
        else:
            encode(binary_file, 1)
            save_tensor_to_binary(tensor.double(), binary_file)

HCareLou avatar Jan 21 '24 10:01 HCareLou

   /// <summary>
    /// 加载模型参数
    /// </summary>
    /// <param name="dict">参数字典</param>
    /// <param name="location">参数的位置</param>
    public static void LoadStateDict(this Dictionary<string, Tensor> dict, string location)
    {
        using FileStream stream = File.OpenRead(location);
        using BinaryReader reader = new BinaryReader(stream);
        var num = reader.Decode();
        for (int i = 0; i < num; i++)
        {
            var key = reader.ReadString();
            var tensor = TensorExtensionMethods.Load(reader, skip: false);
            dict.Add(key, tensor);
        }
    }
    /// <summary>
    /// 加载Tensor列表
    /// </summary>
    /// <param name="tensors">Tensor列表</param>
    /// <param name="location">文件位置</param>
    public static void LoadTensorList(this List<Tensor> tensors, string location)
    {
        using FileStream stream = File.OpenRead(location);
        using BinaryReader reader = new BinaryReader(stream);
        var storeType = reader.Decode();
        if (storeType != 2)
        {
            throw new Exception($"{location}文件存储的不是Tensor列表");
        }
        var num = reader.Decode();
        for (int i = 0; i < num; i++)
        {
            var tensor = TensorExtensionMethods.Load(reader, skip: false);
            tensors.Add(tensor);
        }
    }
    
    //
    // 摘要:
    //     Decode a long value from a binary reader
    //
    // 参数:
    //   reader:
    //     A BinaryReader instance used for input.
    //
    // 返回结果:
    //     The decoded value
    public static long Decode(this BinaryReader reader)
    {
        long num = 0L;
        int num2 = 0;
        while (true)
        {
            long num3 = reader.ReadByte();
            num += (num3 & 0x7F) << num2 * 7;
            if ((num3 & 0x80) == 0L)
            {
                break;
            }

            num2++;
        }

        return num;
    }

HCareLou avatar Jan 21 '24 10:01 HCareLou

perfect, let me try it! Thanks @lintao185

Update

Hey @lintao185 , I tried your solution, and it seems that there're two problems?

The first problem is in python code, it seems that the tensor will be converted to double before writing to binary file, so the exported model size will be four times larger comparing with saving with bfloat16 format. (After exporting, the size of llama 2 model grows to ~50GB while the python ckpt is ~13GB)

The second problem is in TensorExtensionMethods.Load, which seems to read the binary array according to sizeof(dtype) * shape, which might cause loading error when the element type is encoded as bfloat16 but the actual saving array is double

LittleLittleCloud avatar Jan 21 '24 18:01 LittleLittleCloud

Yes, indeed, you could change it to save_tensor_to_binary(tensor, binary_file). It's worth noting that the conversion to double was initially intended for enhanced compatibility. As an alternative, you could experiment with loading the tensor into TorchSharp and subsequently saving a version of the parameters using native APIs officially offered by TorchSharp.

HCareLou avatar Jan 22 '24 01:01 HCareLou

Do you have the .cktp file and want to load it? Or do you want to convert it to a file that can be read by the built in methods in TorchSharp? If it's the former, I've just creted a tool to load ckpt files directly, though I haven't tested it on BF16 yet. If you want I can clean it up a bit and create a gist... and also test with bf16 :-)

It relies on the Razorvine.Pickle library to unpickle the data.pkl stored in the ckpt archive.

phizch avatar Jan 23 '24 06:01 phizch

@phizch The .ckpt file I want to load is llama-2-7b. I'm not sure if I can share it here because of licensing but you can easily download it following this guidance.

Below is the step of what I want to do. Essentially, the reason of why I want to load directly from .ckpt is to save the effort of manually converting a .ckpt format to torchsharp format.

  • from .ckpt, load all tensors, including it's data, name type and shape (I don't know the details in .ckpt so I'm not sure if those informations are available in .ckpt, but those information can help me loading those weight into llama model built with torchsharp)
  • after loading all tensors, create a state_dict similarly with the loading function below and load it into torchsharp llama 2 model

Also, here's the link to the loading function I currently used to load model weight. It's modified based on @lintao185's solution (Thanks BTW) and requires a separate conversion from llama 2 ckpt to torchsharp format, which I'd like to get rid of.

And thanks ahead for any potential solution/help !

LittleLittleCloud avatar Jan 23 '24 07:01 LittleLittleCloud

@LittleLittleCloud I haven't tried it myself, but have you tried loading using TorchSharp.PyBridge?

You can install it using nuget:

Install-Package TorchSharp.PyBridge

And then you can load in the PyTorch weights without applying any conversions:

model.load_py('/path/to/ckpt')

(This should work with the regular pytorch checkpoints, not SafeTensors.)

shaltielshmid avatar Jan 23 '24 11:01 shaltielshmid

@shaltielshmid I just tried your package and your solution works like a charm.

Here's the steps I take in case anyone also encounter the similar problem

step 1

in python, save the state_dict to disk. .ckpt contains some extra meta information so it can't be loaded directly into torchsharp model and you need to save the state dict instead. However, maybe there's a way to extract state_dict from .ckpt in csharp?

# use bf16 as default
# this is a requirement if you want to save llama weight in bfloat16
torch.set_default_dtype(torch.bfloat16)

# some code to load transformer

# save model state dict
with open(llama_torchsharp_weights_path, 'wb') as f:
    torch.save(model.state_dict(keep_vars=False), f)

step 2

in csharp

// create transformer
transformer.load_py(llama_torchsharp_weights_path)

And the model size (consolidate.0.pth is the original ckpt from llama, llama-2-7b.pt is the model weight converted by exports script where bfloat is saved as float. llama-2-7b-2.pt is model weight exported by Torchsharp.PyBridge)

image

And it seems that Torchsharp.PyBridge has a dependency on Torchsharp > 0.105, which I can't find an official linux cuda 11.* runtime support on nuget? I create another issue to track this. @dotnet/torchsharp-admin could you help me out there?

LittleLittleCloud avatar Jan 23 '24 19:01 LittleLittleCloud

TorchSharp.PyBridge is dependant on features that were added only in version 0.101.5 of TorchSharp.

But, since the TorchSharp package includes the cuda binaries already, you can update the package even if you don't have CUDA 12.X on your machine.

shaltielshmid avatar Jan 23 '24 19:01 shaltielshmid

TorchSharp.PyBridge is dependant on features that were added only in version 0.101.5 of TorchSharp.

But, since the TorchSharp package includes the cuda binaries already, you can update the package even if you don't have CUDA 12.X on your machine.

But you do need drivers that are CUDA 12 compatible.

NiklasGustafsson avatar Jan 23 '24 20:01 NiklasGustafsson

@LittleLittleCloud -- still an issue?

NiklasGustafsson avatar Jun 20 '24 18:06 NiklasGustafsson

Not any more!

LittleLittleCloud avatar Jun 20 '24 18:06 LittleLittleCloud