pykan icon indicating copy to clipboard operation
pykan copied to clipboard

Trying a coef2curve example

Open drdozer opened this issue 8 months ago • 1 comments

Hi. Thanks for making the KAN code available. I've been trying to understand how it all works, and was trying to run coef2curve to get an idea of how the splines are calculated. However, I couldn't get it to run.

def exp_sin(x):
  return np.exp(np.sin(math.pi * x))

train_xs = np.arange(-5,5,0.25)
train_ys = exp_sin(train_xs)

t_xs = torch.reshape(torch.from_numpy(train_xs), (1,len(train_xs)))

num_spline = 1
num_sample = len(train_xs)
num_grid_interval = 10
k = 3
grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k))

print(f"Shape of t_xs: {t_xs.shape}")
print(f"Shape of grids: {grids.shape}")
print(f"Shape of coef: {coef.shape}")


t_ys = coef2curve(t_xs, grids, coef, k=k)

This barfs with an error related to shapes:

Shape of t_xs: torch.Size([1, 40])
Shape of grids: torch.Size([1, 11])
Shape of coef: torch.Size([1, 13])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-45-2b0233d6ec51>](https://localhost:8080/#) in <cell line: 16>()
     14 
     15 
---> 16 t_ys = coef2curve(t_xs, grids, coef, k=k)

1 frames
[<ipython-input-27-aae4645f44ba>](https://localhost:8080/#) in coef2curve(x_eval, grid, coef, k, device)
     90     if coef.dtype != x_eval.dtype:
     91         coef = coef.to(x_eval.dtype)
---> 92     y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
     93     return y_eval
     94 

[/usr/local/lib/python3.10/dist-packages/torch/functional.py](https://localhost:8080/#) in einsum(*args)
    383         # the path for contracting 0 or 1 time(s) is already optimized
    384         # or the user has disabled using opt_einsum
--> 385         return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    386 
    387     path = None

RuntimeError: einsum(): subscript j has size 7 for operand 1 which does not broadcast with previously seen size 13

I don't understand the code well enough to grasp what's going wrong here. It seems like some mismatch between the dimensionality of the grid and coefficients? I think I've stuck close to a copy-paste of your example, so unsure why it doesn't work.

drdozer avatar Jun 18 '24 11:06 drdozer