anomalib icon indicating copy to clipboard operation
anomalib copied to clipboard

How to set a threshold at inference time?

Open alevangel opened this issue 2 years ago • 7 comments

I would like to set a threshold at test/inference time, so that I can tune the sensitivity of the model.

How would be possible to do that?

alevangel avatar Sep 13 '22 08:09 alevangel

You can set model.image_threshold.value = ... model.pixel_threshold.value = ... before inference Please be aware that these are unnormalized values

alexriedel1 avatar Sep 14 '22 10:09 alexriedel1

@alexriedel1 this is not enough. Since if I set model.image_threshold.value = ... model.pixel_threshold.value = ... , this will be overwritte by the function trainer.predict(model=model, dataloaders=[dataloader]), this function will also overwrite the model itself.

alevangel avatar Sep 14 '22 15:09 alevangel

Can you show me your testing / inferencing script? I'm pretty sure, calling trainer.predict does not affect the model.

As there is not predict method implemented in the models in anomalib, the default model predict will be called: https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_basic.html#predict-step-with-your-lightningmodule

alexriedel1 avatar Sep 14 '22 16:09 alexriedel1

@alexriedel1 This is the full implementation of my script for inference on PatchCore

        model = get_model(config)
        callbacks = get_callbacks(config)

        trainer = Trainer(callbacks=callbacks, **config.trainer)

        # Set custom threshold
        model.adaptive_threshold = False
        model.pixel_threshold.value = torch.tensor(float(CUSTOM_NUMBER))
        model.image_threshold.value = torch.tensor(float(CUSTOM_NUMBER))

        transform_config = config.dataset.transform_config.val if "transform_config" in config.dataset.keys() else None
        dataset = InferenceDataset(
            my_args['input'], image_size=tuple(config.dataset.image_size), transform_config=transform_config
        )
        dataloader = DataLoader(dataset)
        trainer.predict(model=model, dataloaders=[dataloader])

At the end of this script my model.image_threshold.value will be the one tuned at training time.

alevangel avatar Sep 14 '22 16:09 alevangel

I cannot reproduce the issue with a Patchcore Model

    model.adaptive_threshold = False
    model.pixel_threshold.value = torch.tensor(float(1))
    model.image_threshold.value = torch.tensor(float(1))

    print(model.pixel_threshold.value, model.image_threshold.value)

    trainer.predict(model=model, dataloaders=datamodule)

    print(model.pixel_threshold.value, model.image_threshold.value)
tensor(1.) tensor(1.)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, -4.90it/s]
tensor(1.) tensor(1.)

alexriedel1 avatar Sep 14 '22 16:09 alexriedel1

@alexriedel1 Printing the custom thresholds with print(model.pixel_threshold.value, model.image_threshold.value) before and after the trainer.predict(), what I get is this, I paste also the warnings I get

tensor(100.) tensor(100.)
\venv\lib\site-packages\torchmetrics\utilities\prints.py:36: UserWarning: Metric `ROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
\venv\lib\site-packages\torchmetrics\utilities\prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (MinMax). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
\venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:240: PossibleUserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]
tensor(39.0126) tensor(25.4964)

This is a pip freeze of the env:

absl-py==1.2.0
aiohttp==3.8.1
aiosignal==1.2.0
albumentations==1.2.1
analytics-python==1.4.0
-e git+https://github.com/openvinotoolkit/anomalib@f4903525c05cb8843d3e1fd7036f5d59eac36d79#egg=anomalib
antlr4-python3-runtime==4.9.3
anyio==3.6.1
async-timeout==4.0.2
attrs==22.1.0
backoff==1.10.0
bcrypt==4.0.0
cachetools==5.2.0
certifi==2022.6.15
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
colorama==0.4.5
cryptography==37.0.4
cycler==0.11.0
docker-pycreds==0.4.0
docstring-parser==0.14.1
einops==0.4.1
fastapi==0.81.0
ffmpy==0.3.0
Flask==2.2.2
fonttools==4.37.1
frozenlist==1.3.1
fsspec==2022.8.2
gitdb==4.0.9
GitPython==3.1.27
google-auth==2.11.0
google-auth-oauthlib==0.4.6
gradio==3.2
grpcio==1.47.0
h11==0.12.0
httpcore==0.15.0
httpx==0.23.0
idna==3.3
imageio==2.21.2
imgaug==0.4.0
importlib-metadata==4.12.0
itsdangerous==2.1.2
Jinja2==3.1.2
joblib==1.1.0
jsonargparse==4.13.2
kiwisolver==1.4.4
kornia==0.6.7
linkify-it-py==1.0.3
Markdown==3.4.1
markdown-it-py==2.1.0
MarkupSafe==2.1.1
matplotlib==3.5.3
mdit-py-plugins==0.3.0
mdurl==0.1.2
monotonic==1.6
multidict==6.0.2
networkx==2.8.6
numpy==1.23.2
oauthlib==3.2.0
omegaconf==2.2.3
opencv-python==4.6.0.66
opencv-python-headless==4.6.0.66
orjson==3.8.0
packaging==21.3
pandas==1.4.4
paramiko==2.11.0
pathtools==0.1.2
Pillow==9.2.0
promise==2.3
protobuf==3.19.4
psutil==5.9.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pycryptodome==3.15.0
pydantic==1.10.1
pyDeprecate==0.3.2
pydub==0.25.1
PyNaCl==1.5.0
pyparsing==3.0.9
python-dateutil==2.8.2
python-multipart==0.0.5
pytorch-lightning==1.6.5
pytz==2022.2.1
PyWavelets==1.3.0
PyYAML==6.0
qudida==0.0.4
requests==2.28.1
requests-oauthlib==1.3.1
rfc3986==1.5.0
rsa==4.9
scikit-image==0.19.3
scikit-learn==1.1.2
scipy==1.9.1
sentry-sdk==1.9.6
setproctitle==1.3.2
Shapely==1.8.4
shortuuid==1.0.9
six==1.16.0
smmap==5.0.0
sniffio==1.2.0
starlette==0.19.1
tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
threadpoolctl==3.1.0
tifffile==2022.8.12
timm==0.5.4
torch==1.11.0
torchmetrics==0.9.1
torchtext==0.12.0
torchvision==0.12.0
tqdm==4.64.0
typing_extensions==4.3.0
uc-micro-py==1.0.1
urllib3==1.26.12
uvicorn==0.18.3
wandb==0.12.17
websockets==10.3
Werkzeug==2.2.2
yarl==1.8.1
zipp==3.8.1

alevangel avatar Sep 14 '22 18:09 alevangel

@alevangel The thresholds are overwritten when on_predict_start is called from LoadModelCallback. One idea would be to manually call torch.load and pop LoadModelCallback from the callbacks list. Then updating the thresholds will work. This solution might work for your use case for now but we will try to come up with a better design.

Here is a snippet of the change to the infer function of lightning_inference.py

def infer():
    """Run inference."""
    args = get_args()
    config = get_configurable_parameters(config_path=args.config)
    # This is commented as setting this adds LoadModelCallback to the callback list
    # config.trainer.resume_from_checkpoint = str(args.weights)
    config.visualization.show_images = args.show
    config.visualization.mode = args.visualization_mode
    if args.output:  # overwrite save path
        config.visualization.save_images = True
        config.visualization.image_save_path = args.output
    else:
        config.visualization.save_images = False

    model = get_model(config)
    model.load_state_dict(torch.load(args.weights)["state_dict"]) #  manually load weights
    callbacks = get_callbacks(config)

    trainer = Trainer(callbacks=callbacks, **config.trainer)

    transform_config = config.dataset.transform_config.val if "transform_config" in config.dataset.keys() else None
    dataset = InferenceDataset(
        args.input, image_size=tuple(config.dataset.image_size), transform_config=transform_config
    )
    dataloader = DataLoader(dataset)

    model.adaptive_threshold = False
    model.pixel_threshold.value = torch.tensor(float(1))
    model.image_threshold.value = torch.tensor(float(1))

    print(model.pixel_threshold.value, model.image_threshold.value)
    trainer.predict(model=model, dataloaders=[dataloader])
    print(model.pixel_threshold.value, model.image_threshold.value)

ashwinvaidya17 avatar Sep 22 '22 14:09 ashwinvaidya17

@ashwinvaidya17 Thanks, this will work on avoid threshold overwriting. However I had to add this to make it work:

model_state_dict = torch.load(my_args['weights'], map_location=device)["state_dict"]
model_state_dict.pop("normalization_metrics.min", None)  # this is a key that doesn't make the model load the state
model_state_dict.pop("normalization_metrics.max", None)  # this is a key that doesn't make the model load the state
model.load_state_dict(model_state_dict)

But in this way the heatmap is not very normalized at all, even if the predicted mask is good.

alevangel avatar Sep 23 '22 09:09 alevangel