doctr icon indicating copy to clipboard operation
doctr copied to clipboard

Cannot use builtin datasets for detection training

Open KenjiTakahashi opened this issue 11 months ago • 8 comments

Bug description

I've modified the reference (PyTorch) training code to use SVHN dataset. That did not work (see traceback). By looking at what DetectionDataset class is doing and some trial and error, I managed to get it working (or at least running) by restructuring the targets from the dataset as follows:

targets = [{"words": t} for t in targets]

Don't know if this is "valid" fix, though.

Code snippet to reproduce the bug

See https://gist.github.com/KenjiTakahashi/9bb22093d584bb2b203eb003a2bbb414.

Like mentioned, this is mostly the same code as https://github.com/mindee/doctr/blob/e6bf82d6a74a52cedac17108e596b9265c4e43c5/references/detection/train_pytorch.py with slight modifications to work with SVHN class instead of DetectionDataset.

Error traceback

Traceback (most recent call last):                                                                                                                                                   | 0/16701 [00:00<?, ?it/s]
  File "<string>", line 1, in <module>
  File "ocr/src/ocr/main2.py", line 499, in main
    _main(args)
  File "ocr/src/ocr/main2.py", line 405, in _main
    fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp)
  File "ocr/src/ocr/main2.py", line 126, in fit_one_epoch
    train_loss = model(images, targets)["loss"]
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "ocr/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "ocr/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/pytorch.py", line 208, in forward
    loss = self.compute_loss(logits, target)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/pytorch.py", line 231, in compute_loss
    targets = self.build_target(target, out_map.shape[1:], False)  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/base.py", line 177, in build_target
    if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "ocr/.venv/lib/python3.12/site-packages/doctr/models/detection/fast/base.py", line 177, in <genexpr>
    if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
                                                            ^^^^^^^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'values'

Environment

This script does not seem to work well on (my) MacOS, it mostly returns N/A's. Anyway, I run it currently on MacOS 14 with latest doctr (tried both 0.10 release and master at e6bf82d) and PyTorch.

Deep Learning backend

I'm using PyTorch, but same problem happens on TF as well.

KenjiTakahashi avatar Jan 05 '25 17:01 KenjiTakahashi

Hi @KenjiTakahashi :wave:,

Yeah you are right currently only the validation scripts provides direct support for the built-in datasets .. so your "quick fix" is valid.

Support for the built-in datasets is something we could add especially for pretraining on SynthText it would be useful.

The issue here is SVHN because it contains really small image crops of house numbers and currently we resize everything to 1024x1024 by keeping the aspect ratio and symmetric padding .. which is well chosen for the document case.

Have you tried any of the other built-in datasets ? With every other it should work (only SVHN is a bit "special")

Best regards, Felix :)

felixdittrich92 avatar Jan 06 '25 07:01 felixdittrich92

@felixdittrich92,

i am working on this, if you have any reference or any further info regarding this issue would be helpful.

thanks :)

sarjil77 avatar Mar 12 '25 17:03 sarjil77

Hi @sarjil77 👋,

So the idea is that people can use the builtin datasets we provide (only the onces which are directly available via download where we don't need to provide the path to an images/labels file/folder) for the detection and recognition training

I would suggest to start with the recognition that should be easier.

Currently we have the train options:

  • from local files
  • to use our synth word generator
parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)")
    parser.add_argument("--val_path", type=str, default=None, help="path to val data folder")
    parser.add_argument(
        "--train-samples",
        type=int,
        default=1000,
        help="Multiplied by the vocab length gets you the number of synthetic training samples that will be used.",
    )
    parser.add_argument(
        "--val-samples",
        type=int,
        default=20,
        help="Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.",
    )
    parser.add_argument(
        "--font", type=str, default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", help="Font family to be used"
    )
    parser.add_argument("--min-chars", type=int, default=1, help="Minimum number of characters per synthetic sample")
    parser.add_argument("--max-chars", type=int, default=12, help="Maximum number of characters per synthetic sample")

A third option then would be something like:

parser.add_argument(
    "--val-datasets",
    type=str,
    nargs="+",
    choices=["FUNSD", "CORD", "dataset3", ...],
    default=None,
    help="Builtin dataset names (choose from: dataset1, dataset2, dataset3)",
)

Same for the --train-datasets

Then modify: https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/train_pytorch.py#L214 and https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/train_pytorch.py#L287

If multiple builtin datasets are passed we init the first passed and extend the data from the other ones like here:

https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/evaluate_pytorch.py#L95

For detection it looks mostly similar but requires a little trick to transform the data format ...so let's do 2 PR's starting with the recognition

Does this help ? :)

felixdittrich92 avatar Mar 13 '25 09:03 felixdittrich92

@felixdittrich92,

yeah, this helps, i have already started working on this.

thanks :)

sarjil77 avatar Mar 14 '25 10:03 sarjil77

hey @felixdittrich92,

i did some changes and was able to load the multiple builtin dataset, but i am facing some issues,

using:

python train_pytorch.py crnn_vgg16_bn --train_datasets "FUNSD" "SVHN" --val_datasets 
"FUNSD" "SVHN"

Namespace(arch='crnn_vgg16_bn', output_dir='.', train_path=None, val_path=None, train_samples=1000, val_samples=20, train_datasets=['FUNSD', 'SVHN'], val_datasets=['FUNSD', 'SVHN'], font='FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf', min_chars=1, max_chars=12, name=None, epochs=10, batch_size=64, device=None, input_size=32, lr=0.001, weight_decay=0, workers=None, resume=None, vocab='french', test_only=False, freeze_backbone=False, show_samples=False, wb=False, clearml=False, push_to_hub=False, pretrained=False, optim='adam', sched='cosine', amp=False, find_lr=False, early_stop=False, early_stop_epochs=5, early_stop_delta=0.01)
Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2828.26it/s]
Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:14<00:00, 451.31it/s]
Validation set loaded in 75.68s (33551 samples in 525 batches)
/data/aiuserinj/sarjil/doctr/doctr/models/utils/pytorch.py:62: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(archive_path, map_location="cpu")
Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2570.29it/s]
Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:11<00:00, 469.69it/s]
Train set loaded in 72.65s (33551 samples in 524 batches)
  0%|                                                                                                                                     | 0/524 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 623, in <module>
    main(args)
  File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 501, in main
    train_loss, actual_lr = fit_one_epoch(
                            ^^^^^^^^^^^^^^
  File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 126, in fit_one_epoch
    for images, targets in pbar:
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
    return self._process_data(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
    data.reraise()
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/sarjil/doctr/doctr/datasets/datasets/pytorch.py", line 53, in collate_fn
    images = torch.stack(images, dim=0)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 139] at entry 0 and [3, 76, 160] at entry 1

so should i pad here or resize or any other way you suggest.

sarjil77 avatar Mar 16 '25 20:03 sarjil77

hey @felixdittrich92,

i did some changes and was able to load the multiple builtin dataset, but i am facing some issues,

using:

python train_pytorch.py crnn_vgg16_bn --train_datasets "FUNSD" "SVHN" --val_datasets 
"FUNSD" "SVHN"

Namespace(arch='crnn_vgg16_bn', output_dir='.', train_path=None, val_path=None, train_samples=1000, val_samples=20, train_datasets=['FUNSD', 'SVHN'], val_datasets=['FUNSD', 'SVHN'], font='FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf', min_chars=1, max_chars=12, name=None, epochs=10, batch_size=64, device=None, input_size=32, lr=0.001, weight_decay=0, workers=None, resume=None, vocab='french', test_only=False, freeze_backbone=False, show_samples=False, wb=False, clearml=False, push_to_hub=False, pretrained=False, optim='adam', sched='cosine', amp=False, find_lr=False, early_stop=False, early_stop_epochs=5, early_stop_delta=0.01)
Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2828.26it/s]
Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:14<00:00, 451.31it/s]
Validation set loaded in 75.68s (33551 samples in 525 batches)
/data/aiuserinj/sarjil/doctr/doctr/models/utils/pytorch.py:62: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(archive_path, map_location="cpu")
Preparing and Loading FUNSD: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:00<00:00, 2570.29it/s]
Preparing and Loading SVHN: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 33402/33402 [01:11<00:00, 469.69it/s]
Train set loaded in 72.65s (33551 samples in 524 batches)
  0%|                                                                                                                                     | 0/524 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 623, in <module>
    main(args)
  File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 501, in main
    train_loss, actual_lr = fit_one_epoch(
                            ^^^^^^^^^^^^^^
  File "/data/aiuserinj/sarjil/doctr/references/recognition/train_pytorch.py", line 126, in fit_one_epoch
    for images, targets in pbar:
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
    return self._process_data(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
    data.reraise()
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/anaconda3/envs/sarjil_dev/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/aiuserinj/sarjil/doctr/doctr/datasets/datasets/pytorch.py", line 53, in collate_fn
    images = torch.stack(images, dim=0)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 139] at entry 0 and [3, 76, 160] at entry 1

so should i pad here or resize or any other way you suggest.

Yeah at a minimum we need to apply the Resize transform like here: https://github.com/mindee/doctr/blob/3e213203a1c5c82d87cffc6a664a7f4274ffc82c/references/recognition/train_pytorch.py#L232 :)

felixdittrich92 avatar Mar 16 '25 20:03 felixdittrich92

Addition: It should be enough to define the img_transforms on the first initial dataset because we extend this one only with the followed onces (if more than one is given)

Pseudo code (untested):

train_datasets = args.train_datasets
train_set = datasets.__dict__[train_datasets[0]](
        train=True,
        download=True,
        recognition_task=True,
        use_polygons=True,
        img_transforms=Compose([
                T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
                # Augmentations
                T.RandomApply(T.ColorInversion(), 0.1),
            ]),
    )
if len(train_datasets) > 1:
  for dataset_name in train_datasets[1:]:
      _ds = datasets.__dict__[dataset_name](
          train=True,
          download=True,
          recognition_task=True,
          use_polygons=True,
      )
      train_set.data.extend((np_img, target) for np_img, target in _ds.data)

And for val_set the same but with train=False and img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True)

felixdittrich92 avatar Mar 17 '25 09:03 felixdittrich92

  • [x] Recognition scripts #1904
  • [ ] Detection scripts
  • [ ] Optional: Built-in datasets which require local files (IMGUR5K, COCOTEXT, IC13, MySynth, WILDRECEIPT, IIITHWS)

felixdittrich92 avatar Mar 26 '25 13:03 felixdittrich92