pykan icon indicating copy to clipboard operation
pykan copied to clipboard

Rename the train method to fit to avoid confusion with PyTorch's built-in train method.

Open linkedlist771 opened this issue 1 year ago • 13 comments

The train method in the KAN class is used to define the main training logic for the model. However, in PyTorch, the name "train" is typically used for switching the model between training and evaluation states, such as model.train(). Also, in this implementation of code, if ther user calls the model.train() explicitly, it would raise error. To avoid confusion and improve code clarity, this PR renames the train method to fit.

The fit method better conveys the purpose of the function, which is to train the model using the provided dataset and hyperparameters. This change ensures that the method name does not clash with PyTorch's built-in functionality and makes it clear that it is a custom model training function.

The renaming is a minor change and does not affect the functionality of the method. All occurrences of train within the method have been replaced with fit, and any calls to the method in other parts of the codebase have been updated accordingly.

This change improves the readability and maintainability of the codebase by following a more consistent and intuitive naming convention.

linkedlist771 avatar May 16 '24 10:05 linkedlist771

been facing issues loading setting up the GPU as well

ChrisD-7 avatar May 16 '24 13:05 ChrisD-7

been facing issues loading setting up the GPU as well

what do you mean?

linkedlist771 avatar May 16 '24 14:05 linkedlist771

tried loading the device file on colab didn't run

by when do u think they'll approve on the model.fit change?

ChrisD-7 avatar May 16 '24 15:05 ChrisD-7

Not sure about this, but you can patch this function as a temporary solution like(this function should be called before you use the kan lib):

from kan import KAN
def patch_kan_train_function():
    def patch_train(*args, **kwargs):
        pass

    KAN.train = patch_train

linkedlist771 avatar May 16 '24 15:05 linkedlist771

and then use it as model.fit moving on right?

ChrisD-7 avatar May 16 '24 15:05 ChrisD-7

and then use it as model.fit moving on right?

In my specific use case, the KAN model is integrated as a component within my larger neural network model. I don't directly utilize the train method provided by the KAN class. Instead, I opt for a more customized approach, explicitly managing the training process using PyTorch's loss.backward() and optimizer.step() functions.

class MyNet(nn.Module):
    def __init__(self):
        ...
        self.output_module = KAN(...)  # Integrate KAN as a component

# Training loop (custom, explicit control)
for train_batch in train_dataloader:
    ... 
    loss.backward()  # Compute gradients
    optim.step()     # Update model parameters
    ...

If you prefer to use the standard Keras model.fit() API, you would typically rename your custom train function to fit. However, this isn't necessary in my current approach as I'm directly controlling the training loop.

linkedlist771 avatar May 16 '24 15:05 linkedlist771

This is me trying to run this code on Google Colab with their device example image

Their Code Example: https://github.com/KindXiaoming/pykan/blob/master/tutorials/API_10_device.ipynb

ChrisD-7 avatar May 18 '24 17:05 ChrisD-7

@KindXiaoming Could you please review this? I think this PR is somehow important for those who want to integrate the KAN into their modules.

linkedlist771 avatar May 20 '24 06:05 linkedlist771

Hi @linkedlist771 , thank you for your message. As I explained in an issue (forgot which), tutorials use train() all the time, so it would be too confusing to switch to another API. Also, one can say this is even deliberate to be designed as such, since you need to manually update grid, not just plug-and-play (will improve the error message though). I do understand that users want plug-and-play, I think any KAN variants that do not require grid update can be safely and directly plug-and-play.

KindXiaoming avatar May 20 '24 12:05 KindXiaoming

Could u also check the gpu issue @KindXiaoming

ChrisD-7 avatar May 20 '24 12:05 ChrisD-7

hi @ChrisD-7, can you pull the lastest version and see if gpu issue still persists. Apple GPU (MPS) seems not solved yet, but I don't see more issues regarding cuda.

KindXiaoming avatar May 20 '24 13:05 KindXiaoming

Hi @linkedlist771 , thank you for your message. As I explained in an issue (forgot which), tutorials use train() all the time, so it would be too confusing to switch to another API. Also, one can say this is even deliberate to be designed as such, since you need to manually update grid, not just plug-and-play (will improve the error message though). I do understand that users want plug-and-play, I think any KAN variants that do not require grid update can be safely and directly plug-and-play.

This concern is valid, but for plug-and-play scenarios, I believe a warning would be sufficient instead of raising an exception. It might complicate the train function if we check the invoking function to determine the user's intent (whether to switch to the model's training mode or simply train the model). However, if necessary, I can create a slightly more complex train function to handle this issue.

linkedlist771 avatar May 20 '24 13:05 linkedlist771

@KindXiaoming tried pulling it again is it something with the colab env issue? I'm trying to run it for a classification model and faced a RAM issue thought of loading the GPU and faced this issue.

image

ChrisD-7 avatar May 20 '24 15:05 ChrisD-7

Closed as completed.

linkedlist771 avatar Jun 17 '24 14:06 linkedlist771