In KAN.train method, how batch size works?
Hi.
I wondered whether KAN.train() method works as we expected in deep learning scene. In the code, I think the model seems to be trained only in the small size of samples within overall train set, not trained with overall train set. Any comments may be appreciated!
` for _ in pbar:
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
if opt == "LBFGS":
optimizer.step(closure)
if opt == "Adam":
pred = self.forward(dataset['train_input'][train_id].to(device))
if sglr_avoid == True:
id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
else:
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
reg_ = reg(self.acts_scale)
loss = train_loss + lamb * reg_
optimizer.zero_grad()
loss.backward()
optimizer.step()`
Hi, by default batch=-1 means the full dataset will be used. if you want batch 128, set batch=128.
@KindXiaoming Hi! Thank you for your kind reply. What I mean in the question is the that if we set batch=-1, we can train on whole train set, which is desired results. Unless, we train with only subset of whole train set which is size of batch. Then how can we train the model with large size of dataset? Is it enough to use dataloader as we do in deep learning scene?
@stupidnubnubnub Hello, I'm having the same problem as you and I was wondering if you've solved the problem of using KAN to train the model in batches in a large dataset? I apologize for my not so good English.
@stupidnubnubnub @Papillon-forest Same problem. I haven't checked the code yet, also didn't try to train on GPU. Maybe batch=128 on GPU works correct?
Same question, I would like to ask how to train multiple epochs, is the default setting epoch=1?