pykan
pykan copied to clipboard
It seems like the B-spline function becomes ineffective when the input values exceed its [-1, 1] bounds.
How can we keep the intermediate values within bounds during training so that B-splines can be used during inference?
soory,I accidentally closed it.
update grid is important in training to address the bounded range problem, e.g., in training, this line
Thank you for your response.So, when we get a new sample, if it falls outside the range of the grid used during training, then the corresponding B-splines are invalid, right?
During the training process, when we set a batch, each step seems to be equivalent to an iteration.when we don't set a batch, each step is equivalent to an epoch.
During the forward propagation process, intermediate values might fall outside the range of the grid. Consequently, the B-splines seem to be invalid even though we adjusted the grid range before eval.
why we need the forward fun in update_grid_from_samples fun, it appears to have no effect. def update_grid_from_samples(self, x): ''' update grid from samples
Args:
-----
x : 2D torch.float
inputs, shape (batch, input dimension)
Returns:
--------
None
Example
-------
>>> model = KAN(width=[2,5,1], grid=5, k=3)
>>> print(model.act_fun[0].grid[0].data)
>>> x = torch.rand(100,2)*5
>>> model.update_grid_from_samples(x)
>>> print(model.act_fun[0].grid[0].data)
tensor([-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000])
tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
'''
for l in range(self.depth):
self.forward(x)
self.act_fun[l].update_grid_from_samples(self.acts[l])
When we change the range of the grid, we also need to update the coefficients. Doesn't this render our previous learning process meaningless?
Why do we need to update the range of the grid frequently? If the ranges of the numbers in two batches are inconsistent, wouldn't this cause oscillation in the grid? In that case, is the previously learned information still valid
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid: self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
This is not necessarily a bug, but I believe future variants will be better than this.