LLaMA-Factory icon indicating copy to clipboard operation
LLaMA-Factory copied to clipboard

attn_implementation 不起作用

Open lk137095576 opened this issue 9 months ago • 0 comments

Reminder

  • [X] I have read the README and searched the existing issues.

Reproduction

no

Expected behavior

flash atten 通过args设置失败 transformers (v4.37.2之后)中: transformers/modeling_utils.py ` @classmethod def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here.

    Args:
        torch_dtype (`torch.dtype`, *optional*):
            Override the default `torch.dtype` and load the model under this dtype.
    """
    torch_dtype = kwargs.pop("torch_dtype", None)
    use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)

    # override default dtype if needed
    dtype_orig = None
    if torch_dtype is not None:
        dtype_orig = cls._set_default_torch_dtype(torch_dtype)

    config = copy.deepcopy(config)  # We do not want to modify the config inplace in _from_config.
    config._attn_implementation = kwargs.pop("attn_implementation", None)
    config = cls._autoset_attn_implementation(
        config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
    )

..... `

其中 config._attn_implementation = kwargs.pop("attn_implementation", None) transformers这里重新设置了_attn_implementation ,导致llama factory中的_configure_attn_implementation失效

System Info

transformers 4.37.2

Others

no

lk137095576 avatar May 20 '24 09:05 lk137095576