cgcnn
cgcnn copied to clipboard
Multiclass classification
Thank you for making this tool! I am running into an issue when I run python main.py --task classification
:
/Users/gianmarcoterrones/opt/anaconda3/envs/cgcnn/lib/python3.11/site-packages/pymatgen/io/cif.py:1134: UserWarning: Issues encountered while parsing CIF: Some fractional coordinates rounded to ideal values to avoid issues with finite precision.
warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings))
Traceback (most recent call last):
File "/Users/gianmarcoterrones/Research/cgcnn/main.py", line 513, in <module>
main()
File "/Users/gianmarcoterrones/Research/cgcnn/main.py", line 175, in main
train(train_loader, model, criterion, optimizer, epoch, normalizer)
File "/Users/gianmarcoterrones/Research/cgcnn/main.py", line 252, in train
loss = criterion(output, target_var)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/gianmarcoterrones/opt/anaconda3/envs/cgcnn/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/gianmarcoterrones/opt/anaconda3/envs/cgcnn/lib/python3.11/site-packages/torch/nn/modules/loss.py", line 216, in forward
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/gianmarcoterrones/opt/anaconda3/envs/cgcnn/lib/python3.11/site-packages/torch/nn/functional.py", line 2704, in nll_loss
return torch._C._nn.nll_loss_nd(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: Target 2 is out of bounds.
Have I made an error in setting up the customized dataset? Or does the code not currently support multiclass classification? The entries of my id_prop.csv look like this:
ACOFUU | 1 |
---|---|
ACOGAB | 1 |
ACOGEF | 1 |
ADABAK | 2 |
AFEJUQ | 1 |
AGUBUA | 1 |
AKOXIJ | 1 |
ALAMUW | 0 |
To clarify, the full command I run is python main.py --task classification --train-ratio 0.6 --val-ratio 0.1 --test-ratio 0.3 root_dir