mlx
mlx copied to clipboard
CUDA backend: unary ops
This PR is split from https://github.com/ml-explore/mlx/pull/1983.
This PR implements unary ops for CUDA backend.
- A
cucomplex_math.cuhfile is added to implement arithmetic operators for cuComplex. - In
fp16_math.cuhthere are arithmetic functions added for supporting both floats and halfs. For functions that CUDA does not provide a builtin version, halfs are converted to floats. - A custom iterator class
general_iteratoris added to allow passing strided input tothrust::transform. (thrust calls it "fancy iterator") - A
const_paramutility is added to convertstd::vectorto fixed-sizedcuda::std::array, which is used to pass shapes and strides as CUDA does not allow passing dynamical-sized arguments to kernels. As result a hard-codedMAX_NDIMis introduced, which I set to 8, in PyTorch they set it to 25. - For deciding which types are supported in ops, I use a constexpr
supports_unary_opfunction, which can prevent compilation of unsupported combinations of types and ops.