pytorch_tabular
pytorch_tabular copied to clipboard
add custom loss, optim, metrics for model_sweep
Add support for custom loss and metrics in model_sweep
Fixes #544
- Custom loss, metrics, and optimizers can now be passed to
model_sweepin the same way astabular_model.fit()throughcustom_fit_params. custom_fit_paramsexpects a dictionary specifying the custom loss, metrics, or optimizer.- Minimal code changes; fully backward compatible.
- Updated corresponding tests.
Example usage
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
def forward(self, inputs, targets):
loss = torch.mean((inputs - targets) ** 4)
return 100*loss.mean()
def custom_metric(y_hat, y):
return (y_hat - y).mean()
sweep_df, best_model = model_sweep(
task="regression",
train=train,
test=val,
data_config=data_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
model_list="lite",
custom_fit_params = {
"loss": CustomLoss(),
"metrics": [custom_metric],
"metrics_prob_inputs": [True],
"optimizer": torch.optim.Adagrad,
}
)
📚 Documentation preview 📚: https://pytorch-tabular--587.org.readthedocs.build/en/587/