AMPL icon indicating copy to clipboard operation
AMPL copied to clipboard

Inference from a multiclass classification model

Open mmagithub opened this issue 1 year ago • 1 comments

Hi,

I am wondering how can i use the ''predict_from_model" function to make predictions on unknown molecules from a pre-trained tar.gz model, for a multi-class model (6 classes classification model)

I have tried to make the prediction in the standard way:

########### pfm.predict_from_model_file(model_path = mfile_NN, dont_standardize=True, input_df = input_df, smiles_col = smiles_col, response_col = response_col) ###########

as I do with binary models, but for multi-class models, I keep getting this error:

Traceback (most recent call last): File "~/make_predictions_atomsci_models.py", line 35, in response_col = response_col) File "/AMPL/atomsci/ddm/pipeline/predict_from_model.py", line 111, in predict_from_model_file pipe = mp.create_prediction_pipeline_from_file(pred_params, reload_dir=None, model_path=model_path) File "/AMPL/atomsci/ddm/pipeline/model_pipeline.py", line 1342, in create_prediction_pipeline_from_file pipeline.model_wrapper.reload_model(model_dir) File "/AMPL/atomsci/ddm/pipeline/model_wrapper.py", line 2378, in reload_model self.model.restore(best_chkpt, reload_dir) File "/miniconda3/envs/atomsci/lib/python3.7/site-packages/deepchem/models/torch_models/torch_model.py", line 1038, in restore self.model.load_state_dict(data['model_state_dict']) File "/miniconda3/envs/atomsci/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for PytorchImpl: size mismatch for output_layer.weight: copying a param with shape torch.Size([6, 512]) from checkpoint, the shape in current model is torch.Size([2, 512]). size mismatch for output_layer.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([2]).

Any clue ?

Thanks, M

mmagithub avatar Jun 19 '23 00:06 mmagithub

The parameter class_number was never saved so when the new model is loaded it thinks there's only 2 classes instead of 6. You have a few options to work around this bug.

First, you can look at this commit and try out the fix. https://github.com/ATOMScience-org/AMPL/commit/a7fb0d6c77cdcdbfbf06ddb5a843a2a08592f0e9

Or, you can un-tar the model, manually edit the model_metadata.json, then re-tar the model and try to run predictions again. You'll need to make the following change:

"model_parameters": {
    "ampl_version": "1.6.0",
    "class_number": 6,
    "featurizer": "ecfp",
    "hyperparam_uuid": null,
    "model_bucket": "public",
    "model_choice_score_type": "roc_auc",
    "model_type": "NN",
    "num_model_tasks": 1,
    "prediction_type": "classification",

Note the line that adds "class_number":6,

stewarthe6 avatar Jun 22 '23 17:06 stewarthe6