targetdiff
targetdiff copied to clipboard
how to get the batch.ligand_element_batch
Hi, thank you for sharing such a good work. However, I am a little confused about how can I get batch.ligand_element_batch in the def train(it): model.train() optimizer.zero_grad() for _ in range(config.train.n_acc_batch): batch = next(train_iterator).to(args.device)
results = model.get_diffusion_loss(
ligand_pos=batch.ligand_pos, #
ligand_v=batch.ligand_atom_feature_full,
batch_ligand=batch.ligand_element_batch
)
Can you tell where can I find the processing operation of the ligand element batch? Thank you