Rename the train method to fit to avoid confusion with PyTorch's built-in train method.
The train method in the KAN class is used to define the main training logic for the model. However, in PyTorch, the name "train" is typically used for switching the model between training and evaluation states, such as model.train(). Also, in this implementation of code, if ther user calls the model.train() explicitly, it would raise error. To avoid confusion and improve code clarity, this PR renames the train method to fit.
The fit method better conveys the purpose of the function, which is to train the model using the provided dataset and hyperparameters. This change ensures that the method name does not clash with PyTorch's built-in functionality and makes it clear that it is a custom model training function.
The renaming is a minor change and does not affect the functionality of the method. All occurrences of train within the method have been replaced with fit, and any calls to the method in other parts of the codebase have been updated accordingly.
This change improves the readability and maintainability of the codebase by following a more consistent and intuitive naming convention.
been facing issues loading setting up the GPU as well
been facing issues loading setting up the GPU as well
what do you mean?
tried loading the device file on colab didn't run
by when do u think they'll approve on the model.fit change?
Not sure about this, but you can patch this function as a temporary solution like(this function should be called before you use the kan lib):
from kan import KAN
def patch_kan_train_function():
def patch_train(*args, **kwargs):
pass
KAN.train = patch_train
and then use it as model.fit moving on right?
and then use it as model.fit moving on right?
In my specific use case, the KAN model is integrated as a component within my larger neural network model. I don't directly utilize the train method provided by the KAN class. Instead, I opt for a more customized approach, explicitly managing the training process using PyTorch's loss.backward() and optimizer.step() functions.
class MyNet(nn.Module):
def __init__(self):
...
self.output_module = KAN(...) # Integrate KAN as a component
# Training loop (custom, explicit control)
for train_batch in train_dataloader:
...
loss.backward() # Compute gradients
optim.step() # Update model parameters
...
If you prefer to use the standard Keras model.fit() API, you would typically rename your custom train function to fit. However, this isn't necessary in my current approach as I'm directly controlling the training loop.
This is me trying to run this code on Google Colab with their device example
Their Code Example: https://github.com/KindXiaoming/pykan/blob/master/tutorials/API_10_device.ipynb
@KindXiaoming Could you please review this? I think this PR is somehow important for those who want to integrate the KAN into their modules.
Hi @linkedlist771 , thank you for your message. As I explained in an issue (forgot which), tutorials use train() all the time, so it would be too confusing to switch to another API. Also, one can say this is even deliberate to be designed as such, since you need to manually update grid, not just plug-and-play (will improve the error message though). I do understand that users want plug-and-play, I think any KAN variants that do not require grid update can be safely and directly plug-and-play.
Could u also check the gpu issue @KindXiaoming
hi @ChrisD-7, can you pull the lastest version and see if gpu issue still persists. Apple GPU (MPS) seems not solved yet, but I don't see more issues regarding cuda.
Hi @linkedlist771 , thank you for your message. As I explained in an issue (forgot which), tutorials use train() all the time, so it would be too confusing to switch to another API. Also, one can say this is even deliberate to be designed as such, since you need to manually update grid, not just plug-and-play (will improve the error message though). I do understand that users want plug-and-play, I think any KAN variants that do not require grid update can be safely and directly plug-and-play.
This concern is valid, but for plug-and-play scenarios, I believe a warning would be sufficient instead of raising an exception. It might complicate the train function if we check the invoking function to determine the user's intent (whether to switch to the model's training mode or simply train the model). However, if necessary, I can create a slightly more complex train function to handle this issue.
@KindXiaoming tried pulling it again is it something with the colab env issue? I'm trying to run it for a classification model and faced a RAM issue thought of loading the GPU and faced this issue.
Closed as completed.