Eric Buehler
Eric Buehler
As discussed in #2361, our current argsort implementation does not work on CUDA for large vectors because we use a bitonic sort implementation, which requires shared memory. For some n...
If the vector length is high, the error `CUDA_INVALID_VALUE` is returned: ```rust use candle_core::{DType, Device, Tensor}; fn main() { let a = Tensor::zeros( 32000, DType::F32, &Device::cuda_if_available(0).unwrap(), ) .unwrap(); dbg!(&a.arg_sort_last_dim(true)); }...
Currently, the `GgmlDType` only supports F16 and not BF16. This PR introduces support for the BF16 type. I would appreciate a check if this looks good! I have tested with...
Currently, we execute a dtoh copy when dequantizing f16/f32 on CUDA when this is not necessary. We can just add a simple cast kernel to ensure that we keep the...
Benchmarks did not compile.
Motivation: The current `QTensor::quantize` quantizes the `src` tensor onto the same device as `src`. This behavior is OK for most use cases, but there is a specific condition where this...
Adds the aforementioned methods to Device. The `Device::best_device` has the same functionality as `candle_examples::device`, and this PR changes `candle_examples::device` to use `best_device`. `metal_if_available` has been added for parity with `cuda_if_available`.
Reasoning: 1) We use lots of elementwise operations: [masked_fill in every layer](https://github.com/huggingface/candle/blob/2be9bd211e34333b605695242896903231ab26da/candle-transformers/src/models/llama.rs#L328-L341), [elementwise addition and division](https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/mistral.rs#L275-L283) in our attention implementations. 2) GEMM APIs like cuBLAS's [gemm](https://docs.nvidia.com/cuda/cublas/#cublas-level-3-function-reference) provide alpha and beta...
These are a few utility functions which are often useful. Both implementations do not require operations on the CPU. I plan on following up this PR with one for bitwise...
This PR improves compat for older GPUs where the CC is less than 610. Refs #2348.