mlx icon indicating copy to clipboard operation
mlx copied to clipboard

CUDA backend: unary ops

Open zcbenz opened this issue 8 months ago • 0 comments

This PR is split from https://github.com/ml-explore/mlx/pull/1983.

This PR implements unary ops for CUDA backend.

  • A cucomplex_math.cuh file is added to implement arithmetic operators for cuComplex.
  • In fp16_math.cuh there are arithmetic functions added for supporting both floats and halfs. For functions that CUDA does not provide a builtin version, halfs are converted to floats.
  • A custom iterator class general_iterator is added to allow passing strided input to thrust::transform. (thrust calls it "fancy iterator")
  • A const_param utility is added to convert std::vector to fixed-sized cuda::std::array, which is used to pass shapes and strides as CUDA does not allow passing dynamical-sized arguments to kernels. As result a hard-coded MAX_NDIM is introduced, which I set to 8, in PyTorch they set it to 25.
  • For deciding which types are supported in ops, I use a constexpr supports_unary_op function, which can prevent compilation of unsupported combinations of types and ops.

zcbenz avatar May 07 '25 07:05 zcbenz