training_extensions
training_extensions copied to clipboard
Implement sampler API & add BalancedSampler
Summary
- Configuration to make the Sampler mutable.
- Add
BalancedSampler
,RepeatSampler
BalancedSampler in OTX 1.6 : https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/algorithms/common/adapters/torch/dataloaders/samplers/balanced_sampler.py
otx train \
--config src/otx/recipe/classification/multi_class_cls/otx_efficientnet_b0.yaml \
--data_root /home/harimkan/workspace/repo/datasets/otx_v2_dataset/multiclass_classification/multiclass_CUB_medium \
--seed 1234 \
--deterministic True \
--data.config.train_subset.sampler.class_path otx.algo.samplers.balanced_sampler.BalancedSampler \
--work_dir otx-cls-balanced-sampler
Here is the result of running a simple Regression Dataset.
*seed: 1234, deterministic: True (The example below is a random sampler generated with np.Generator, so it may look different than the current state with torch.Generator.)
Just for reference, the current Regression Test dataset is not a typical dataset and does not have unbalanced classes. (We need to experiment more.)
So the current default is the same RandomSampler as before.
How to test
Checklist
- [ ] I have added unit tests to cover my changes.
- [ ] I have added integration tests to cover my changes.
- [ ] I have added e2e tests for validation.
- [ ] I have added the description of my changes into CHANGELOG in my target branch (e.g., CHANGELOG in develop).
- [ ] I have updated the documentation in my target branch accordingly (e.g., documentation in develop).
- [ ] I have linked related issues.
License
- [ ] I submit my code changes under the same Apache License that covers the project. Feel free to contact the maintainers if that's a concern.
- [ ] I have updated the license header for each file (see an example below).
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0