pykan icon indicating copy to clipboard operation
pykan copied to clipboard

It seems like the B-spline function becomes ineffective when the input values exceed its [-1, 1] bounds.

Open zgydbc opened this issue 9 months ago • 8 comments

How can we keep the intermediate values within bounds during training so that B-splines can be used during inference?

zgydbc avatar May 09 '24 13:05 zgydbc

soory,I accidentally closed it.

zgydbc avatar May 09 '24 13:05 zgydbc

update grid is important in training to address the bounded range problem, e.g., in training, this line

KindXiaoming avatar May 09 '24 22:05 KindXiaoming

Thank you for your response.So, when we get a new sample, if it falls outside the range of the grid used during training, then the corresponding B-splines are invalid, right?

zgydbc avatar May 10 '24 02:05 zgydbc

During the training process, when we set a batch, each step seems to be equivalent to an iteration.when we don't set a batch, each step is equivalent to an epoch.

zgydbc avatar May 10 '24 02:05 zgydbc

During the forward propagation process, intermediate values might fall outside the range of the grid. Consequently, the B-splines seem to be invalid even though we adjusted the grid range before eval.

zgydbc avatar May 10 '24 02:05 zgydbc

why we need the forward fun in update_grid_from_samples fun, it appears to have no effect. def update_grid_from_samples(self, x): ''' update grid from samples

    Args:
    -----
        x : 2D torch.float
            inputs, shape (batch, input dimension)
        
    Returns:
    --------
        None
     
    Example
    -------
    >>> model = KAN(width=[2,5,1], grid=5, k=3)
    >>> print(model.act_fun[0].grid[0].data)
    >>> x = torch.rand(100,2)*5
    >>> model.update_grid_from_samples(x)
    >>> print(model.act_fun[0].grid[0].data)
    tensor([-1.0000, -0.6000, -0.2000,  0.2000,  0.6000,  1.0000])
    tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
    '''
    for l in range(self.depth):
        self.forward(x)
        self.act_fun[l].update_grid_from_samples(self.acts[l])

zgydbc avatar May 10 '24 02:05 zgydbc

When we change the range of the grid, we also need to update the coefficients. Doesn't this render our previous learning process meaningless?

zgydbc avatar May 10 '24 02:05 zgydbc

Why do we need to update the range of the grid frequently? If the ranges of the numbers in two batches are inconsistent, wouldn't this cause oscillation in the grid? In that case, is the previously learned information still valid

if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid: self.update_grid_from_samples(dataset['train_input'][train_id].to(device))

zgydbc avatar May 10 '24 03:05 zgydbc

This is not necessarily a bug, but I believe future variants will be better than this.

KindXiaoming avatar Jul 14 '24 04:07 KindXiaoming