pixelsplat icon indicating copy to clipboard operation
pixelsplat copied to clipboard

Training Error Occured : Type Error

Open KyungdaePark opened this issue 4 months ago • 2 comments

Hello, I just started training, with re10k datasets, and this error occured :

================================================================================

❯ python3 -m src.main +experiment=re10k data_loader.train.batch_size=1 
Saving outputs to /n/pixelsplat2/outputs/2025-08-22/16-07-13.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Using cache found in /home/pkd/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/pkd/.cache/torch/hub/facebookresearch_dino_main
[2025-08-22 16:07:14,084][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(

[2025-08-22 16:07:14,084][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
[2025-08-22 16:07:14,328][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)

Loading model from: /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                 | Params | Mode
---------------------------------------------------------
0 | encoder | EncoderEpipolar      | 118 M  | train
1 | decoder | DecoderSplattingCUDA | 0      | train
2 | losses  | ModuleList           | 0      | train
---------------------------------------------------------
118 M     Trainable params
0         Non-trainable params
118 M     Total params
475.918   Total estimated model params size (MB)
493       Modules in train mode
59        Modules in eval mode
[2025-08-22 16:07:15,289][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

validation step 0; scene = ['306e2b7785657539']; context = [[48, 73]]
[2025-08-22 16:07:15,920][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
[2025-08-22 16:07:15,921][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(

[2025-08-22 16:07:15,921][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)

Loading model from: /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
Error executing job with overrides: ['+experiment=re10k', 'data_loader.train.batch_size=1']
Traceback (most recent call last):
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 436, in wrapped_fn_impl
    param_fn(*args, **kwargs)
  File "<@beartype(src.visualization.layout.vcat) at 0x74e449816560>", line 55, in vcat
beartype.roar.BeartypeCallHintParamViolation: Function src.visualization.layout.vcat() parameter images="tensor([[[0.3804, 0.4118, 0.4392,  ..., 0.9961, 0.9961, 0.9961],
         [0.3882, 0.4196...')" violates type hint typing.Iterable[jaxtyping.Float[Tensor, 'channel _ _']], as <protocol "torch.Tensor"> index 0 item this array has 2 dimensions, not the 3 expected by the type hint.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 753, in _get_problem_arg
    fn(*args, **kwargs)
  File "<@beartype(src.visualization.layout.check_single_arg) at 0x74e503b75480>", line 49, in check_single_arg
beartype.roar.BeartypeCallHintParamViolation: Function src.visualization.layout.check_single_arg() parameter images="tensor([[[0.3804, 0.4118, 0.4392,  ..., 0.9961, 0.9961, 0.9961],
         [0.3882, 0.4196...')" violates type hint typing.Iterable[jaxtyping.Float[Tensor, 'channel _ _']], as <protocol "torch.Tensor"> index 0 item this array has 2 dimensions, not the 3 expected by the type hint.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 441, in wrapped_fn_impl
    argmsg = _get_problem_arg(
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 756, in _get_problem_arg
    raise TypeCheckError(
jaxtyping.TypeCheckError:
The problem arose whilst typechecking parameter 'images'.
Actual value: (f32[3,256,256](torch), f32[3,256,256](torch))
Expected type: typing.Iterable[Float[Tensor, 'channel _ _']].

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/n/pixelsplat2/src/main.py", line 128, in train
    trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
    call._call_and_handle_interrupt(
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
    results = self._run_stage()
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1054, in _run_stage
    self._run_sanity_check()
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1083, in _run_sanity_check
    val_loop.run()
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 41, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 473, in wrapped_fn_impl
    out = fn(*args, **kwargs)
  File "/n/pixelsplat2/src/model/model_wrapper.py", line 258, in validation_step
    add_label(vcat(*batch["context"]["image"][0]), "Context"),
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 470, in wrapped_fn_impl
    raise TypeCheckError(msg) from e
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of src.visualization.layout.vcat.
The problem arose whilst typechecking parameter 'images'.
Actual value: (f32[3,256,256](torch), f32[3,256,256](torch))
Expected type: typing.Iterable[Float[Tensor, 'channel _ _']].
----------------------
Called with parameters: {
  'images': (f32[3,256,256](torch), f32[3,256,256](torch)),
  'align': 'start',
  'gap': 8,
  'gap_color': 1
}
Parameter annotations: (*images: Iterable[Float[Tensor, 'channel _ _']], align: Literal['start', 'center', 'end', 'left', 'right'] = 'start', gap: int = 8, gap_color: Union[int, float, Iterable[int], Iterable[float], Float[Tensor, '#channel'], Float[Tensor, '']] = 1) -> Any.


Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

================================================================================

How can I fix it ?

KyungdaePark avatar Aug 22 '25 07:08 KyungdaePark

pip list :

❯ pip list
Package                     Version
--------------------------- ------------
addict                      2.4.0
aiohappyeyeballs            2.6.1
aiohttp                     3.12.14
aiosignal                   1.4.0
annotated-types             0.7.0
antlr4-python3-runtime      4.9.3
asttokens                   3.0.0
async-timeout               5.0.1
attrs                       25.3.0
beartype                    0.21.0
black                       25.1.0
blinker                     1.9.0
brotlicffi                  1.0.9.2
certifi                     2025.7.14
cffi                        1.17.1
charset-normalizer          3.4.2
click                       8.2.1
colorama                    0.4.6
colorspacious               1.1.2
comm                        0.2.2
ConfigArgParse              1.7.1
contourpy                   1.3.2
cycler                      0.12.1
dacite                      1.9.2
dash                        3.1.1
decorator                   4.4.2
diff_gaussian_rasterization 0.0.0
e3nn                        0.5.6
einops                      0.8.1
exceptiongroup              1.3.0
executing                   2.2.0
fastjsonschema              2.21.1
filelock                    3.18.0
Flask                       3.1.1
fonttools                   4.59.0
frozenlist                  1.7.0
fsspec                      2025.7.0
gitdb                       4.0.12
GitPython                   3.1.44
gmpy2                       2.2.1
hf-xet                      1.1.5
huggingface-hub             0.33.4
hydra-core                  1.3.2
idna                        3.10
imageio                     2.37.0
imageio-ffmpeg              0.6.0
importlib_metadata          8.7.0
ipython                     8.37.0
ipywidgets                  8.1.7
itsdangerous                2.2.0
jaxtyping                   0.3.2
jedi                        0.19.2
Jinja2                      3.1.6
joblib                      1.5.1
jsonschema                  4.24.1
jsonschema-specifications   2025.4.1
jupyter_core                5.8.1
jupyterlab_widgets          3.0.15
kiwisolver                  1.4.8
lazy_loader                 0.4
lightning                   2.5.2
lightning-utilities         0.14.3
lpips                       0.1.4
MarkupSafe                  3.0.2
matplotlib                  3.10.3
matplotlib-inline           0.1.7
mkl_fft                     1.3.10
mkl_random                  1.2.8
mkl-service                 2.4.0
moviepy                     1.0.3
mpmath                      1.3.0
multidict                   6.6.3
mypy_extensions             1.1.0
narwhals                    1.47.1
nbformat                    5.10.4
nest-asyncio                1.6.0
networkx                    3.4.2
numpy                       1.26.4
omegaconf                   2.3.0
open3d                      0.19.0
opencv-python               4.9.0.80
opt_einsum                  3.4.0
opt-einsum-fx               0.1.4
packaging                   25.0
pandas                      2.3.1
parso                       0.8.4
pathspec                    0.12.1
pexpect                     4.9.0
pillow                      11.3.0
pip                         25.1
platformdirs                4.3.8
plotly                      6.2.0
plyfile                     1.1.2
proglog                     0.1.12
prompt_toolkit              3.0.51
propcache                   0.3.2
protobuf                    6.31.1
ptyprocess                  0.7.0
pure_eval                   0.2.3
pycparser                   2.21
pydantic                    2.11.7
pydantic_core               2.33.2
Pygments                    2.19.2
pyparsing                   3.2.3
pyquaternion                0.9.9
PySocks                     1.7.1
python-dateutil             2.9.0.post0
python-dotenv               1.1.1
pytorch-lightning           2.5.2
pytz                        2025.2
PyYAML                      6.0.2
referencing                 0.36.2
requests                    2.32.4
retrying                    1.4.0
rpds-py                     0.26.0
ruff                        0.12.3
safetensors                 0.5.3
scikit-image                0.25.2
scikit-learn                1.7.0
scipy                       1.15.3
sentry-sdk                  2.33.0
setuptools                  78.1.1
six                         1.17.0
smmap                       5.0.2
stack-data                  0.6.3
svg.py                      1.7.0
sympy                       1.14.0
tabulate                    0.9.0
threadpoolctl               3.6.0
tifffile                    2025.5.10
timm                        1.0.17
tomli                       2.2.1
torch                       2.1.2+cu118
torchaudio                  2.1.0
torchmetrics                1.7.4
torchvision                 0.16.2+cu118
tqdm                        4.67.1
traitlets                   5.14.3
triton                      2.1.0
typing_extensions           4.14.1
typing-inspection           0.4.1
tzdata                      2025.2
urllib3                     2.5.0
wadler_lindig               0.1.7
wandb                       0.21.0
wcwidth                     0.2.13
Werkzeug                    3.1.3
wheel                       0.45.1
widgetsnbextension          4.0.14
yarl                        1.20.1
zipp                        3.23.0

KyungdaePark avatar Aug 22 '25 07:08 KyungdaePark

Same issue, I tried downgrading the version of jaxtyping and beartype as per requipments_w_version.txt in MVSplat and this fixed the issue.

KeKer7 avatar Aug 26 '25 07:08 KeKer7