Trying to add custom-trained YOLOv8n-pose model
First of all, thank you for such a well-written code and document. Everything was easy to read and understand, and the codes were organized very nicely. As a person with no computer degree, this is very much appreciated, especially when you don't see it often from the people in your field.
What I have tried to do is what the title says: I have a custom-trained YOLOv8n pose model, which is in .pt format. I wanted to add this model as 1) supervised tracking model, and 2) possibly extend this into a semi-supervised tracker.
However, even after following your detailed instructions on adding a new model, I have failed to do so. What I have done is the following. It is in the order of the document page.
-
Added [[YOLOtracker.py]that defines two new tracker classes - YOLOtracker and SemisupervisedYOLOtracker
- YOLOtracker(RegressionTracker)
- SemisupervisedYOLOtracker(SemiSupervisedTrackerMixin, YOLOtracker)
- The only differences are that -
- backbone=YOLO('~Somepath~/ym-pretrained.pt') (yes, I have imported YOLO from ultralytics here)
- num_keypoints=11
-
Added ‘YOLOtracker’ to ALLOWED_MODELS in models/init.py
-
Created new config with model_type: “YOLOtracker”
-
line85 of utils/scripts.py if cfg.model.model_type == "regression" or cfg.model.model_type == "YOLOtracker":
-
added
elif cfg.model.model_type == "YOLOtracker": model = YOLOtracker( num_keypoints=cfg.data.num_keypoints, # loss_factory=loss_factories["supervised"], backbone=cfg.model.backbone, # torch_seed=cfg.training.rng_seed_model_pt, # lr_scheduler=lr_scheduler, # lr_scheduler_params=lr_scheduler_params, # image_size=image_h, # only used by ViT -
added in get_model_class
elif map_type == "YOLOtracker":
from lightning_pose.models import YOLOtracker as Model
I have also went ahead and added "YOLOtracker":RegressionMSELoss to losses.losses, which was missing from the document.
then I went to the unit test and created,
def test_supervised_YOLO(
cfg, base_data_module, video_dataloader, trainer, remove_logs
):
"""Test the initialization and training of a supervised YOLO model."""
# cfg = '/home/tarislada/Behavitproject/lightning-pose/scripts/configs/config_custom.yaml'
cfg_tmp = copy.deepcopy(cfg)
cfg_tmp.model.model_type = "YOLOtracker"
cfg_tmp.model.losses_to_use = []
run_model_test(
cfg=cfg_tmp,
data_module=base_data_module,
video_dataloader=video_dataloader,
trainer=trainer,
remove_logs_fn=remove_logs,
)
Which failed with the message FAILED tests/models/test_custom_trackers.py::test_supervised_YOLO - RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn Because I did see the "Initializing a YOLOtracker instance." message and the usual YOLO run command line messages, I assume that the implementation regarding the lightning-pose document was successful. Any idea on how to get this working?
Hi @Tarislada, thanks for checking out the Lightning Pose repo! Happy to help you get your YOLO tracker up and running. Could you run pytest tests/models/test_custom_trackers.py::test_supervised_YOLO -s (the -s will actually output the print statements to the terminal) and then paste the error traceback here?
Thanks for the quick reply, and sorry for my late check-up. I did run the command, and came back with:
========================================================================================================== test session starts ===========================================================================================================
platform linux -- Python 3.11.6, pytest-7.4.3, pluggy-1.3.0
rootdir: /home/tarislada/Behavitproject/lightning-pose
plugins: typeguard-4.1.5, hydra-core-1.3.2, anyio-4.1.0, torchtyping-0.1.4
collected 1 item
tests/models/test_custom_trackers.py using default image augmentation pipeline (resizing only)
Number of labeled images in the full dataset (train+val+test): 90
Size of -- train set: 85, val set: 2, test set: 3
Warning: the argument `seed` shadows a Pipeline constructor argument of the same name.
[/opt/dali/dali/util/nvml_wrap.cc:69] nvmlInitChecked failed: Driver/library version mismatch
E
================================================================================================================= ERRORS =================================================================================================================
_________________________________________________________________________________________________ ERROR at setup of test_supervised_YOLO _________________________________________________________________________________________________
cfg = {'data': {'image_orig_dims': {'width': 396, 'height': 406}, 'image_resize_dims': {'width': 256, 'height': 256}, 'data_...'total_unsupervised_importance', 'init_val': 0.0, 'increase_factor': 0.01, 'final_val': 1.0, 'freeze_until_epoch': 0}}}
base_dataset = <lightning_pose.data.datasets.BaseTrackingDataset object at 0x7fed2d156b50>, video_list = ['data/mirror-mouse-example/videos/test_vid.mp4']
@pytest.fixture
def video_dataloader(cfg, base_dataset, video_list) -> LitDaliWrapper:
"""Create a prediction dataloader for a new video."""
# setup
vid_pred_class = PrepareDALI(
train_stage="predict",
model_type="base",
dali_config=cfg.dali,
filenames=video_list,
resize_dims=[base_dataset.height, base_dataset.width],
)
> video_dataloader = vid_pred_class()
tests/conftest.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
lightning_pose/data/dali.py:397: in __call__
return LitDaliWrapper(pipe, **args[self.train_stage][self.model_type])
lightning_pose/data/dali.py:154: in __init__
super().__init__(*args, **kwargs)
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/plugin/pytorch.py:181: in __init__
_DaliBaseIterator.__init__(self,
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/plugin/base_iterator.py:198: in __init__
p.build()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.dali.pipeline.Pipeline object at 0x7fed2c2606d0>
def build(self):
"""Build the pipeline.
Pipeline needs to be built in order to run it standalone.
Framework-specific plugins handle this step automatically.
"""
if self._built:
return
if self.num_threads < 1:
raise ValueError("Pipeline created with `num_threads` < 1 can only be used "
"for serialization.")
self.start_py_workers()
if not self._backend_prepared:
self._init_pipeline_backend()
self._setup_pipe_pool_dependency()
> self._pipe.Build(self._generate_build_args())
E RuntimeError: nvml error (18): nvml: RM detects a driver/library version mismatch
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/pipeline.py:881: RuntimeError
============================================================================================================ warnings summary ============================================================================================================
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/lightning_utilities/core/imports.py:14
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/lightning_utilities/core/imports.py:14: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/lightning/fabric/__init__.py:40
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/lightning/fabric/__init__.py:40: Deprecated call to `pkg_resources.declare_namespace('lightning.fabric')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/pkg_resources/__init__.py:2350: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('lightning')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(parent)
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/lightning/pytorch/__init__.py:37
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/lightning/pytorch/__init__.py:37: Deprecated call to `pkg_resources.declare_namespace('lightning.pytorch')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/backend.py:46
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/backend.py:46: Warning: DALI support for Python 3.11 is experimental and some functionalities may not work.
deprecation_warning("DALI support for Python {0}.{1} is experimental and some "
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/_autograph/pyct/gast_util.py:78
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/_autograph/pyct/gast_util.py:78: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
if get_gast_version() < LooseVersion("0.5"):
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/setuptools/_distutils/version.py:345
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/setuptools/_distutils/version.py:345: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
other = LooseVersion(other)
../../mambaforge/envs/Action_transformer/lib/python3.11/site-packages/torch/cuda/__init__.py:611
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML
warnings.warn("Can't initialize NVML")
tests/models/test_custom_trackers.py::test_supervised_YOLO
/home/tarislada/mambaforge/envs/Action_transformer/lib/python3.11/site-packages/nvidia/dali/backend.py:92: Warning: nvidia-dali-cuda120 is no longer shipped with CUDA runtime. You need to install it separately. NPP is typically provided with CUDA Toolkit installation or an appropriate wheel. Please check https://docs.nvidia.com/cuda/cuda-quick-start-guide/index.html#pip-wheels-installation-linux for the reference.
deprecation_warning("nvidia-dali-cuda120 is no longer shipped with CUDA runtime. "
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================== short test summary info =========================================================================================================
ERROR tests/models/test_custom_trackers.py::test_supervised_YOLO - RuntimeError: nvml error (18): nvml: RM detects a driver/library version mismatch
I wonder if this is because the environment setup is cuda 12?
Possible it's a CUDA/DALI issue...
I also see that you use python=3.11.6. First, let's try to create a new conda environment with python=3.8
conda create --name <YOUR_ENVIRONMENT_NAME> python=3.8
Then re-install lightning-pose from your modified lightning-pose folder again (in editable mode, with extra models and development tools)
pip install -e ".[dev,extra_models]"
Then do pytest -vv again and copy pasta the remaining errors if you can.
@Tarislada just checking in to see if you're still working on this; otherwise I'll close the issue due to inactivity
closing due to inactivity