m2cgen icon indicating copy to clipboard operation
m2cgen copied to clipboard

add code for multioutput regression

Open AaronDavidSchneider opened this issue 1 year ago • 2 comments

This PR adds the option for multioutput regression in XGBoost. This could work for other booster classes as well. However, there is no general attribute that informs about the number of targets. Thats why this solution is restricted to XGBoost.

Closes #559

AaronDavidSchneider avatar Feb 27 '23 16:02 AaronDavidSchneider

I hereby also add some tests. And here is another simple way to verify that it works:

n_targets = 3
n_features = 3
n_test = 20
n_train = 20

X, y = datasets.make_regression(n_targets=n_targets, n_features=n_features, n_samples=n_train, random_state=1)

multi_class_model_params = {
    'n_estimators': 3,
    'max_depth': 2
}

model = XGBRegressor(**multi_class_model_params).fit(X, y)

code = m2cgen.export_to_python(model, function_name=f'predict')
with open('test_file.py', 'w') as f:
    f.write(code)

from test_file import predict

input = np.random.random((n_test, n_targets))

closenes = []
for i, input_i in enumerate(input):
    hardcoded = predict(input_i)
    apipred = model.predict(input_i.reshape((1, n_features)))
    closenes.append(np.allclose(apipred, hardcoded))

print(f'All close: {np.all(closenes)}, fraction of close: {np.sum(closenes)/n_test}')

Which works well for me.

I also ran the make pre-pr command, which reported pass for all tests but two on my MacBook. I don't know what happened to the two tests where it failed, but I suspect that it may be related to different versions of python packages being installed on my machine and the test machine (e.g., these exact same problems also occur on the master branch...).

Unfortunately, the docker command also fails on my machine, which does not seem to be caused by my changes but rather by the docker build command.

I hope that the CI will give more insight!

AaronDavidSchneider avatar Feb 28 '23 08:02 AaronDavidSchneider

Kindly pinging @izeigerman and @StrikerRUS

AaronDavidSchneider avatar Mar 02 '23 08:03 AaronDavidSchneider