cccl icon indicating copy to clipboard operation
cccl copied to clipboard

cuda::std::complex division is slower than expected

Open NickKarpowicz opened this issue 3 years ago • 9 comments

@jrhemstad edit: This was originally from the Thrust repo about thrust::complex. I pointed @NickKarpowicz at cuda::std::complex as it will eventually replace thrust::complex. @NickKarpowicz reported that cuda::std::complex was even slower than thrust::complex :upside_down_face:

Hi, I noticed that the division of a (real) double precision number by a thust::complex isn't as fast as it could be. Maybe this is a case of the compiler not optimizing something it should, but there is an easy workaround. I posted about this on the nVidia forum, and they suggested I create an issue here.

It seems to do the operation "literally": first turn the double into a complex number, then divide complex/complex. This can be done more efficiently, in a way that saves a division. I pasted a simple program below that I made to isolate and test this.

Long story short: If I write a function by hand to do the operation without the additional divide, kernel calls just doing this division ~64 million times average 6.22 ms according to the profiler. Doing it with Thrust’s division operator, they take 10.97 ms on average, on a 2080 Super. On a 3060, the number are similar, 13.23 ms vs. 23.46 ms. This is compiling on Windows in Visual Studio 2022, CUDA 11.7.

So one can simply overload the / operator for a bit of a speedup, as:

__device__ thrust::complex<double> operator/(double a, thrust::complex<double> b) {
		double divByDenominator = a / (b.real() * b.real() + b.imag() * b.imag());
		return thrust::complex<double>(b.real() * divByDenominator, -b.imag() * divByDenominator);
}

I tried this with floats instead of doubles, and it doesn’t seem to matter there. If it turns out to be the case in general and not just a me thing, maybe it’s worth putting something like that explicitly in the library, or maybe the compiler is just currently missing something it shouldn’t? I'm not sure if this is a compiler issue or thrust issue, but one can most easily work around it when interacting with thrust so I put it here...

The code I used for testing is here:

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <stdio.h>
#include <thrust/complex.h>

#define TESTSIZE 64*1048576
#define THREADS_PER_BLOCK 128
#define NLAUNCHES 5

//divide the arrays using thrust standard operator
__global__ void divideWithThrust(double* x, thrust::complex<double>* y, thrust::complex<double>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	z[i] = x[i] / y[i];
}

//divide the arrays by hand
__global__ void divideDZ(double* x, thrust::complex<double>* y, thrust::complex<double>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	double divByDenominator = x[i] / (y[i].real() * y[i].real() + y[i].imag() * y[i].imag());
	z[i] = thrust::complex<double>(y[i].real() * divByDenominator, -y[i].imag() * divByDenominator);
}

//divide the arrays by explicitly turning the double into a complex double
__global__ void divideDZupcast(double* x, thrust::complex<double>* y, thrust::complex<double>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	z[i] = thrust::complex<double>(x[i], 0) / y[i];
}

//float math for comparison
__global__ void divideWithThrustFloat(float* x, thrust::complex<float>* y, thrust::complex<float>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	z[i] = x[i] / y[i];
}

//float by hand for comparison
__global__ void divideFC(float* x, thrust::complex<float>* y, thrust::complex<float>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	float divByDenominator = x[i] / (y[i].real() * y[i].real() + y[i].imag() * y[i].imag());
	z[i] = thrust::complex<float>(y[i].real() * divByDenominator, -y[i].imag() * divByDenominator);
}

//fill arrays
__global__ void initArrays(double* x, thrust::complex<double>* y) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	x[i] = sin(0.1 * i);
	y[i] = thrust::complex<double>(cos(0.2 * i), sin(0.5 * i));
}
__global__ void initArraysFloat(float* x, thrust::complex<float>* y) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	x[i] = sin(0.1 * i);
	y[i] = thrust::complex<float>(cos(0.2 * i), sin(0.5 * i));
}


int main()
{
	//first check with doubles
	double *x;
	thrust::complex<double> *y, *z;
	cudaMalloc(&x, TESTSIZE * sizeof(double));
	cudaMalloc(&y, TESTSIZE * sizeof(thrust::complex<double>));
	cudaMalloc(&z, TESTSIZE * sizeof(thrust::complex<double>));

	//divide by hand
	initArrays<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, y);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideDZ<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, y, z);
	}

	//divide with thrust
	initArrays<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, y);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideWithThrust<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, y, z);
	}

	//divide by turning double into complex explicitly
	initArrays << <TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK >> > (x, y);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideDZupcast << <TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK >> > (x, y, z);
	}

	cudaFree(x);
	cudaFree(y);
	cudaFree(z);


	//compare float division
	float *xf;
	thrust::complex<float> *yf, * zf;
	cudaMalloc(&xf, TESTSIZE * sizeof(float));
	cudaMalloc(&yf, TESTSIZE * sizeof(thrust::complex<float>));
	cudaMalloc(&zf, TESTSIZE * sizeof(thrust::complex<float>));

	initArraysFloat<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(xf, yf);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideFC<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(xf, yf, zf);
	}

	initArraysFloat<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(xf, yf);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideWithThrustFloat<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(xf, yf, zf);
	}

	cudaFree(xf);
	cudaFree(yf);
	cudaFree(zf);

	return 0;
}

NickKarpowicz avatar Sep 14 '22 09:09 NickKarpowicz

Hey @NickKarpowicz! Thanks for taking the time to write up this issue.

Would you mind trying with cuda::std::complex from libcu++?

Example: https://godbolt.org/z/jnr815TM9

Our goal is to replace the thrust::complex implementation with cuda::std::complex, we just haven't gotten around to it yet :)

I just filed https://github.com/NVIDIA/cccl/issues/819 since it doesn't look like we had an issue tracking this effort yet.

jrhemstad avatar Sep 14 '22 19:09 jrhemstad

Hi @jrhemstad, thanks for getting back to me on it. I just tried it using cuda::std::complex in the code below, and somehow it's even slower: hand-written: 6.18 ms thrust::complex: 10.97 ms cuda::std::complex: 21.27 ms

If I use float instead of double, they're all the same, 2.54 ms. The difference only seems to be how doubles are handled. Unfortunately the godbolt output for a kernel launch is a bit hard for me to parse 😅

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <thrust/complex.h>
#include <cuda/std/complex>
#define TESTSIZE 64*1048576
#define THREADS_PER_BLOCK 128
#define NLAUNCHES 5

//divide the arrays using cuda::std::complex operator
__global__ void divideWithCudaStd(double* x, cuda::std::complex<double>* y, cuda::std::complex<double>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	z[i] = x[i] / y[i];
}

//divide the arrays using thrust::complex operator
__global__ void divideWithThrust(double* x, thrust::complex<double>* y, thrust::complex<double>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	z[i] = x[i] / y[i];
}

//divide the arrays by hand
__global__ void divideDZ(double* x, cuda::std::complex<double>* y, cuda::std::complex<double>* z) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	double divByDenominator = x[i] / (y[i].real() * y[i].real() + y[i].imag() * y[i].imag());
	z[i] = cuda::std::complex<double>(y[i].real() * divByDenominator, -y[i].imag() * divByDenominator);
}

//fill arrays
__global__ void initArrays(double* x, cuda::std::complex<double>* y) {
	unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
	x[i] = sin(0.1 * i);
	y[i] = cuda::std::complex<double>(cos(0.2 * i), sin(0.5 * i));
}

int main(){
	double *x;
	cuda::std::complex<double> *y, *z;
	cudaMalloc(&x, TESTSIZE * sizeof(double));
	cudaMalloc(&y, TESTSIZE * sizeof(cuda::std::complex<double>));
	cudaMalloc(&z, TESTSIZE * sizeof(cuda::std::complex<double>));

	//divide by hand
	initArrays<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, y);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideDZ<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, y, z);
	}
 
	//divide with thrust
	initArrays<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, y);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideWithThrust<<<TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK>>>(x, (thrust::complex<double>*)y, (thrust::complex<double>*)z);
	}

	//divide with cuda::std
	initArrays << <TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK >> > (x, y);
	for (int i = 0; i < NLAUNCHES; i++) {
		divideWithCudaStd << <TESTSIZE / THREADS_PER_BLOCK, THREADS_PER_BLOCK >> > (x, y, z);
	}

	cudaFree(x);
	cudaFree(y);
	cudaFree(z);

	return 0;
}

NickKarpowicz avatar Sep 14 '22 20:09 NickKarpowicz

If I use float instead of double

~Out of curiosity, are you using libcu++ from the CTK? Or otherwise which version of libcu++ are you using?~

~There was a recent-ish issue where there was an explicit cast to double in the isnan code that ended up causing the whole code path to be done in fp64 instead of fp32, which is a lot slower.~

~See https://github.com/NVIDIA/libcudacxx/commit/ef7458cae4a28bc090675bf5a7667fbfb913b426#diff-d54e254855b5ea60264d8f60fc95514daaee3440f9245c3e52a38491e8d2247cL599~

~I don't think this ended up in CUDA 11.7, so you'd need to use libcu++ from GH.~

Edit: Nevermind. I misunderstood the direction you were going. You were testing double and it's slow, but float is fast.

jrhemstad avatar Sep 14 '22 21:09 jrhemstad

This isn't the first issue raised about the performance of cuda::std::complex :smile: See also https://github.com/NVIDIA/libcudacxx/issues/306

The implementation is just a straight port of the implementation from libc++ and has not been studied or tuned extensively and so is not likely to be speed-of-light.

The Thrust/libcu++ team are not experts on all the nuances of complex arithmetic, so we've already pinged our internal Math libraries team (who are experts on this stuff) to help us out in making sure our complex<T> implementation is speed-of-light.

We are unlikely to do anything with the thrust::complex implementation and will be focusing our efforts on the cuda::std::complex implementation.

If you don't mind, I'd like to migrate this issue to the libcu++ repo.

jrhemstad avatar Sep 14 '22 21:09 jrhemstad

Sure, it would make sense to migrate it if that's where the more active development is at the moment! Thanks for working on this

NickKarpowicz avatar Sep 14 '22 21:09 NickKarpowicz

This is something we should also investigate during NVIDIA/thrust#338

miscco avatar Feb 23 '23 09:02 miscco

This appears to still be the case in CUDA 12.8 update 1.

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:42:46_Pacific_Standard_Time_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0

On my system consisting of a 3050 Ti Laptop GPU, running the latest code above from @NickKarpowicz gives these kernel execution times in NSight Systems:

  • divideWithCudaStd: 48 ms
  • divideWithThrust: 25 ms
  • divideDZ: 16 ms

It appears that cuda::std::complex is still much slower than thrust::complex.

Is there any plan to improve the performance of cuda::std::complex even more?

cmey avatar Apr 14 '25 15:04 cmey

It appears that cuda::std::complex is still much slower than thrust::complex.

Is there any plan to improve the performance of cuda::std::complex even more?

The issue is on the conformance side of things. In the meantime we do provide LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS if you are ok with a fast but technically not 100% conforming implementation. (It all relates to handling of nan / inf inputs so its fine for common use cases)

miscco avatar Apr 14 '25 15:04 miscco

Thank you for the suggestion to use that define. A not 100% conforming implementation about nan and inf would indeed not be a problem for us.

I've compiled the same code above, this time with nvcc -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS main.cu, but the performance did not change much, if at all (the difference is probably just in the variance noise of the measurement):

  • divideWithCudaStd: 46 ms
  • divideWithThrust: 25 ms
  • divideDZ: 16 ms

Did I use LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS incorrectly?

I've also tried adding --use_fast_math to nvcc, with no measurable change to the performance.

Is there any other existing option or future plan to improve the performance of cuda::std::complex even more?

cmey avatar Apr 17 '25 08:04 cmey