TALENT icon indicating copy to clipboard operation
TALENT copied to clipboard

Regression Models Underperform Compared to Classification in Benchmark

Open DanilaEremenko opened this issue 4 months ago • 4 comments

Hello,

I’m using the following wrapper for models from your repository in my benchmark. When I compare the rankings on classification and regression datasets separately, I observe a significant drop in the accuracy of your models on regression datasets relative to other models in the benchmark — even though the scale of the metrics does not differ significantly from models like CatBoost.

As you can see from the code, I took into account that in your implementation, y is normalized for regression tasks, but no reverse normalization is applied. To address this, I perform reverse normalization using the parameters stored in your models under self.method.y_info.

Could you please help me understand what might be causing this issue?

import dataclasses
import json
from pathlib import Path
from typing import List, Optional

import numpy as np
import pandas as pd
import torch
from optuna import Trial

import sys

from TALENT.model.lib.data import get_dataset
from TALENT.model.utils import get_deep_args, get_classical_args, set_seeds, get_method, merge_sampled_parameters, \
    sample_parameters


def init_lamda_ds(dsets_dir: Path, ds_name: str, x_train, y_train, x_val, y_val, x_test, y_test, c_mask,
                  task_type: str, n_classes: int):
    dsets_dir.mkdir(exist_ok=True, parents=True)
    assert dsets_dir.exists(), dsets_dir
    ds_dir = dsets_dir.joinpath(ds_name)

    ds_dir.mkdir(exist_ok=True, parents=True)

    info = {
        "name": ds_name,
        "n_num_features": int(len(c_mask) - sum(c_mask)),
        "n_cat_features": int(sum(c_mask)),
        "train_size": len(x_train),
        "val_size": len(x_val),
        "test_size": len(x_test),
        "source": None,
        "task_intro": None,
        "task_type": task_type,
        "openml_id": None,
        "n_classes": n_classes,
        "num_feature_intro": {f"x{i + 1}": f"x{i + 1}" for i, c in enumerate(c_mask) if not c},
        "cat_feature_intro": {f"x{i + 1}": f"x{i + 1}" for i, c in enumerate(c_mask) if c}
    }

    with open(ds_dir.joinpath('info.json'), 'w') as fp:
        json.dump(obj=info, fp=fp)

    for pref, dsets in {
        'N': (x_train[:, ~c_mask], x_val[:, ~c_mask], x_test[:, ~c_mask]),
        'C': (x_train[:, c_mask], x_val[:, c_mask], x_test[:, c_mask]),
        'y': (y_train, y_val, y_test)
    }.items():
        if len(dsets[0].flatten()):
            np.save(file=ds_dir.joinpath(f'{pref}_train.npy'), arr=dsets[0])
            np.save(file=ds_dir.joinpath(f'{pref}_val.npy'), arr=dsets[1])
            np.save(file=ds_dir.joinpath(f'{pref}_test.npy'), arr=dsets[2])


def get_lamda_opt_and_config_trial(trial: Trial, model_type: str, opt_space: dict):
    config = {}
    try:
        opt_space[model_type]['training']['n_bins'] = [
            "int",
            2,
            256
        ]
    except:
        opt_space[model_type]['fit']['n_bins'] = [
            "int",
            2,
            256
        ]
    sampled_parameters = sample_parameters(trial, opt_space[model_type], config)
    merge_sampled_parameters(config, sampled_parameters)

    if model_type in ['resnet']:
        config['model']['activation'] = 'relu'
        config['model']['normalization'] = 'batchnorm'

    if model_type in ['tabr']:
        config['model']["num_embeddings"].setdefault('type', 'PLREmbeddings')
        config['model']["num_embeddings"].setdefault('lite', True)
        config['model'].setdefault('d_multiplier', 2.0)
        config['model'].setdefault('mixer_normalization', 'auto')
        config['model'].setdefault('dropout1', 0.0)
        config['model'].setdefault('normalization', "LayerNorm")
        config['model'].setdefault('activation', "ReLU")

    if model_type in ['mlp_plr']:
        config['model']["num_embeddings"].setdefault('type', 'PLREmbeddings')
        config['model']["num_embeddings"].setdefault('lite', True)

    if model_type in ['modernNCA', 'tabm']:
        config['model']["num_embeddings"].setdefault('type', 'PLREmbeddings')
        config['model']["num_embeddings"].setdefault('lite', True)

    if model_type in ['tabm']:
        config['model']['backbone'].setdefault('type', 'MLP')
        config['model'].setdefault("arch_type", "tabm")
        config['model'].setdefault("k", 32)

    return config


@dataclasses.dataclass
class LamdaLow:
    model_type: str  # modernNCA, tabr
    cat_policy: str  # tabr-ohe
    ds_name: str  # directories in LAMDA-TALENT/data
    lamda_path: str
    task_type: str
    n_classes: int
    inverse_norm_y: bool
    dsets_dir: Path
    model_dir: Path
    model_family: str
    hparams: Optional[dict] = None
    seed = 42

    def __post_init__(self):
        sys.argv = []
        sys.argv.extend(['--tune'])
        sys.argv.extend(['--model_type', self.model_type])
        sys.argv.extend(['--cat_policy', self.cat_policy])
        sys.argv.extend(['--dataset', str(self.ds_name)])
        sys.argv.extend(['--dataset_path', str(self.dsets_dir)])
        sys.argv.extend(['--model_path', str(self.model_dir)])

        if self.model_family == 'deep':
            self.args, self.default_para, self.opt_space = get_deep_args()
        elif self.model_family == 'classical':
            self.args, self.default_para, self.opt_space = get_classical_args()
        else:
            raise ValueError(f"Undefined {self.model_family}")

    def fit(self, train_val_data, info):
        assert self.hparams is not None
        self.args.config = self.hparams
        self.args.seed = self.seed
        set_seeds(self.args.seed)

        self.method = get_method(self.args.model_type)(
            self.args,
            info['task_type'] == 'regression'
        )

        self._time_cost = self.method.fit(
            train_val_data,
            info,
            train=True,
            config=self.args.config
        )

    def predict(self, x_num, x_cat, info):
        test_data = (
            {'test': x_num},
            {'test': x_cat} if len(x_cat.flatten()) else None,
            {
                'test': np.random.randint(low=0, high=self.n_classes, size=len(x_num)) if self.task_type != 'regression'
                else np.random.random(size=len(x_num))
            }
        )
        if self.model_family == 'deep':
            _, vres, metric_name, y_pred = self.method.predict(
                test_data,
                info,
                model_name=self.args.evaluate_option
            )
            self.metrics_d = dict(zip(metric_name, vres, strict=True))
        elif self.model_family == 'classical':
            _, _, y_pred = self.method.predict(
                test_data,
                info,
                model_name=self.args.evaluate_option
            )
        else:
            raise ValueError(f"Undefined {self.model_family}")

        if self.inverse_norm_y and self.task_type == 'regression':
            assert self.method.y_info['policy'] == 'mean_std', self.method.y_info['policy']
            mean = self.method.y_info['mean']
            std = self.method.y_info['std']
            return y_pred * std + mean
        else:
            return y_pred


@dataclasses.dataclass
class LamdaHigh:
    model_type: str  # modernNCA, tabr
    cat_ids: List[int]
    tmp_ds_dir: Path
    talent_dir: Path
    task_type: str
    n_classes: int
    inverse_norm_y: bool
    model: Optional[LamdaLow] = None
    info: Optional[dict] = None
    hparams: Optional[dict] = None
    seed = 42

    @staticmethod
    def cast_arr(arr) -> np.ndarray:
        if isinstance(arr, pd.DataFrame):
            return arr.to_numpy().astype(np.float64)
        elif isinstance(arr, np.ndarray):
            return arr.astype(np.float64)
        else:
            raise ValueError(f"Undefined type  = {type(arr)}")

    def fit(self, x_train_df, y_train_df, x_val_df, y_val_df, actually_fit=True):
        x_train = self.cast_arr(x_train_df)
        y_train = self.cast_arr(y_train_df)
        x_val = self.cast_arr(x_val_df)
        y_val = self.cast_arr(y_val_df)

        c_mask = np.array([i in self.cat_ids for i in range(x_train.shape[1])])
        ds_name = 'tmp_lamda_ds'

        init_lamda_ds(
            dsets_dir=self.tmp_ds_dir,
            ds_name=ds_name,
            x_train=x_train, y_train=y_train,
            x_val=x_val, y_val=y_val,
            x_test=x_val, y_test=y_val,
            c_mask=c_mask,
            task_type=self.task_type,
            n_classes=self.n_classes
        )
        deep_models = ['modernNCA', 'tabr', 'tabnet', 'mlp', 'mlp_plr']
        if self.model_type in deep_models:
            model_family = 'deep'
            cat_policy = 'tabr_ohe'
        elif self.model_type in ['lightgbm', 'xgboost', 'RandomForest']:
            model_family = 'classical'
            cat_policy = 'ohe'
        elif self.model_type in ['catboost']:
            model_family = 'classical'
            cat_policy = 'indices'
        else:
            model_family = 'deep'
            cat_policy = 'indices'

        self.model = LamdaLow(
            model_type=self.model_type,
            cat_policy=cat_policy,
            model_dir=self.tmp_ds_dir.joinpath('tmp_lamda_model'),
            dsets_dir=self.tmp_ds_dir,
            ds_name=ds_name,
            lamda_path=str(self.talent_dir),
            task_type=self.task_type,
            n_classes=self.n_classes,
            hparams=self.hparams,
            inverse_norm_y=self.inverse_norm_y,
            model_family=model_family
        )
        train_val_data, test_data, self.info = get_dataset(self.model.args.dataset, self.model.args.dataset_path)

        if actually_fit:
            self.model.fit(train_val_data=train_val_data, info=self.info)
            if self.model_type in deep_models:
                self._best_epoch = self.model.method.trlog['best_epoch']

    def predict(self, X: pd.DataFrame):
        X = self.cast_arr(X)
        c_mask = np.array([i in self.cat_ids for i in range(X.shape[1])])
        y_pred = self.model.predict(x_num=X[:, ~c_mask], x_cat=X[:, c_mask], info=self.info)
        if isinstance(y_pred, torch.Tensor):
            return y_pred.detach().cpu().numpy()
        elif isinstance(y_pred, np.ndarray):
            return y_pred
        else:
            raise TypeError(f"Undefined type = {type(y_pred)}")

    def predict_proba(self, X: pd.DataFrame):
        logits = self.predict(X)
        if self.model_type in [
            'lightgbm', 'xgboost', 'catboost', 'RandomForest',
            'realmlp'
        ]:  # already probs but not always accurate
            return logits / logits.sum(axis=1, keepdims=1)
        else:
            return torch.nn.functional.softmax(torch.tensor(logits), dim=1).detach().cpu().numpy()


DanilaEremenko avatar Aug 18 '25 11:08 DanilaEremenko

Hello,

Thank you for your detailed feedback and for using our toolbox.

Regarding the performance difference you observed: If you're finding that ModernNCA and similar models perform slightly worse than CatBoost on regression tasks, this aligns with our own observations as well (as shown in our regression benchmarks: regression result.png). CatBoost does indeed demonstrate strong performance on regression tasks.

A few points that might help:

  1. The primary contributor for ModernNCA is currently completing an industry internship, but he will be able to provide more specific improvement suggestions once available.
  2. From my personal experience, reducing dropout usage in regression tasks can sometimes improve performance (as discussed in this Kaggle thread). We've actually already implemented this approach in our codebase (see here), but further tuning might yield additional improvements.

We appreciate your interest in our work. Please feel free to share any additional findings or questions you might have.

Best regards

6sy666 avatar Aug 18 '25 11:08 6sy666

@6sy666

I am reporting a significant drop in accuracy for regression task.

I am including the model vs. model game matrices on the test sets for the classification task (Figure 1) and the regression task (Figure 2).

I initially thought the issue might be related to standardization. However, as you can see, the problem does not occur in the regression task when using realmlp (also taken from your repository).

Analyzing the models' metrics on the training sets (Figure 3), the MNCA model shows the highest tendency to overfit and underperform on the test set. This issue is not observed with other models, including realmlp or tabm.

The datasets were also taken from your repository – they are relatively small, with a total number of samples not exceeding 5,000.

Could you please help me understand what might be causing this behavior in the MNCA model? Any insights or suggestions would be greatly appreciated.

Thank you for your time and for maintaining this repository.

Image Image Image Image

DanilaEremenko avatar Nov 18 '25 21:11 DanilaEremenko

Hi,

Thank you for your feedback and for sharing the results across the different methods! These results are unexpected, as we did not observe ModernNCA performing significantly worse than other methods on regression datasets during our own experiments.

Could you please provide the names of the specific datasets in TALENT where ModernNCA significantly underperformed? Additionally, could you share your general experimental settings (e.g., whether you used default parameters or the number of parameter search trials)? It would be even better if you could share the specific result metrics for the different methods so we can attempt to reproduce and verify the issue you described.

If you have observed any other phenomena or have further questions, please feel free to reach out.

Best regards.

Yinhuaihong avatar Nov 22 '25 12:11 Yinhuaihong

@Yinhuaihong

Hi,

I've created a repository to reproduce the issue (LAMDA-TALENT-MNCA-REG-ISSUE). The main experiment uses nested cross-validation with 5 inner and 2/5 outer folds. Model hyperparameters are optimized using Optuna with 100 trials. Datasets are set in the main_regression_issue.py script. Distributions of test metrics can be found in main_check_results.ipynb. Before running the script you need to clone TALENT into the working directory.

Kind regards.

DanilaEremenko avatar Nov 22 '25 23:11 DanilaEremenko

This might have to do with Issue 89

andreasgoethals avatar Dec 12 '25 14:12 andreasgoethals

@andreasgoethals I'm not sure; I have an inverse normalization for the regression within LamdaLow.predict(..), as you can see.

DanilaEremenko avatar Dec 13 '25 18:12 DanilaEremenko

@Yinhuaihong @6sy666 This problem also occurs in classification tasks when using log loss for early stopping and log loss as the test metric for comparison.

DanilaEremenko avatar Dec 23 '25 20:12 DanilaEremenko