bitsandbytes icon indicating copy to clipboard operation
bitsandbytes copied to clipboard

Add ROCm support

Open arlo-phoenix opened this issue 9 months ago • 38 comments

Inspired by the llama.cpp ROCm port, I decided to try and use a similar approach for bitsandbytes and worked through the different hipified cuda functions/classes and just redefine them with the HIP equivalents. This only happens if BNB_USE_HIP is set and merging this shouldn't affect the CUDA code at all. It's also easier to maintain than keeping a parallel hip code base alive.

This PR adds the target hip to make and works with the most recent version (0.42.0) with ROCm 5.6+ (6.0 included). For installing just do

# Your ROCM_TARGET can be found with rocminfo | grep gfx
ROCM_TARGET=gfx1030 make hip
pip install .

It won't pass all tests as some are igemm or Cuda specific, but all optimizers work in both 8bit and 32bit. I also used this a lot with llama 4-bit interference, that also works. The tests that fail are beside those test_autograd.py and anything with double_quant in its name, I assume that also has to do with matrix multiplication and is expected to fail.

Besides that igemm / Matrix core support for the more recent AMD GPU's is still impossible because of missing instructions in hipBLASLt. There is also an official fork which tries to enable it, but doesn't seem finished yet. If you want to use the official fork without hipblasLt @IMbackK provided a patch for it which should work on all ROCm supported GPU's.

I'm making this a draft for now, as it is still not well tested and I haven't really updated the documentation yet. From an actual code standpoint not much will change on my side as I only own a gfx1030 GPU and thus can't test igemm support.

Closes #47, closes #107, closes #681

arlo-phoenix avatar Sep 08 '23 13:09 arlo-phoenix

https://github.com/TimDettmers/bitsandbytes/blob/a13af4542f5a163d8e9eacb0b8b2a5b26d9e2b15/Makefile#L19-L23

You have logic bug there, if ROCM_HOME is empty, it will skip ROCM_TARGET check and will try to compile anyway (will fail later due to missing libraries/headers)

AloneLiberty avatar Sep 23 '23 00:09 AloneLiberty

Tested on Ubuntu 22.04.3 with RX 7900 XTX and passed with the following command: ROCM_HOME=/opt/rocm ROCM_TARGET=gfx1100 make hip

Without ROCM_HOME make cannot find proper headers to build.

zoyamd avatar Sep 26 '23 17:09 zoyamd

I get this error: fatal error: 'hipblaslt/hipblaslt.h' file not found which I can't figure out

CorvetteCole avatar Oct 05 '23 06:10 CorvetteCole

Can compile using the official rocm pytorch docker image and then pip install on the original desktop environment and it works just fine though

CorvetteCole avatar Oct 05 '23 13:10 CorvetteCole

https://github.com/TimDettmers/bitsandbytes/blob/a13af4542f5a163d8e9eacb0b8b2a5b26d9e2b15/Makefile#L19-L23

You have logic bug there, if ROCM_HOME is empty, it will skip ROCM_TARGET check and will try to compile anyway (will fail later due to missing libraries/headers)

Not really that used to working with Makefiles directly. If it's empty it should skip that check since it doesn't matter what target you set as it won't be able to find libraries / binaries with ROCM_HOME being empty. I just tried to roughly copy the cuda checks since that also only takes one parameter. I agree it should throw an error, I'll see what I can do, thanks.

Tested on Ubuntu 22.04.3 with RX 7900 XTX and passed with the following command: ROCM_HOME=/opt/rocm ROCM_TARGET=gfx1100 make hip

Without ROCM_HOME make cannot find proper headers to build.

If ROCM_HOME isn't already set it will try to automatically find with the command which hipcc | rev | cut -d'/' -f4- | rev. Could you check which hipcc? For me (in the official docker image) it's at /opt/rocm/hip/bin/hipcc. I assume it prbly finds the other hipcc link which is at /opt/rocm/bin/hipcc which would require using f3 in that cut. An alternative to using that command to find it could be to just default to /opt/rocm since that seems to be the usual spot.

I get this error: fatal error: 'hipblaslt/hipblaslt.h' file not found which I can't figure out

hipblaslt is available since ROCM 5.6 (which is why I called my fork that). If you are below that this fork won't work and I'd recommend any another one that can be found in the linked issues (no 4bit support though). I dunno if ROCm 5.7 changed anything. From what I read 6.0 is gonna be the release where stuff isn't backwards compatible anymore so I don't think that should already be the case.

arlo-phoenix avatar Oct 07 '23 14:10 arlo-phoenix

FWIW, Arch Linux's ROCM doesn't seem to be distributing hipblaslt (yet), so I am also getting missing hipblaslt.h errors on 5.6 there. It doesn't seem to be a straightforward (for me) thing to build, either.

person4268 avatar Oct 08 '23 08:10 person4268

FWIW, Arch Linux's ROCM doesn't seem to be distributing hipblaslt (yet), so I am also getting missing hipblaslt.h errors on 5.6 there. It doesn't seem to be a straightforward (for me) thing to build, either.

You don't need the actual library, just headers, as igenn is still disabled for this anyways (they were needed for this to compile though, I could've probably just made some placeholder defines so it compiles, but since I had the header I decided not to). The headers can be found under https://github.com/ROCmSoftwarePlatform/hipBLASLt/tree/develop/library/include and you could just put those in /opt/rocm/include. I don't know if they have any other dependencies, but other than that it should work.

arlo-phoenix avatar Oct 08 '23 11:10 arlo-phoenix

That worked (at least as far as being able to execute python -m bitsandbytes, we'll see if it works once the model I'm trying to use finishes downloading), thanks - though do note that I did have to manually create hipblaslt-export.h and hipblaslt-version.h by hand, as those are generated at build time. (just copied the ones from hipblas and pretty much just just did s/HIPBLAS/HIPBLASLT/g in both upper and lower case)

person4268 avatar Oct 09 '23 20:10 person4268

How do you deal with multiple gpus of different targets. I have both MI100s (gfx908) and W6800s (gfx1030) in my machine. Can I use ROCM_TARGET=gfx908;gfx1030?

ccbadd avatar Oct 31 '23 15:10 ccbadd

Any joy with this patch?

bog-dan-ro avatar Dec 11 '23 13:12 bog-dan-ro

Hi, is there any update on this? Would love to have this merged!

fakerybakery avatar Dec 16 '23 00:12 fakerybakery

@arlo-phoenix pls add support for ROCM 5.7

Wintoplay avatar Dec 16 '23 01:12 Wintoplay

Hi, is there any update on this? Would love to have this merged!

Since there seems to now be an official plan to extend support to multiple platforms / hardware targets this will probably have to adjust if it's gonna be merged. I personally want to wait for ROCm 6.0 since that might break stuff. And even if this gets merged it would just be basic support without matrix cores as I don't have a more recent one. And even if I did, the hipblasLt project (at least according to their docs) only officially supports gfx90a, gfx94x (I assume it's gonna expand to MI300 as well, but I haven't seen anything about RDNA 3). The previously missing instructions from hipblasLt are now there (even if they don't respond to issues .-., had to grep) so it's not impossible to add, but as I said I can't test so it'd be nice if someone else did that.

@arlo-phoenix pls add support for ROCM 5.7

While ROCm is annoying with having to compile for each arch, nothing should've broken between ROCm 5.6 and 5.7 (will likely happen with 6.0).

Since many people still have problems with building this: I added an /opt/rocm fallback on the rocm branch for this PR. I also updated the main branch of my fork which now includes a hipblaslt-compat header so you don't actually need hipblaslt as a lot of distros don't distribute it yet (just tried to build this on an arch system). The only reason I named my fork so specific was for that header., so these build instructions should work even before ROCm 5.6 and on a lot more systems:

git clone https://github.com/arlo-phoenix/bitsandbytes-rocm-5.6
cd bitsandbytes-rocm-5.6

ROCM_TARGET=gfx1030 make hip
pip install .

At least on my system you don't even need to set the ROCM_TARGET as it will just build it for all targets. I still recommend it for a faster build process. For finding it use

/opt/rocm/bin/rocminfo | grep gfx

or just look for your GPU under https://www.llvm.org/docs/AMDGPUUsage.html#processors

Edit: Just noticed 6.0 was already out .-., I'll update this once the official docker images are updated as well

arlo-phoenix avatar Dec 16 '23 11:12 arlo-phoenix

@arlo-phoenix rocm 6.0 docker has arrived https://hub.docker.com/r/rocm/pytorch/tags

Wintoplay avatar Dec 21 '23 03:12 Wintoplay

@arlo-phoenix rocm 6.0 docker has arrived https://hub.docker.com/r/rocm/pytorch/tags

Thanks for the info! Only tested the basic stuff and will probably only further test after the holidays, but it still compiles, 4 bit works and all optimizers also all work (at least according to pytest). So summed up ROCm 6.0 breaks nothing in this after all.

@Titus-von-Koeller, I only skimmed through #898, but from what I see the idea is to add the ability to have different backends with one of them being the current implementation now under a CudaBackend. From my perspective this won't really change this PR that much then (only gotta move some checks) since there isn't really a need for a separate backend for HIP and AMD GPU's should just use the CudaBackend as well.

  • One improvement could be moving the defines to a separate header hip-compat.h so it's better separated.
  • The Makefile definitely still needs work, as already said never worked with them directly
  • If there is a move towards a CMakeFile for Windows Support (I think there are several PR's) I could try to make this work with CMake. Should be easier to add good integration that doesn't bother Cuda compilation as I'm more experienced with that

arlo-phoenix avatar Dec 21 '23 18:12 arlo-phoenix

One thing i should note about this pr is that since it dose not support wave64 it should really refuse to compile on those, or assert at run time, right now it produces incorrect results.

all amd ai/compute focused gpus are wave64 only (ie mi25,mi50,mi100,m210 all the way to the latest mi300) its only consumer gpus newer than radeon VII that can do both wave64 and wave32 so this pr excludes the very gpus that are best suited to be used in ml.

IMbackK avatar Dec 22 '23 14:12 IMbackK

Hi, is there any update on this? Would love to have this merged!

Since there seems to now be an official plan to extend support to multiple platforms / hardware targets this will probably have to adjust if it's gonna be merged. I personally want to wait for ROCm 6.0 since that might break stuff. And even if this gets merged it would just be basic support without matrix cores as I don't have a more recent one. And even if I did, the hipblasLt project (at least according to their docs) only officially supports gfx90a, gfx94x (I assume it's gonna expand to MI300 as well, but I haven't seen anything about RDNA 3). The previously missing instructions from hipblasLt are now there (even if they don't respond to issues .-., had to grep) so it's not impossible to add, but as I said I can't test so it'd be nice if someone else did that.

@arlo-phoenix pls add support for ROCM 5.7

While ROCm is annoying with having to compile for each arch, nothing should've broken between ROCm 5.6 and 5.7 (will likely happen with 6.0).

Since many people still have problems with building this: I added an /opt/rocm fallback on the rocm branch for this PR. I also updated the main branch of my fork which now includes a hipblaslt-compat header so you don't actually need hipblaslt as a lot of distros don't distribute it yet (just tried to build this on an arch system). The only reason I named my fork so specific was for that header., so these build instructions should work even before ROCm 5.6 and on a lot more systems:

git clone https://github.com/arlo-phoenix/bitsandbytes-rocm-5.6
cd bitsandbytes-rocm-5.6

ROCM_TARGET=gfx1030 make hip
pip install .

At least on my system you don't even need to set the ROCM_TARGET as it will just build it for all targets. I still recommend it for a faster build process. For finding it use

/opt/rocm/bin/rocminfo | grep gfx

or just look for your GPU under https://www.llvm.org/docs/AMDGPUUsage.html#processors

Edit: Just noticed 6.0 was already out .-., I'll update this once the official docker images are updated as well

Hello, any news about 6.0 update?

purefire avatar Dec 28 '23 07:12 purefire

Seconding this ^ Mi300 and H100 are both battling at the moment, so would like to use my 7900xtx!

Iron-Bound avatar Dec 29 '23 01:12 Iron-Bound

@purefire

Hello, any news about 6.0 update?

Status quo is still

Thanks for the info! Only tested the basic stuff and will probably only further test after the holidays, but it still compiles, 4 bit works and all optimizers also all work (at least according to pytest). So summed up ROCm 6.0 breaks nothing in this after all.

not finetuning anything atm, but since it still compiles and tests succeed it should still work as expected. The only thing I expected to break was the makefile or some includes or defines becoming deprecated, but didn't see anything.


@Iron-Bound

Seconding this ^ Mi300 and H100 are both battling at the moment, so would like to use my 7900xtx!

7900XTX should work, it's wavefront 64 that doesn't work and 7900XTX has the normal wavefront size 32. It would not become a battle here though as this doesn't support hipblaslt yet meaning no matrix cores are used and so the 7900XTX /MI300 wouldn't perform well at all. This isn't something I can implement/test myself so someone else will need to do that. The changes shouldn't be too large, just a small python check if the hip device supports hipblaslt where gemm support is checked and adjusting the Makefile to actually use the library.


@IMbackK

One thing i should note about this pr is that since it dose not support wave64 it should really refuse to compile on those, or assert at run time, right now it produces incorrect results.

all amd ai/compute focused gpus are wave64 only (ie mi25,mi50,mi100,m210 all the way to the latest mi300) its only consumer gpus newer than radeon VII that can do both wave64 and wave32 so this pr excludes the very gpus that are best suited to be used in ml.

That's interesting, didn't find anything last time because I didn't bother into looking into large architecture description PDF's just to look for a wavefront size, but you are right. Then it's a bit more important, I assumed it was only the CDNA1 that was just supporting wavefront size 64.

I'll try to think of a good way to include them anyways. The wavefront size override won't actually affect how everything is executed, it's just that some compile time asserts are not triggered anymore (from what I remember). The define override should still be removed / only be called if something like FORCE_WAVEFRONT32 is set. I'll try to see if I can just trigger a trap in device code for the unsupported functions so it compiles. If that's actually the case it would be enough to add a one time warning with a fallback in the affected block size functions to actually use the next larger block size or throw an exception if that doesn't work. The problem with the second solution will be that most projects use the smallest BLOCK_SIZE for 4bit stuff which means e.g. https://github.com/TimDettmers/bitsandbytes/blob/f63abb5a0d0bc971d28972ba890a9e59596caac4/csrc/kernels.cu#L3976 for FP4 is called which doesn't work with the larger wavefront size of 64. So that fallback / exception if impossible with tensor size would need to go here https://github.com/TimDettmers/bitsandbytes/blob/f63abb5a0d0bc971d28972ba890a9e59596caac4/bitsandbytes/functional.py#L690

Not experienced at all with that so no idea if that even works, but if it doesn't work to just use the next larger BLOCK_SIZE we can always just throw an exception and nothing that doesn't work would be called. Same would need to be done for dequantize.


arlo-phoenix avatar Dec 30 '23 12:12 arlo-phoenix

For the makefile check we could do a basic check from the gfx version, also avoids importing anything.

Other option would be to call rocminfo or clinfo

PYTORCH_ROCM_ARCH := $(shell echo $$PYTORCH_ROCM_ARCH)

check-rocm-arch:
ifeq ($(findstring gfx1100,$(PYTORCH_ROCM_ARCH)),gfx1100)
	@echo "gfx1100 is included in PYTORCH_ROCM_ARCH."
else
	@echo "gfx1100 is not included in PYTORCH_ROCM_ARCH."
endif

Iron-Bound avatar Dec 30 '23 13:12 Iron-Bound

@arlo-phoenix if you need a wave64 device to test with, i have a mi25 im not using sitting on a shelf, i could give it to you for free. despite amds documentation gfx900/mi25 is still fully supported and working in rocm6 and is still enabled in official builds, so it would work fine as a test bed

IMbackK avatar Dec 31 '23 16:12 IMbackK

@arlo-phoenix if you need a wave64 device to test with, i have a mi25 im not using sitting on a shelf, i could give it to you for free. despite amds documentation gfx900/mi25 is still fully supported and working in rocm6 and is still enabled in official builds, so it would work fine as a test bed

First off Happy New Year.

@IMbackK Thanks for the offer, but RDNA 2 supports wave64 and even defaults to it (which is why I forced it to wave32) and I'd only need to test with a specific device to test the hacky solution of just keeping it forced so the asserts don't trigger which I don't even wanna keep.

Instead I'll likely open a separate PR (will take a while, maybe before March) next to this preprocessor solution where we can more easily add wave64 traps without preprocessor hell in the CUDA section and also look into properly supporting hipblaslt. I previously disliked this, but now thinking it over and with devices becoming separated I think a HIPDevice is the cleaner solution after all, since the HIPIFIED code would not have to be kept up to date to the CUDA solution so that it doesn't break, but we could just catch up to the Cuda features manually. This will lead to a lot of code duplication, but imo that's better for wider support. What led to this thought change: Looking deeper into llama.cpp with ROCm support it is very confusing imo and often breaks using the preprocessors. A main collaborator there even called it a mistake and since this requires different handling for wave64 this would also become quite confusing here.

For the meantime this still works and if someone has a MI100 or above it should not be impossible to add hipblaslt support yourself if you need it earlier (see earlier discussion). I sadly won't work on this for a while until then since I'm preoccupied with uni/exams. Till then the common device abstraction will likely solidify or even be merged so I can base the HIPDevice on that. I'll keep this PR open for now, but this will likely not be the one to be merged.

arlo-phoenix avatar Jan 01 '24 10:01 arlo-phoenix

@arlo-phoenix if you need a wave64 device to test with, i have a mi25 im not using sitting on a shelf, i could give it to you for free. despite amds documentation gfx900/mi25 is still fully supported and working in rocm6 and is still enabled in official builds, so it would work fine as a test bed

First off Happy New Year.

@IMbackK Thanks for the offer, but RDNA 2 supports wave64 and even defaults to it (which is why I forced it to wave32) and I'd only need to test with a specific device to test the hacky solution of just keeping it forced so the asserts don't trigger which I don't even wanna keep.

Instead I'll likely open a separate PR (will take a while, maybe before March) next to this preprocessor solution where we can more easily add wave64 traps without preprocessor hell in the CUDA section and also look into properly supporting hipblaslt. I previously disliked this, but now thinking it over and with devices becoming separated I think a HIPDevice is the cleaner solution after all, since the HIPIFIED code would not have to be kept up to date to the CUDA solution so that it doesn't break, but we could just catch up to the Cuda features manually. This will lead to a lot of code duplication, but imo that's better for wider support. What led to this thought change: Looking deeper into llama.cpp with ROCm support it is very confusing imo and often breaks using the preprocessors. A main collaborator there even called it a mistake and since this requires different handling for wave64 this would also become quite confusing here.

For the meantime this still works and if someone has a MI100 or above it should not be impossible to add hipblaslt support yourself if you need it earlier (see earlier discussion). I sadly won't work on this for a while until then since I'm preoccupied with uni/exams. Till then the common device abstraction will likely solidify or even be merged so I can base the HIPDevice on that. I'll keep this PR open for now, but this will likely not be the one to be merged.

Note that for hipblaslt you need at least gfx90a not gfx908, i do have access to mi100s but that wont help here since thats gfx908. The matrix instructions in mi100 are different than gfx90a and afaik there is little support for the gfx908 implmenentation.

IMbackK avatar Jan 01 '24 10:01 IMbackK

Thank you so much for this contribution. I am sorry that it took so long to reply and look at this. We are currently working on integrating different devices other than CUDA. We will discuss internally and will get back to you.

TimDettmers avatar Jan 02 '24 07:01 TimDettmers

Thank you so much for this contribution. I am sorry that it took so long to reply and look at this. We are currently working on integrating different devices other than CUDA. We will discuss internally and will get back to you.

No worries. This PR is still quite a bit hacky anyways, but this would be the simplest way to add HIP support. Pytorch recognizes a HIP device also as a CUDA device so the detection of a HIP device needs to be done above that.

If we go with this I still need to remove the forcing of wave32 (it can just be removed, but then would work on less devices. I believe all RDNA2 devices have issues with it) and someone else who actually knows how to work with Makefiles would need to rework my Makefile changes. I can try fixing it myself, but I don't know how to make checks like the CUDA_HOME check only execute for a list of targets. I believe that problem should come up with other ports as well.

I also don't think this is really that high of a risk besides the Makefile for current CUDA devices. The only changes that should affect CUDA devices are under include/Algo-Direct2.h, maybe the header order matters, but I doubt that and it could be verified quickly if someone with CUDA tries to compile this.


It should also be evaluated if we want to create a custom HIPDevice after all. The advantage I see is that it might be possible to use several different devices at some point, but didn't look too deep into the abstraction PR. HIPIFYING the code wouldn't take long and it might be a bother to put ROCm Capability checks in a combined .cu file once we try to support wave64 and hipblaslt. The same applies to the python part where capabilities are checked

The disadvantage is obviously lots of code duplication. We can share the non .cu/.cuh headers like Algo-Direct with small changes, but the CUDADevice in python and everything in the cuda files would pretty much just be copy pasted. It also means that once CUDA codes is updated everything would need to be hipified again or manually copied over using diffs.

A third option is adding a HIPDevice in python using the libbitsandbytes_hip_nohipblaslt.so, but using preprocessors in the library code for compiling.

I'm personally for a separate HIPDevice (if the code isn't too much, I don't think this would contain a lot but setting capabilities) and unsure between the preprocessors and hipified code. Since it might bother cuda development I slightly lean towards just hipifying it after all. With the device abstraction defining what is supported and more devices being planned I don't think it would need to be updated too frequently and would look cleaner. It would also allow optimizing for HIP devices in the future.

arlo-phoenix avatar Jan 02 '24 12:01 arlo-phoenix

Note that for hipblaslt you need at least gfx90a not gfx908, i do have access to mi100s but that wont help here since thats gfx908. The matrix instructions in mi100 are different than gfx90a and afaik there is little support for the gfx908 implmenentation.

is it like an actual hardware issue where the instructions are different? It seems weird to me that the first device with matrix cores does not actually work with the libraries. I looked through the source a bit and ASM code is here. I have no idea how it works, but there are folders for "Grid Based" for Navi31 and Navi32. I doubt they will work without the PRO variant GPU's though since the chip names for desktop RDNA have stuff like XT behind them. If the target for ROCm is the same it might work on RDNA3 though. As said won't add hipblaslt support before it's finalized what solution for integrating the AMD device is decided, but someone could try if they can run some examples.

arlo-phoenix avatar Jan 02 '24 12:01 arlo-phoenix

i don't know, but i do suspect that something is wrong with the cdna1 matrix instructions, given as far as i can tell rocm uses them no where. The navi kernels should work fine on desktop rdna, no rocm componant ever deals in marketing names, only llvm targets are used, so there is zero crippling of desktop parts aside from the firmware enforced f64 nerf on vega20/gfx906

IMbackK avatar Jan 02 '24 16:01 IMbackK

It should also be evaluated if we want to create a custom HIPDevice after all. The advantage I see is that it might be possible to use several different devices at some point, but didn't look too deep into the abstraction PR. HIPIFYING the code wouldn't take long and it might be a bother to put ROCm Capability checks in a combined .cu file once we try to support wave64 and hipblaslt. The same applies to the python part where capabilities are checked

I would like to note that adding wave64 support is a good thing for the cuda code too, while all current and past cuda devices have 32 wide waves, theoretically you should query warpSize from cudaDeviceProp since Nvidia dose consider devices with other warp sizes as a future possibility. Hipify dose also translate this just fine, so the wave32/64 handling could (should) be implemented on the cuda side.

I would also like to float the final but possibility unpopular possibility: just convert everything to hip, using the cuda hip implementation is significantly cleaner and less finicky than using hipify while providing a better set of test macros to filter for device features than cuda .

IMbackK avatar Jan 02 '24 16:01 IMbackK

I've had success with my 7900xtx /w bfloat16 acceleration, so I'd recommend we use that as a first target and the Asm here can also work untill hipblaslt is more trustable: https://github.com/tinygrad/tinygrad/blob/master/extra/assembly/assembly_rdna.py

Iron-Bound avatar Jan 09 '24 00:01 Iron-Bound

@IMbackK my latest commit tries to adjust the kQuantizeBlockwise function to work with wavefront 64 (just only use 1 load, etc. per thread instead of 2). Didn't really test much, but tests succeeded for FP16 (might cause issues with nf4,fp4). I believe test_generation.py uses the 64 BLOCK_SIZE variant as well which would trigger the adjusted code and those succeeded. Even if my code doesn't work, I believe it shouldn't be too hard to edit that function to work anyways under WAVEFRONT/warp size 64. This makes integration a lot easier and wouldn't require a bunch of ifdefs or guards in python code.

Sadly this is very hard to test as most tests don't just fail, but crash without igemm support, I can run them 1 by 1 but this is a bit annoying. I'll see if I there's a setting to pytest to just continue or if this is easily adjustable.

@Iron-Bound I found this https://github.com/ROCmSoftwarePlatform/bitsandbytes/tree/rocm_enabled linked on some huggingface docs. That fork is almost as old as mine, I wonder why they never advertised it, but they link hipblaslt and don't disable it from what I saw in the makefile. Like the output library is called libbitsandbytes_hip_nohipblaslt, but didn't actually see it disabled so it might be worth to give this a try. I also don't think tinygrad is really gonna help, I have no clue how hipblaslt works internally and don't plan on researching to replicate functions, so the only thing I can do is use functions they provide to replace the cublaslt functions.

arlo-phoenix avatar Jan 12 '24 17:01 arlo-phoenix