Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load bfloat (float16) weight into torchsharp model #1204

Open
LittleLittleCloud opened this issue Jan 21, 2024 · 10 comments
Open

How to load bfloat (float16) weight into torchsharp model #1204

LittleLittleCloud opened this issue Jan 21, 2024 · 10 comments

Comments

@LittleLittleCloud
Copy link

The current convert python script converts a tensor to np array before writing to file. However, since np array doesn't support the bf16 type, the convert script won't work if the model weight contains bf16 type.

My current workaround is to save model weight in f32 type and set bf16 as default weight before inferencing the model. However, the cost is nearly double the size of the exported model weight. So I wonder if it's possible to 1) add function to save bf16 weight in python convert script and 2) maybe add support to load from pytorch checkpoint file or hf .safetensor format to further facilitate the loading model weight process.

@lintao185
Copy link

Storing using binary

def encode(writer,value: int) -> None:
    if value < 0:
        raise NotImplementedError("LEB128 encoding of negative numbers is not implemented")

    while value > 0:
        num = value & 127
        value >>= 7
        if value != 0:
            byte_to_write = num |128
            # 写入当前字节
            writer.write(byte_to_write.to_bytes())
        else:
            break

    # 当value为0时,最后写入的num就是最终结果
    writer.write(num.to_bytes())
def save_tensor_to_binary(tensor: torch.Tensor, binary_file):
    # 先处理设备问题
    flag = False
    if tensor.device.type != 'cpu':
        tensor = tensor.to('cpu')
        flag = True

    match tensor.dtype:
        case torch.float16:
            dtype_code = 5
        case torch.float32:
            dtype_code = 6
        case torch.float64:
            dtype_code = 7

    # 写入数据类型
    encode(binary_file, dtype_code)
    # 写入形状长度
    shape_len = len(tensor.shape)
    encode(binary_file, shape_len)
    # 写入每个维度大小
    for dim in tensor.shape:
        encode(binary_file, dim)
    # 将tensor内容转换为字节并写入
    tensor_data = tensor.numpy().tobytes()
    binary_file.write(tensor_data)

def save_tensor(tensor:torch.Tensor|list[torch.Tensor], file_path: str):
    with open(file_path, 'wb') as binary_file:
        if isinstance(tensor,list):
            encode(binary_file,2)
            encode(binary_file, len(tensor))
            for tensor in tensor:
                save_tensor_to_binary(tensor.double(), binary_file)
        else:
            encode(binary_file, 1)
            save_tensor_to_binary(tensor.double(), binary_file)

@lintao185
Copy link

   /// <summary>
    /// 加载模型参数
    /// </summary>
    /// <param name="dict">参数字典</param>
    /// <param name="location">参数的位置</param>
    public static void LoadStateDict(this Dictionary<string, Tensor> dict, string location)
    {
        using FileStream stream = File.OpenRead(location);
        using BinaryReader reader = new BinaryReader(stream);
        var num = reader.Decode();
        for (int i = 0; i < num; i++)
        {
            var key = reader.ReadString();
            var tensor = TensorExtensionMethods.Load(reader, skip: false);
            dict.Add(key, tensor);
        }
    }
    /// <summary>
    /// 加载Tensor列表
    /// </summary>
    /// <param name="tensors">Tensor列表</param>
    /// <param name="location">文件位置</param>
    public static void LoadTensorList(this List<Tensor> tensors, string location)
    {
        using FileStream stream = File.OpenRead(location);
        using BinaryReader reader = new BinaryReader(stream);
        var storeType = reader.Decode();
        if (storeType != 2)
        {
            throw new Exception($"{location}文件存储的不是Tensor列表");
        }
        var num = reader.Decode();
        for (int i = 0; i < num; i++)
        {
            var tensor = TensorExtensionMethods.Load(reader, skip: false);
            tensors.Add(tensor);
        }
    }
    
    //
    // 摘要:
    //     Decode a long value from a binary reader
    //
    // 参数:
    //   reader:
    //     A BinaryReader instance used for input.
    //
    // 返回结果:
    //     The decoded value
    public static long Decode(this BinaryReader reader)
    {
        long num = 0L;
        int num2 = 0;
        while (true)
        {
            long num3 = reader.ReadByte();
            num += (num3 & 0x7F) << num2 * 7;
            if ((num3 & 0x80) == 0L)
            {
                break;
            }

            num2++;
        }

        return num;
    }

@LittleLittleCloud
Copy link
Author

LittleLittleCloud commented Jan 21, 2024

perfect, let me try it! Thanks @lintao185

Update

Hey @lintao185 , I tried your solution, and it seems that there're two problems?

The first problem is in python code, it seems that the tensor will be converted to double before writing to binary file, so the exported model size will be four times larger comparing with saving with bfloat16 format. (After exporting, the size of llama 2 model grows to ~50GB while the python ckpt is ~13GB)

The second problem is in TensorExtensionMethods.Load, which seems to read the binary array according to sizeof(dtype) * shape, which might cause loading error when the element type is encoded as bfloat16 but the actual saving array is double

@lintao185
Copy link

Yes, indeed, you could change it to save_tensor_to_binary(tensor, binary_file). It's worth noting that the conversion to double was initially intended for enhanced compatibility. As an alternative, you could experiment with loading the tensor into TorchSharp and subsequently saving a version of the parameters using native APIs officially offered by TorchSharp.

@phizch
Copy link

phizch commented Jan 23, 2024

Do you have the .cktp file and want to load it? Or do you want to convert it to a file that can be read by the built in methods in TorchSharp? If it's the former, I've just creted a tool to load ckpt files directly, though I haven't tested it on BF16 yet. If you want I can clean it up a bit and create a gist... and also test with bf16 :-)

It relies on the Razorvine.Pickle library to unpickle the data.pkl stored in the ckpt archive.

@LittleLittleCloud
Copy link
Author

LittleLittleCloud commented Jan 23, 2024

@phizch The .ckpt file I want to load is llama-2-7b. I'm not sure if I can share it here because of licensing but you can easily download it following this guidance.

Below is the step of what I want to do. Essentially, the reason of why I want to load directly from .ckpt is to save the effort of manually converting a .ckpt format to torchsharp format.

  • from .ckpt, load all tensors, including it's data, name type and shape (I don't know the details in .ckpt so I'm not sure if those informations are available in .ckpt, but those information can help me loading those weight into llama model built with torchsharp)
  • after loading all tensors, create a state_dict similarly with the loading function below and load it into torchsharp llama 2 model

Also, here's the link to the loading function I currently used to load model weight. It's modified based on @lintao185's solution (Thanks BTW) and requires a separate conversion from llama 2 ckpt to torchsharp format, which I'd like to get rid of.

And thanks ahead for any potential solution/help !

@shaltielshmid
Copy link
Contributor

shaltielshmid commented Jan 23, 2024

@LittleLittleCloud I haven't tried it myself, but have you tried loading using TorchSharp.PyBridge?

You can install it using nuget:

Install-Package TorchSharp.PyBridge

And then you can load in the PyTorch weights without applying any conversions:

model.load_py('/path/to/ckpt')

(This should work with the regular pytorch checkpoints, not SafeTensors.)

@LittleLittleCloud
Copy link
Author

LittleLittleCloud commented Jan 23, 2024

@shaltielshmid I just tried your package and your solution works like a charm.

Here's the steps I take in case anyone also encounter the similar problem

step 1

in python, save the state_dict to disk. .ckpt contains some extra meta information so it can't be loaded directly into torchsharp model and you need to save the state dict instead. However, maybe there's a way to extract state_dict from .ckpt in csharp?

# use bf16 as default
# this is a requirement if you want to save llama weight in bfloat16
torch.set_default_dtype(torch.bfloat16)

# some code to load transformer

# save model state dict
with open(llama_torchsharp_weights_path, 'wb') as f:
    torch.save(model.state_dict(keep_vars=False), f)

step 2

in csharp

// create transformer
transformer.load_py(llama_torchsharp_weights_path)

And the model size (consolidate.0.pth is the original ckpt from llama, llama-2-7b.pt is the model weight converted by exports script where bfloat is saved as float. llama-2-7b-2.pt is model weight exported by Torchsharp.PyBridge)

image

And it seems that Torchsharp.PyBridge has a dependency on Torchsharp > 0.105, which I can't find an official linux cuda 11.* runtime support on nuget? I create another issue to track this. @dotnet/torchsharp-admin could you help me out there?

@shaltielshmid
Copy link
Contributor

TorchSharp.PyBridge is dependant on features that were added only in version 0.101.5 of TorchSharp.

But, since the TorchSharp package includes the cuda binaries already, you can update the package even if you don't have CUDA 12.X on your machine.

@NiklasGustafsson
Copy link
Contributor

TorchSharp.PyBridge is dependant on features that were added only in version 0.101.5 of TorchSharp.

But, since the TorchSharp package includes the cuda binaries already, you can update the package even if you don't have CUDA 12.X on your machine.

But you do need drivers that are CUDA 12 compatible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants