pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

Point cloud classification for points with additional input features

Open narges-tk opened this issue 1 year ago • 2 comments

🐛 Describe the bug

Could you please let me know how I can do point cloud classification (e.g., using PointCNN) with additional input features? What should I change if I have point clouds with XYZRGB values? Now, I am getting dimensional error as follows:

Traceback (most recent call last): File "/media/emre/Data/Downloads/pytorch_geometric-master/benchmark/points/point_cnn_Narges.py", line 94, in run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, File "/media/emre/Data/Downloads/pytorch_geometric-master/benchmark/points/train_eval.py", line 102, in run run_train(train_dataset, test_dataset, model, epochs, batch_size, File "/media/emre/Data/Downloads/pytorch_geometric-master/benchmark/points/train_eval.py", line 42, in run_train train(model, optimizer, train_loader, device) File "/media/emre/Data/Downloads/pytorch_geometric-master/benchmark/points/train_eval.py", line 116, in train out = model(data.pos, data.batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/media/emre/Data/Downloads/pytorch_geometric-master/benchmark/points/point_cnn_Narges.py", line 48, in forward x = F.relu(self.conv1(None, pos, batch)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch_geometric/nn/conv/x_conv.py", line 147, in forward x_star = self.mlp1(pos) ^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/container.py", line 215, in forward input = module(input) ^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/PyG3/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: mat1 and mat2 shapes cannot be multiplied (0x1 and 3x32)

Versions

curl -OL https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 22068 100 22068 0 0 46361 0 --:--:-- --:--:-- --:--:-- 46264

python3 collect_env.py Collecting environment information... PyTorch version: 2.1.1 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 Clang version: Could not collect CMake version: version 3.16.3 Libc version: glibc-2.31

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.15.0-84-generic-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: 10.1.243 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 Ti Nvidia driver version: 470.223.02 cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 43 bits physical, 48 bits virtual CPU(s): 32 On-line CPU(s) list: 0-31 Thread(s) per core: 2 Core(s) per socket: 16 Socket(s): 1 NUMA node(s): 2 Vendor ID: AuthenticAMD CPU family: 23 Model: 8 Model name: AMD Ryzen Threadripper 2950X 16-Core Processor Stepping: 2 Frequency boost: enabled CPU MHz: 2200.000 CPU max MHz: 3500,0000 CPU min MHz: 2200,0000 BogoMIPS: 6999.02 Virtualization: AMD-V L1d cache: 512 KiB L1i cache: 1 MiB L2 cache: 8 MiB L3 cache: 32 MiB NUMA node0 CPU(s): 0-7,16-23 NUMA node1 CPU(s): 8-15,24-31 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; untrained return thunk; SMT vulnerable Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid amd_dcm aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sme sev sev_es

Versions of relevant libraries: [pip3] numpy==1.26.2 [pip3] torch==2.1.1 [pip3] torch-cluster==1.6.3 [pip3] torch_geometric==2.4.0 [pip3] torch-scatter==2.1.2 [pip3] torch-spline-conv==1.2.2 [pip3] torchaudio==2.1.1 [pip3] torchvision==0.16.1 [pip3] triton==2.1.0 [conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch [conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py311h5eee18b_1
[conda] mkl_fft 1.3.8 py311h5eee18b_0
[conda] mkl_random 1.2.4 py311hdb19cb5_0
[conda] numpy 1.26.2 py311h08b1b3b_0
[conda] numpy-base 1.26.2 py311hf175353_0
[conda] pytorch 2.1.1 py3.11_cuda11.8_cudnn8.7.0_0 pytorch [conda] pytorch-cluster 1.6.3 py311_torch_2.1.0_cu118 pyg [conda] pytorch-cuda 11.8 h7e8668a_5 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] pytorch-scatter 2.1.2 py311_torch_2.1.0_cu118 pyg [conda] pytorch-spline-conv 1.2.2 py311_torch_2.1.0_cu118 pyg [conda] torch-geometric 2.4.0 pypi_0 pypi [conda] torchaudio 2.1.1 py311_cu118 pytorch [conda] torchtriton 2.1.0 py311 pytorch [conda] torchvision 0.16.1 py311_cu118 pytorch

narges-tk avatar Feb 15 '24 14:02 narges-tk

Your first layer would have an increased number of input features, and forward pass for, e.g., benchmark/points/point_net.py would look as

def forward(self, x, pos, batch):
    radius = 0.2
    edge_index = radius_graph(pos, r=radius, batch=batch)
    x = F.relu(self.conv1(x, pos, edge_index))

rusty1s avatar Feb 17 '24 13:02 rusty1s

It worked! Many thanks for your prompt answer.

narges-tk avatar Feb 20 '24 14:02 narges-tk