matbench-discovery icon indicating copy to clipboard operation
matbench-discovery copied to clipboard

Simplified user interface

Open pbenner opened this issue 1 year ago • 1 comments

What about providing a simplified user interface for training and testing models? This would be a simple example:

import pandas as pd

from matbench_discovery.data import DATA_FILES, df_wbm
from pymatgen.core import Structure
from sklearn.metrics import r2_score

class MatbenchDiscovery:
    def __init__(self, task_type = "IS2RE"):
        if task_type not in ['IS2RE', 'RS2RE']:
            raise ValueError(f'Invalid task_type {task_type}')
        self.task_type = task_type
    
    def get_test_data(self):
        id_col = "material_id"
        input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[self.task_type]
        target_col = "e_form_per_atom_mp2020_corrected"

        data_path = {
            "IS2RE": DATA_FILES.wbm_initial_structures,
            "RS2RE": DATA_FILES.wbm_computed_structure_entries,
        }[self.task_type]

        df_in = pd.read_json(data_path).set_index(id_col)

        X = pd.Series([Structure.from_dict(x) for x in df_in[input_col]], index = df_in.index)
        y = pd.Series(df_wbm[target_col])

        return X[y.index], y

    def get_train_data(self):
        assert self.task_type == "IS2RE", "TODO"

        target_col = "formation_energy_per_atom"
        input_col = "structure"
        id_col = "material_id"

        df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index(id_col)
        df_eng = pd.read_csv(DATA_FILES.mp_energies).set_index(id_col)

        X = pd.Series([ Structure.from_dict(cse[input_col]) for cse in df_cse.entry ], index = df_cse.index)
        y = pd.Series(df_eng[target_col], index = df_eng.index)

        return X[y.index], y

    def evaluate_predictions(self, y_pred, apply_correction = False):

        assert type(y_pred) == pd.Series

        target_col = "e_form_per_atom_mp2020_corrected"

        y_pred = y_pred.dropna()
        y_true = df_wbm[target_col][y_pred.index]

        if apply_correction:
            y_pred -= df_wbm.e_correction_per_atom_mp_legacy
            y_pred += df_wbm.e_correction_per_atom_mp2020

        mae = (y_true - y_pred).abs().mean()
        r2 = r2_score(y_true, y_pred)

        return {'mae': mae, 'r2': r2, 'y_true': y_true, 'y_pred': y_pred}

pbenner avatar Jul 11 '23 11:07 pbenner