Cannot use builtin datasets for detection training
Bug description
I've modified the reference (PyTorch) training code to use SVHN dataset.
That did not work (see traceback).
By looking at what DetectionDataset class is doing and some trial and error, I managed to get it working (or at least running) by restructuring the targets from the dataset as follows:
targets = [{"words": t} for t in targets]
Don't know if this is "valid" fix, though.
Code snippet to reproduce the bug
See https://gist.github.com/KenjiTakahashi/9bb22093d584bb2b203eb003a2bbb414.
Like mentioned, this is mostly the same code as
https://github.com/mindee/doctr/blob/e6bf82d6a74a52cedac17108e596b9265c4e43c5/references/detection/train_pytorch.py
with slight modifications to work with SVHN class instead of DetectionDataset.
Error traceback
Traceback (most recent call last): | 0/16701 [00:00<?, ?it/s]
File "<string>", line 1, in <module>
File "ocr/src/ocr/main2.py", line 499, in main
_main(args)
File "ocr/src/ocr/main2.py", line 405, in _main
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp)
File "ocr/src/ocr/main2.py", line 126, in fit_one_epoch
train_loss = model(images, targets)["loss"]
^^^^^^^^^^^^^^^^^^^^^^
File "ocr/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "ocr/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/pytorch.py", line 208, in forward
loss = self.compute_loss(logits, target)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/pytorch.py", line 231, in compute_loss
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/base.py", line 177, in build_target
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/base.py", line 177, in <genexpr>
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
^^^^^^^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'values'
Environment
This script does not seem to work well on (my) MacOS, it mostly returns N/A's. Anyway, I run it currently on MacOS 14 with latest doctr (tried both 0.10 release and master at e6bf82d) and PyTorch.
Deep Learning backend
I'm using PyTorch, but same problem happens on TF as well.
Hi @KenjiTakahashi :wave:,
Yeah you are right currently only the validation scripts provides direct support for the built-in datasets .. so your "quick fix" is valid.
Support for the built-in datasets is something we could add especially for pretraining on SynthText it would be useful.
The issue here is SVHN because it contains really small image crops of house numbers and currently we resize everything to 1024x1024 by keeping the aspect ratio and symmetric padding .. which is well chosen for the document case.
Have you tried any of the other built-in datasets ? With every other it should work (only SVHN is a bit "special")
Best regards, Felix :)
@felixdittrich92,
i am working on this, if you have any reference or any further info regarding this issue would be helpful.
thanks :)
Hi @sarjil77 👋,
So the idea is that people can use the builtin datasets we provide (only the onces which are directly available via download where we don't need to provide the path to an images/labels file/folder) for the detection and recognition training
I would suggest to start with the recognition that should be easier.
Currently we have the train options:
- from local files
- to use our synth word generator
parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)")
parser.add_argument("--val_path", type=str, default=None, help="path to val data folder")
parser.add_argument(
"--train-samples",
type=int,
default=1000,
help="Multiplied by the vocab length gets you the number of synthetic training samples that will be used.",
)
parser.add_argument(
"--val-samples",
type=int,
default=20,
help="Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.",
)
parser.add_argument(
"--font", type=str, default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", help="Font family to be used"
)
parser.add_argument("--min-chars", type=int, default=1, help="Minimum number of characters per synthetic sample")
parser.add_argument("--max-chars", type=int, default=12, help="Maximum number of characters per synthetic sample")
A third option then would be something like:
parser.add_argument(
"--val-datasets",
type=str,
nargs="+",
choices=["FUNSD", "CORD", "dataset3", ...],
default=None,
help="Builtin dataset names (choose from: dataset1, dataset2, dataset3)",
)
Same for the --train-datasets
Then modify: https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/train_pytorch.py#L214 and https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/train_pytorch.py#L287
If multiple builtin datasets are passed we init the first passed and extend the data from the other ones like here:
https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/evaluate_pytorch.py#L95
For detection it looks mostly similar but requires a little trick to transform the data format ...so let's do 2 PR's starting with the recognition
Does this help ? :)
@felixdittrich92,
yeah, this helps, i have already started working on this.
thanks :)
hey @felixdittrich92,
i did some changes and was able to load the multiple builtin dataset, but i am facing some issues,
using:
python train_pytorch.py crnn_vgg16_bn --train_datasets "FUNSD" "SVHN" --val_datasets
"FUNSD" "SVHN"
Namespace(arch='crnn_vgg16_bn', output_dir='.', train_path=None, val_path=None, train_samples=1000, val_samples=20, train_datasets=['FUNSD', 'SVHN'], val_datasets=['FUNSD', 'SVHN'], font='FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf', min_chars=1, max_chars=12, name=None, epochs=10, batch_size=64, device=None, input_size=32, lr=0.001, weight_decay=0, workers=None, resume=None, vocab='french', test_only=False, freeze_backbone=False, show_samples=False, wb=False, clearml=False, push_to_hub=False, pretrained=False, optim='adam', sched='cosine', amp=False, find_lr=False, early_stop=False, early_stop_epochs=5, early_stop_delta=0.01)
Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2828.26it/s]
Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:14<00:00, 451.31it/s]
Validation set loaded in 75.68s (33551 samples in 525 batches)
/data/aiuserinj/sarjil/doctr/doctr/models/utils/pytorch.py:62: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(archive_path, map_location="cpu")
Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2570.29it/s]
Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:11<00:00, 469.69it/s]
Train set loaded in 72.65s (33551 samples in 524 batches)
0%| | 0/524 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 623, in <module>
main(args)
File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 501, in main
train_loss, actual_lr = fit_one_epoch(
^^^^^^^^^^^^^^
File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 126, in fit_one_epoch
for images, targets in pbar:
File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__
for obj in iterable:
File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
return self._process_data(data)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
data.reraise()
File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
^^^^^^^^^^^^^^^^^^^^
File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
return self.collate_fn(data)
^^^^^^^^^^^^^^^^^^^^^
File "/data/aiuserinj/sarjil/doctr/doctr/datasets/datasets/pytorch.py", line 53, in collate_fn
images = torch.stack(images, dim=0)
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 139] at entry 0 and [3, 76, 160] at entry 1
so should i pad here or resize or any other way you suggest.
hey @felixdittrich92,
i did some changes and was able to load the multiple builtin dataset, but i am facing some issues,
using:
python train_pytorch.py crnn_vgg16_bn --train_datasets "FUNSD" "SVHN" --val_datasets "FUNSD" "SVHN"Namespace(arch='crnn_vgg16_bn', output_dir='.', train_path=None, val_path=None, train_samples=1000, val_samples=20, train_datasets=['FUNSD', 'SVHN'], val_datasets=['FUNSD', 'SVHN'], font='FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf', min_chars=1, max_chars=12, name=None, epochs=10, batch_size=64, device=None, input_size=32, lr=0.001, weight_decay=0, workers=None, resume=None, vocab='french', test_only=False, freeze_backbone=False, show_samples=False, wb=False, clearml=False, push_to_hub=False, pretrained=False, optim='adam', sched='cosine', amp=False, find_lr=False, early_stop=False, early_stop_epochs=5, early_stop_delta=0.01) Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2828.26it/s] Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:14<00:00, 451.31it/s] Validation set loaded in 75.68s (33551 samples in 525 batches) /data/aiuserinj/sarjil/doctr/doctr/models/utils/pytorch.py:62: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. state_dict = torch.load(archive_path, map_location="cpu") Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2570.29it/s] Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:11<00:00, 469.69it/s] Train set loaded in 72.65s (33551 samples in 524 batches) 0%| | 0/524 [00:01<?, ?it/s] Traceback (most recent call last): File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 623, in <module> main(args) File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 501, in main train_loss, actual_lr = fit_one_epoch( ^^^^^^^^^^^^^^ File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 126, in fit_one_epoch for images, targets in pbar: File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__ for obj in iterable: File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__ data = self._next_data() ^^^^^^^^^^^^^^^^^ File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data return self._process_data(data) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data data.reraise() File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise raise exception RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop data = fetcher.fetch(index) # type: ignore[possibly-undefined] ^^^^^^^^^^^^^^^^^^^^ File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch return self.collate_fn(data) ^^^^^^^^^^^^^^^^^^^^^ File "/data/aiuserinj/sarjil/doctr/doctr/datasets/datasets/pytorch.py", line 53, in collate_fn images = torch.stack(images, dim=0) ^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 139] at entry 0 and [3, 76, 160] at entry 1so should i pad here or resize or any other way you suggest.
Yeah at a minimum we need to apply the Resize transform like here: https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/train_pytorch.py#L232 :)
Addition:
It should be enough to define the img_transforms on the first initial dataset because we extend this one only with the followed onces (if more than one is given)
Pseudo code (untested):
train_datasets = args.train_datasets
train_set = datasets.__dict__[train_datasets[0]](
train=True,
download=True,
recognition_task=True,
use_polygons=True,
img_transforms=Compose([
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
]),
)
if len(train_datasets) > 1:
for dataset_name in train_datasets[1:]:
_ds = datasets.__dict__[dataset_name](
train=True,
download=True,
recognition_task=True,
use_polygons=True,
)
train_set.data.extend((np_img, target) for np_img, target in _ds.data)
And for val_set the same but with train=False and img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True)
- [x] Recognition scripts #1904
- [ ] Detection scripts
- [ ] Optional: Built-in datasets which require local files (IMGUR5K, COCOTEXT, IC13, MySynth, WILDRECEIPT, IIITHWS)