pykan icon indicating copy to clipboard operation
pykan copied to clipboard

Dynamic Spline Fixing for Critical Inputs During Iterative Training in KAN Models

Open andrewrgarcia opened this issue 5 months ago • 0 comments

Enhancement Request: Dynamic Spline Fixing for Critical Inputs During Iterative Training in KAN Models

Description:

The pyKAN framework is powerful for iterative training, pruning, and refinement of Kolmogorov-Arnold Networks (KANs). However, there is a need to fix the spline constants and values associated with specific inputs at various stages of training while allowing the rest of the network to continue adapting. This functionality is crucial in scenarios where critical inputs should maintain their learned spline representations while the rest of the model remains flexible for further refinement.

Problem Statement:

During the training of deep KAN models, specific inputs (e.g., x1, x2) may be identified as critical at various stages. Once their corresponding splines are learned, these splines should be fixed to prevent further alteration in subsequent training phases. However, the rest of the network, including other inputs and splines, should remain adaptable to ensure the model can continue refining its overall performance.

Proposed Solution:

  • Introduce a feature that allows users to fix spline constants for specific inputs during intermediate training phases.
  • This feature should enable users to lock the learned splines for these inputs at any stage while still allowing the model to prune, retrain, and adapt other parts of the network.
  • The solution should integrate smoothly with existing KAN functionality, such as pruning and continuing training after spline fixing.

Example Usage:

# Initial training phase
results = model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.)
plot_training_curves(results, filename='curve_first_train.pdf')

# Fix spline values for specific inputs x1 and x2 after initial training
fixed_inputs = [0, 1]  # Indices for x1 and x2
model.fix_splines_for_inputs(fixed_inputs)

# Intermediate pruning phase
model.prune(node_th=1e-1)
model.plot()

# Further training phase while keeping fixed splines unchanged
results = model.fit(dataset, opt="LBFGS", steps=50)
plot_training_curves(results, filename='curve_post_prune_train.pdf')
model.plot()

# Additional spline fixing if necessary after intermediate training
additional_fixed_inputs = [2, 3]
model.fix_splines_for_inputs(additional_fixed_inputs)

# Continue training with newly fixed splines
results = model.fit(dataset, opt="LBFGS", steps=30)
plot_training_curves(results, filename='curve_final_train.pdf')
model.plot()

Benefits:

  • Incremental Model Refinement: Allows for incremental refinement of the model while preserving the critical learned splines at various stages of training.
  • Flexibility in Training: Ensures that critical inputs maintain their influence while the rest of the model remains flexible for continued learning.
  • Advanced Control: Provides advanced control over the training process, enabling more sophisticated modeling techniques that leverage the iterative and flexible nature of KANs.

Request:

Could this enhancement be implemented to support more flexible and controlled training processes in KAN models?

andrewrgarcia avatar Aug 27 '24 19:08 andrewrgarcia