pykan
pykan copied to clipboard
Trying a coef2curve example
Hi. Thanks for making the KAN code available. I've been trying to understand how it all works, and was trying to run coef2curve to get an idea of how the splines are calculated. However, I couldn't get it to run.
def exp_sin(x):
return np.exp(np.sin(math.pi * x))
train_xs = np.arange(-5,5,0.25)
train_ys = exp_sin(train_xs)
t_xs = torch.reshape(torch.from_numpy(train_xs), (1,len(train_xs)))
num_spline = 1
num_sample = len(train_xs)
num_grid_interval = 10
k = 3
grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k))
print(f"Shape of t_xs: {t_xs.shape}")
print(f"Shape of grids: {grids.shape}")
print(f"Shape of coef: {coef.shape}")
t_ys = coef2curve(t_xs, grids, coef, k=k)
This barfs with an error related to shapes:
Shape of t_xs: torch.Size([1, 40])
Shape of grids: torch.Size([1, 11])
Shape of coef: torch.Size([1, 13])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-45-2b0233d6ec51>](https://localhost:8080/#) in <cell line: 16>()
14
15
---> 16 t_ys = coef2curve(t_xs, grids, coef, k=k)
1 frames
[<ipython-input-27-aae4645f44ba>](https://localhost:8080/#) in coef2curve(x_eval, grid, coef, k, device)
90 if coef.dtype != x_eval.dtype:
91 coef = coef.to(x_eval.dtype)
---> 92 y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
93 return y_eval
94
[/usr/local/lib/python3.10/dist-packages/torch/functional.py](https://localhost:8080/#) in einsum(*args)
383 # the path for contracting 0 or 1 time(s) is already optimized
384 # or the user has disabled using opt_einsum
--> 385 return _VF.einsum(equation, operands) # type: ignore[attr-defined]
386
387 path = None
RuntimeError: einsum(): subscript j has size 7 for operand 1 which does not broadcast with previously seen size 13
I don't understand the code well enough to grasp what's going wrong here. It seems like some mismatch between the dimensionality of the grid and coefficients? I think I've stuck close to a copy-paste of your example, so unsure why it doesn't work.