alignn
alignn copied to clipboard
Classification model prediction
Add an example to make predictions with trained model. Something like the following:
from alignn.models.alignn import ALIGNN, ALIGNNConfig
import torch
import pprint
from alignn.config import TrainingConfig
from jarvis.core.atoms import Atoms
from jarvis.core.graphs import Graph
from jarvis.db.jsonutils import dumpjson, loadjson
device = "cpu"
if torch.cuda.is_available():
device = torch.device("cuda")
filename = "checkpoint_100.pt"
cutoff = 8
max_neighbors = 12
config = loadjson("config.json")
print(pprint.pprint(config))
config = TrainingConfig(**config)
model = ALIGNN(config.model)
model.load_state_dict(torch.load(filename, map_location=device)["model"])
model.to(device)
model.eval()
atoms = Atoms.from_poscar("POSCAR")
g, lg = Graph.atom_dgl_multigraph(
atoms,
cutoff=float(cutoff),
max_neighbors=max_neighbors,
)
out_data = (
torch.argmax(model([g.to(device), lg.to(device)]))
.detach()
.cpu()
.numpy()
.flatten()
.tolist()
)[0]
print("out_data class ", out_data)
Can I make predictions on a new dataset of CIFs with a model trained from an old dataset?