pykan icon indicating copy to clipboard operation
pykan copied to clipboard

Request for advice on tunning of hyper parameters

Open zhongjingjogy opened this issue 9 months ago • 1 comments

I am trying to tunning the hyperparameters with optuna. Accordingly, the objective for optimization should be assigned. The AICc is currently considered, though advice about how to reasonably and conveniently evaluate the corresponding model complexity. For example, how to count, e.g., the number of model parameters being used. A roadmap to do this is present as following, and I would be greatly appreciated if any instructive suggestions are available.

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
import numpy as np
import optuna

def f(x):
    return torch.where(x > 0.5, torch.sin(20.0*x), np.sin(20.0*0.5))

dataset = create_dataset(f, n_var=1, ranges=[0,1], train_num=2000, test_num=2000)
print("Dataset created")

def create_model(trial):
    neuron_num = trial["neuron_num"]
    grid = trial["grid"]
    k = trial["k"]
    model = KAN(width=[1, neuron_num, 1], grid=grid, k=k, seed=0)

    return model

def objective(trial):
    neruon_num = trial.suggest_int("neuron_num", 1, 10)
    grid = trial.suggest_int("grid", 2, 10)
    k = trial.suggest_int("k", 1, 3)

    model = create_model(trial.params)
    # train the model
    model.train(dataset, opt="Adam", steps=500)

    train_inputs = dataset['train_input']
    train_predictions = model(train_inputs)
    train_score = torch.mean((dataset['train_label'] - train_predictions) ** 2).item()

    test_inputs = dataset['test_input']
    test_predictions = model(test_inputs)
    test_score = torch.mean((dataset['test_label'] - test_predictions) ** 2).item()

    # TODO
    # 1. Calculate the number of parameters used in the model
    # 2. Calculate the RSS of the model
    # 3. Calculate the AIC or AICc of the model
    # 4. Return the AIC or AICc as the objective value

    return train_score * 0.5 + test_score * 0.5

study = optuna.create_study(direction="minimize", study_name="kan", storage="sqlite:///kan.db", load_if_exists=True)
study.optimize(objective, n_trials=20)

print(study.best_params)

zhongjingjogy avatar May 22 '24 15:05 zhongjingjogy