unsloth icon indicating copy to clipboard operation
unsloth copied to clipboard

[BUG] RuntimeError: Invalid device string: 'bfloat16' with transformers v4.40.1 and save_strategy="epoch"

Open OAHC2022 opened this issue 1 year ago • 1 comments

While fine-tuning the unsloth/codellama-7b model using transformers v4.40.1 and setting save_strategy=epoch, I encountered the following error:

line 540, in LlamaModel_fast_forward
    inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
RuntimeError: Invalid device string: 'bfloat16'

Upon examining the code, I identified the problematic line at this GitHub location:

# Embed positions
if inputs_embeds is None:
    inputs_embeds = self.embed_tokens(input_ids)

inputs_embeds = inputs_embeds.to(self.config.torch_dtype)

It appears that during the model's saving process in fine-tuning, the self.config.torch_dtype was incorrectly set as the string "bfloat16" instead of torch.bfloat16. Here's a simple fix I implemented:

# Embed positions
if inputs_embeds is None:
    inputs_embeds = self.embed_tokens(input_ids)

# My Modification
if self.config.torch_dtype == "bfloat16":
    self.config.torch_dtype = torch.bfloat16
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)

OAHC2022 avatar Apr 30 '24 21:04 OAHC2022

Oh thanks for that!! Will add your fix in! Thanks!

danielhanchen avatar May 01 '24 18:05 danielhanchen

Thank you!

OAHC2022 avatar May 06 '24 17:05 OAHC2022

is this fixed ? still hitting the same bug

johnsonice avatar Jun 23 '24 01:06 johnsonice

@johnsonice Could you try updating Unsloth as in https://github.com/unslothai/unsloth/wiki

danielhanchen avatar Jul 01 '24 00:07 danielhanchen

@danielhanchen I can confirm that I'm still experiencing this bug after updating.

pip uninstall unsloth -y
pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

chawins avatar Aug 15 '24 23:08 chawins

@danielhanchen I'm having this same issue as of commit 976d11a10d54383aeb7a692c69e01151a20bfd72. I get this issue every second time that I run my finetuning script. It alternates between running without issue and throwing this error each time. I saw that this was solved previously by using if statements to check for the name of the dtype as a string and correcting it to the appropriate dtype object. Why not use a mapping of dtype strings to their dtype objects?

A naive approach could look like the following:

Original:

if self.config.torch_dtype == "float32":
    self.config.torch_dtype = torch.float32
elif self.config.torch_dtype == "bfloat16":
    self.config.torch_dtype = torch.bfloat16
elif self.config.torch_dtype == "float16":
    self.config.torch_dtype = torch.float16
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)

To this:

DTYPE_MAP = {  # probably good to set as a class attribute
    "float32": torch.float32,
    torch.float32: torch.float32,
    "float16": torch.float16,
    torch.float16: torch.float16,
    "bfloat16": torch.bfloat16,
    torch.bfloat16: torch.bfloat16,
}

inputs_embeds = inputs_embeds.to(DTYPE_MAP[self.config.torch_dtype])

An enum might be a better solution:

from enum import Enum

import torch


class DtypeMap(Enum):
    float32: torch.dtype = torch.float32
    fp32: torch.dtype = float32
    float16: torch.dtype = torch.float16
    fp16: torch.dtype = float16
    bfloat16: torch.dtype = torch.bfloat16
    bf16: torch.dtype = bfloat16

    @classmethod
    def get_dtype(cls, _v) -> torch.dtype:
        if isinstance(_v, str):
            return cls[_v].value
        elif isinstance(_v, torch.dtype):
            return _v
        else:
            raise TypeError(f"{type(_v).__name__}")

DtypeMap.get_dtype(torch.bfloat16) == torch.bfloat16 # True
DtypeMap.get_dtype("bfloat16") == torch.bfloat16 # True
DtypeMap.get_dtype("bf16") == torch.bfloat16 # True; may be useful for `TrainingArguments`
# List all non-aliased dtype names:
[dt.name for dt in DtypeMap]  # ['float32', 'float16', 'bfloat16']
# List all torch.dtype dtypes defined in the enum:
[dt.value for dt in DtypeMap]  # [torch.float32, torch.float16, torch.bfloat16]
# `.to()` now works without if/elif/else branches cluttering methods and without mutating `config` state
inputs_embeds = inputs_embeds.to(DtypeMap.get_dtype(self.config.torch_dtype))

SVHawk13 avatar Sep 01 '24 19:09 SVHawk13

I can add your mapping idea!

danielhanchen avatar Sep 02 '24 07:09 danielhanchen

I hit this same issue just now, but on this line:

https://github.com/unslothai/unsloth/blob/f1951c0f6d3e1f184af93e5d8f5eff6e7834e4b5/unsloth/models/llama.py#L961C9-L961C52

I don't know if this is the best place to fix it, but I changed it to logits = logits.to(__DTYPE_MAP[self.config.torch_dtype]) and it worked.

llllvvuu avatar Sep 18 '24 00:09 llllvvuu