pixelsplat
pixelsplat copied to clipboard
Training Error Occured : Type Error
Hello, I just started training, with re10k datasets, and this error occured :
================================================================================
❯ python3 -m src.main +experiment=re10k data_loader.train.batch_size=1
Saving outputs to /n/pixelsplat2/outputs/2025-08-22/16-07-13.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Using cache found in /home/pkd/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/pkd/.cache/torch/hub/facebookresearch_dino_main
[2025-08-22 16:07:14,084][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
[2025-08-22 16:07:14,084][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
warnings.warn(msg)
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
[2025-08-22 16:07:14,328][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Loading model from: /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
---------------------------------------------------------
0 | encoder | EncoderEpipolar | 118 M | train
1 | decoder | DecoderSplattingCUDA | 0 | train
2 | losses | ModuleList | 0 | train
---------------------------------------------------------
118 M Trainable params
0 Non-trainable params
118 M Total params
475.918 Total estimated model params size (MB)
493 Modules in train mode
59 Modules in eval mode
[2025-08-22 16:07:15,289][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
validation step 0; scene = ['306e2b7785657539']; context = [[48, 73]]
[2025-08-22 16:07:15,920][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
[2025-08-22 16:07:15,921][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
[2025-08-22 16:07:15,921][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Loading model from: /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
Error executing job with overrides: ['+experiment=re10k', 'data_loader.train.batch_size=1']
Traceback (most recent call last):
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 436, in wrapped_fn_impl
param_fn(*args, **kwargs)
File "<@beartype(src.visualization.layout.vcat) at 0x74e449816560>", line 55, in vcat
beartype.roar.BeartypeCallHintParamViolation: Function src.visualization.layout.vcat() parameter images="tensor([[[0.3804, 0.4118, 0.4392, ..., 0.9961, 0.9961, 0.9961],
[0.3882, 0.4196...')" violates type hint typing.Iterable[jaxtyping.Float[Tensor, 'channel _ _']], as <protocol "torch.Tensor"> index 0 item this array has 2 dimensions, not the 3 expected by the type hint.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 753, in _get_problem_arg
fn(*args, **kwargs)
File "<@beartype(src.visualization.layout.check_single_arg) at 0x74e503b75480>", line 49, in check_single_arg
beartype.roar.BeartypeCallHintParamViolation: Function src.visualization.layout.check_single_arg() parameter images="tensor([[[0.3804, 0.4118, 0.4392, ..., 0.9961, 0.9961, 0.9961],
[0.3882, 0.4196...')" violates type hint typing.Iterable[jaxtyping.Float[Tensor, 'channel _ _']], as <protocol "torch.Tensor"> index 0 item this array has 2 dimensions, not the 3 expected by the type hint.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 441, in wrapped_fn_impl
argmsg = _get_problem_arg(
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 756, in _get_problem_arg
raise TypeCheckError(
jaxtyping.TypeCheckError:
The problem arose whilst typechecking parameter 'images'.
Actual value: (f32[3,256,256](torch), f32[3,256,256](torch))
Expected type: typing.Iterable[Float[Tensor, 'channel _ _']].
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/n/pixelsplat2/src/main.py", line 128, in train
trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
call._call_and_handle_interrupt(
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
results = self._run_stage()
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1054, in _run_stage
self._run_sanity_check()
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1083, in _run_sanity_check
val_loop.run()
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
return loop_run(self, *args, **kwargs)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 41, in wrapped_fn
return fn(*args, **kwargs)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
return wrapped_fn_impl(args, kwargs, bound, memos)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 473, in wrapped_fn_impl
out = fn(*args, **kwargs)
File "/n/pixelsplat2/src/model/model_wrapper.py", line 258, in validation_step
add_label(vcat(*batch["context"]["image"][0]), "Context"),
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
return wrapped_fn_impl(args, kwargs, bound, memos)
File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 470, in wrapped_fn_impl
raise TypeCheckError(msg) from e
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of src.visualization.layout.vcat.
The problem arose whilst typechecking parameter 'images'.
Actual value: (f32[3,256,256](torch), f32[3,256,256](torch))
Expected type: typing.Iterable[Float[Tensor, 'channel _ _']].
----------------------
Called with parameters: {
'images': (f32[3,256,256](torch), f32[3,256,256](torch)),
'align': 'start',
'gap': 8,
'gap_color': 1
}
Parameter annotations: (*images: Iterable[Float[Tensor, 'channel _ _']], align: Literal['start', 'center', 'end', 'left', 'right'] = 'start', gap: int = 8, gap_color: Union[int, float, Iterable[int], Iterable[float], Float[Tensor, '#channel'], Float[Tensor, '']] = 1) -> Any.
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
================================================================================
How can I fix it ?
pip list :
❯ pip list
Package Version
--------------------------- ------------
addict 2.4.0
aiohappyeyeballs 2.6.1
aiohttp 3.12.14
aiosignal 1.4.0
annotated-types 0.7.0
antlr4-python3-runtime 4.9.3
asttokens 3.0.0
async-timeout 5.0.1
attrs 25.3.0
beartype 0.21.0
black 25.1.0
blinker 1.9.0
brotlicffi 1.0.9.2
certifi 2025.7.14
cffi 1.17.1
charset-normalizer 3.4.2
click 8.2.1
colorama 0.4.6
colorspacious 1.1.2
comm 0.2.2
ConfigArgParse 1.7.1
contourpy 1.3.2
cycler 0.12.1
dacite 1.9.2
dash 3.1.1
decorator 4.4.2
diff_gaussian_rasterization 0.0.0
e3nn 0.5.6
einops 0.8.1
exceptiongroup 1.3.0
executing 2.2.0
fastjsonschema 2.21.1
filelock 3.18.0
Flask 3.1.1
fonttools 4.59.0
frozenlist 1.7.0
fsspec 2025.7.0
gitdb 4.0.12
GitPython 3.1.44
gmpy2 2.2.1
hf-xet 1.1.5
huggingface-hub 0.33.4
hydra-core 1.3.2
idna 3.10
imageio 2.37.0
imageio-ffmpeg 0.6.0
importlib_metadata 8.7.0
ipython 8.37.0
ipywidgets 8.1.7
itsdangerous 2.2.0
jaxtyping 0.3.2
jedi 0.19.2
Jinja2 3.1.6
joblib 1.5.1
jsonschema 4.24.1
jsonschema-specifications 2025.4.1
jupyter_core 5.8.1
jupyterlab_widgets 3.0.15
kiwisolver 1.4.8
lazy_loader 0.4
lightning 2.5.2
lightning-utilities 0.14.3
lpips 0.1.4
MarkupSafe 3.0.2
matplotlib 3.10.3
matplotlib-inline 0.1.7
mkl_fft 1.3.10
mkl_random 1.2.8
mkl-service 2.4.0
moviepy 1.0.3
mpmath 1.3.0
multidict 6.6.3
mypy_extensions 1.1.0
narwhals 1.47.1
nbformat 5.10.4
nest-asyncio 1.6.0
networkx 3.4.2
numpy 1.26.4
omegaconf 2.3.0
open3d 0.19.0
opencv-python 4.9.0.80
opt_einsum 3.4.0
opt-einsum-fx 0.1.4
packaging 25.0
pandas 2.3.1
parso 0.8.4
pathspec 0.12.1
pexpect 4.9.0
pillow 11.3.0
pip 25.1
platformdirs 4.3.8
plotly 6.2.0
plyfile 1.1.2
proglog 0.1.12
prompt_toolkit 3.0.51
propcache 0.3.2
protobuf 6.31.1
ptyprocess 0.7.0
pure_eval 0.2.3
pycparser 2.21
pydantic 2.11.7
pydantic_core 2.33.2
Pygments 2.19.2
pyparsing 3.2.3
pyquaternion 0.9.9
PySocks 1.7.1
python-dateutil 2.9.0.post0
python-dotenv 1.1.1
pytorch-lightning 2.5.2
pytz 2025.2
PyYAML 6.0.2
referencing 0.36.2
requests 2.32.4
retrying 1.4.0
rpds-py 0.26.0
ruff 0.12.3
safetensors 0.5.3
scikit-image 0.25.2
scikit-learn 1.7.0
scipy 1.15.3
sentry-sdk 2.33.0
setuptools 78.1.1
six 1.17.0
smmap 5.0.2
stack-data 0.6.3
svg.py 1.7.0
sympy 1.14.0
tabulate 0.9.0
threadpoolctl 3.6.0
tifffile 2025.5.10
timm 1.0.17
tomli 2.2.1
torch 2.1.2+cu118
torchaudio 2.1.0
torchmetrics 1.7.4
torchvision 0.16.2+cu118
tqdm 4.67.1
traitlets 5.14.3
triton 2.1.0
typing_extensions 4.14.1
typing-inspection 0.4.1
tzdata 2025.2
urllib3 2.5.0
wadler_lindig 0.1.7
wandb 0.21.0
wcwidth 0.2.13
Werkzeug 3.1.3
wheel 0.45.1
widgetsnbextension 4.0.14
yarl 1.20.1
zipp 3.23.0
Same issue, I tried downgrading the version of jaxtyping and beartype as per requipments_w_version.txt in MVSplat and this fixed the issue.