pykan
pykan copied to clipboard
How to support complex data type?
Hi. It's a great project. I want to use KAN in radar signal processing domain. As you know that the radar signal is complex number. When I create a dataset with complex data and try to train KAN it report errors as below:
Traceback (most recent call last):
File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/yantao/awork/pykan/apps/fmcw/fmcw_app.py", line 39, in <module>
main(args=args)
File "/home/yantao/awork/pykan/apps/fmcw/fmcw_app.py", line 26, in main
FmcwApp.startup(args=args)
File "/home/yantao/awork/pykan/apps/fmcw/fmcw_app.py", line 23, in startup
model.train(dataset, opt="LBFGS", steps=20, lamb=0.001, lamb_entropy=2.)
File "/home/yantao/awork/pykan/kan/KAN.py", line 896, in train
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
File "/home/yantao/awork/pykan/kan/KAN.py", line 241, in update_grid_from_samples
self.forward(x)
File "/home/yantao/awork/pykan/kan/KAN.py", line 309, in forward
x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yantao/software/miniforge3/envs/kan/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yantao/awork/pykan/kan/KANLayer.py", line 173, in forward
y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k, device=self.device) # shape (size, batch)
File "/home/yantao/awork/pykan/kan/spline.py", line 100, in coef2curve
y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
File "/home/yantao/awork/pykan/kan/spline.py", line 59, in B_batch
B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
File "/home/yantao/awork/pykan/kan/spline.py", line 59, in B_batch
B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
File "/home/yantao/awork/pykan/kan/spline.py", line 59, in B_batch
B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
File "/home/yantao/awork/pykan/kan/spline.py", line 57, in B_batch
value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
RuntimeError: "ge_cpu" not implemented for 'ComplexDouble'
How to solve this problem? I prepare the development environment by using pip install -r requirements. I notice that the pytorch is CPUversion. What about I switch to GPU version Pytorch? Can it solve this problem?
Hi, unfortunately pykan doesn't support complex numbers. However, you may try treating real and imaginary parts separately, so you end up feeding KAN with vectors of doubled length.
pykan doesn't support GPU out-of-the-box, so I'd suggest debugging (with small-scale datasets) on cpus first.
Is there any particular reason in your domain that you don't want to separate real and imaginary parts? I know in some applications people want to constrain to holomorphic functions where complex neural networks are favored. In other cases, I don't see a strong reason not to just treat real and imaginary separately and feed into real-valued neural networks (which are much more optimized than complex-valued ones).
Something like Earth gravity fields required N summation of spherical harmonics. However, I currently do not figure out (1) how to add multiple spherical harmonics base functions as symbolic function. (2) how to apply complex numbers.
@plyu3 thanks for the feedback. Currently it's not supported but sounds like something worth being included in the future.