chemprop
chemprop copied to clipboard
[TODO]: v2 interpret functionality
In the MLPDS steering committee meeting today, two companies indicated that they regularly use the interpret function and see good results from this, so they would like to see this not left out of v2. I think this would be appropriate to add in v2.1 or v2.2. I'll set it as v2.1 for now, but we can move it to v2.2 if needed.
Great to see you at MLPDS today! Thanks for these notes. To document some of my feedback when using the interpret.py
script, the default behavior seems to only save one substructure. It also prioritizes the smallest substructure which means it can be sensitive to whatever min_atoms
was passed by the user. If I later deem it acceptable to use slightly larger substructures that have higher explainability scores, I have to re-run the script which can take some time. If possible, could we work on updating the functionality so the top N substructures are saved? I believe a short addition like this should work based on some local testing. I just hardcoded N
here.
property_name = header[args.property_id] if len(header) > args.property_id else 'score'
csv_contents = f'smiles,{property_name},rationale,rationale_score'
csv_contents_topN = f'smiles,{property_name}'
N = 10
for i in range(1, N + 1):
csv_contents_topN += f',rationale_{i},rationale_{i}_score'
print(f'smiles,{property_name},rationale,rationale_score')
for smiles in all_smiles:
print('*'*88)
print(smiles)
score = scoring_function([smiles])[0]
if score > args.prop_delta:
rationales = mcts(
smiles=smiles[0],
scoring_function=scoring_function,
n_rollout=args.rollout,
max_atoms=args.max_atoms,
prop_delta=args.prop_delta
)
else:
rationales = []
csv_contents_topN += f'\n{smiles},{score:.3f}'
rationales_sorted = sorted(rationales, key=lambda x: x.P, reverse=True)
for i, x in enumerate(rationales_sorted):
if i < N:
csv_contents_topN += f',{x.smiles},{x.P:.3f}'
print(f'{x.smiles}, {x.P:.3f}')
print('\n')
min_size = min(len(x.atoms) for x in rationales)
min_rationales = [x for x in rationales if len(x.atoms) == min_size]
rats = sorted(min_rationales, key=lambda x: x.P, reverse=True)
print(f'{smiles},{score:.3f},{rats[0].smiles},{rats[0].P:.3f}')
csv_contents += f'\n{smiles},{score:.3f},{rats[0].smiles},{rats[0].P:.3f}'
with open('MCTS_rationales.csv', 'w') as f:
f.write(csv_contents)
with open(f'MCTS_rationales_top{N}.csv', 'w') as f:
f.write(csv_contents_topN)
As a longer-term goal, it would be awesome if we could add ablation (e.g., removing atom symbol and maybe MW too) so we could easily see atomic contributions to the prediction. Happy to hear additional brainstorming! Thanks again!
@kspieks thanks for the input! We'll definitely consider this as we implement interpretation in v2. I don't think many people on the dev team have used the interpretation functionality in a while, so this type of input is very valuable.
Related https://github.com/chemprop/chemprop/pull/923