chemprop icon indicating copy to clipboard operation
chemprop copied to clipboard

[TODO]: v2 interpret functionality

Open kevingreenman opened this issue 1 year ago • 3 comments

In the MLPDS steering committee meeting today, two companies indicated that they regularly use the interpret function and see good results from this, so they would like to see this not left out of v2. I think this would be appropriate to add in v2.1 or v2.2. I'll set it as v2.1 for now, but we can move it to v2.2 if needed.

kevingreenman avatar Jan 30 '24 14:01 kevingreenman

Great to see you at MLPDS today! Thanks for these notes. To document some of my feedback when using the interpret.py script, the default behavior seems to only save one substructure. It also prioritizes the smallest substructure which means it can be sensitive to whatever min_atoms was passed by the user. If I later deem it acceptable to use slightly larger substructures that have higher explainability scores, I have to re-run the script which can take some time. If possible, could we work on updating the functionality so the top N substructures are saved? I believe a short addition like this should work based on some local testing. I just hardcoded N here.

    property_name = header[args.property_id] if len(header) > args.property_id else 'score'
    csv_contents = f'smiles,{property_name},rationale,rationale_score'
    csv_contents_topN = f'smiles,{property_name}'
    N = 10
    for i in range(1, N + 1):
        csv_contents_topN += f',rationale_{i},rationale_{i}_score'
    print(f'smiles,{property_name},rationale,rationale_score')

    for smiles in all_smiles:
        print('*'*88)
        print(smiles)
        score = scoring_function([smiles])[0]
        if score > args.prop_delta:
            rationales = mcts(
                smiles=smiles[0],
                scoring_function=scoring_function,
                n_rollout=args.rollout,
                max_atoms=args.max_atoms,
                prop_delta=args.prop_delta
            )
        else:
            rationales = []

            csv_contents_topN += f'\n{smiles},{score:.3f}'
            rationales_sorted = sorted(rationales, key=lambda x: x.P, reverse=True)
            for i, x in enumerate(rationales_sorted):
                if i < N:
                    csv_contents_topN += f',{x.smiles},{x.P:.3f}'
                print(f'{x.smiles}, {x.P:.3f}')
            print('\n')

            min_size = min(len(x.atoms) for x in rationales)
            min_rationales = [x for x in rationales if len(x.atoms) == min_size]
            rats = sorted(min_rationales, key=lambda x: x.P, reverse=True)
            print(f'{smiles},{score:.3f},{rats[0].smiles},{rats[0].P:.3f}')
            csv_contents += f'\n{smiles},{score:.3f},{rats[0].smiles},{rats[0].P:.3f}'
    
    with open('MCTS_rationales.csv', 'w') as f:
        f.write(csv_contents)

    with open(f'MCTS_rationales_top{N}.csv', 'w') as f:
        f.write(csv_contents_topN)

As a longer-term goal, it would be awesome if we could add ablation (e.g., removing atom symbol and maybe MW too) so we could easily see atomic contributions to the prediction. Happy to hear additional brainstorming! Thanks again!

kspieks avatar Apr 17 '24 22:04 kspieks

@kspieks thanks for the input! We'll definitely consider this as we implement interpretation in v2. I don't think many people on the dev team have used the interpretation functionality in a while, so this type of input is very valuable.

kevingreenman avatar Apr 18 '24 03:04 kevingreenman

Related https://github.com/chemprop/chemprop/pull/923

oscarwumit avatar Jun 18 '24 18:06 oscarwumit