AMD support
Can you add AMD support? thanks
@userbox020 AMD will work if Triton works on ROCm and if Flash Attention and Xformer works. Also bitsandbytes does not work on AMD (yet) - possibly in the future.
@danielhanchen does this will help bro?
https://github.com/ROCm/triton/tree/triton-mlir
what would be the next step?
@userbox020 Oh if Triton is on AMD, then Flash Attention. If those 2 work, then I guess Unsloth will work (hopefully)
@danielhanchen which version off Flash Attention is needed? There is a version with ROCm Support but its a little bit behind.
https://github.com/ROCm/flash-attention
@lufixSch Oh I think it is this one? Uncertain though - I think some members on our Discord server were trying to get AMD to work - I think some were successful, but then bitsandbytes became an issue
any progress here? would love to try this on my 7900xtx
I was able to get Unsloth to train fine with DPO using a docker container on an AMD system, but SFT throws this error: This is on a 4x Instinct MI100 System
@danielhanchen
RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/app/model/finetune.py", line 97, in <module>
trainer_stats = trainer.train()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/trl/trainer/sft_trainer.py", line 361, in train
output = super().train(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "<string>", line 361, in _fast_inner_training_loop
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/trainer.py", line 3138, in training_step
loss = self.compute_loss(model, inputs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/trainer.py", line 3161, in compute_loss
outputs = model(**inputs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/accelerate/utils/operations.py", line 825, in forward
return model_forward(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/accelerate/utils/operations.py", line 813, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/unsloth/models/llama.py", line 882, in PeftModelForCausalLM_fast_forward
return self.base_model(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/unsloth/models/llama.py", line 847, in _CausalLM_fast_forward
loss = fast_cross_entropy_loss(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/unsloth/kernels/cross_entropy_loss.py", line 274, in fast_cross_entropy_loss
loss = Fast_CrossEntropyLoss.apply(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/unsloth/kernels/cross_entropy_loss.py", line 219, in forward
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 581, in run
bin.c_wrapper(
SystemError: <built-in function launch> returned a result with an error set
Edit: Train fine is a bit of an overstatement. Does it train? Yes. Is it coherent with output? No
<|begin_of_text|><|start_header_id|>assistant<|end_header_id|>
ศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจศจ
Hmmm sadly AMD is an issue, and sadly I don't have an AMD GPU to even try support for it - apologies :(
Ok so there is now;
https://github.com/ROCm/bitsandbytes
and
https://github.com/ROCm/flash-attention
and
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1
What else is needed? Or should I start mucking around? I have a working install for inferencing and don't want to mess it up if there is still something that has no support on ROCm yet
Following up on this; ROCm seems to have all the support to do the work needed at this point;
https://rocm.docs.amd.com/en/latest/how-to/llm-fine-tuning-optimization/overview.html
https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/index.html
So any potential we can have an update or guide me in how I can make it happen?
Would also like to try to get ROCM working on unsloth. Was able to get torchtune working pretty easily with ROCM which is the next best thing. Would be nice to try unsloth next.
Sorry it's not technically on our roadmap yet :( @erasmus74 If those work, then try installing Unsloth with no dependencies, and see if it runs