anomalib
anomalib copied to clipboard
How to set a threshold at inference time?
I would like to set a threshold at test/inference time, so that I can tune the sensitivity of the model.
How would be possible to do that?
You can set model.image_threshold.value = ... model.pixel_threshold.value = ...
before inference
Please be aware that these are unnormalized values
@alexriedel1 this is not enough. Since if I set model.image_threshold.value = ... model.pixel_threshold.value = ...
, this will be overwritte by the function
trainer.predict(model=model, dataloaders=[dataloader])
, this function will also overwrite the model itself.
Can you show me your testing / inferencing script? I'm pretty sure, calling trainer.predict
does not affect the model.
As there is not predict
method implemented in the models in anomalib, the default model predict will be called: https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_basic.html#predict-step-with-your-lightningmodule
@alexriedel1 This is the full implementation of my script for inference on PatchCore
model = get_model(config)
callbacks = get_callbacks(config)
trainer = Trainer(callbacks=callbacks, **config.trainer)
# Set custom threshold
model.adaptive_threshold = False
model.pixel_threshold.value = torch.tensor(float(CUSTOM_NUMBER))
model.image_threshold.value = torch.tensor(float(CUSTOM_NUMBER))
transform_config = config.dataset.transform_config.val if "transform_config" in config.dataset.keys() else None
dataset = InferenceDataset(
my_args['input'], image_size=tuple(config.dataset.image_size), transform_config=transform_config
)
dataloader = DataLoader(dataset)
trainer.predict(model=model, dataloaders=[dataloader])
At the end of this script my model.image_threshold.value
will be the one tuned at training time.
I cannot reproduce the issue with a Patchcore Model
model.adaptive_threshold = False
model.pixel_threshold.value = torch.tensor(float(1))
model.image_threshold.value = torch.tensor(float(1))
print(model.pixel_threshold.value, model.image_threshold.value)
trainer.predict(model=model, dataloaders=datamodule)
print(model.pixel_threshold.value, model.image_threshold.value)
tensor(1.) tensor(1.)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, -4.90it/s]
tensor(1.) tensor(1.)
@alexriedel1 Printing the custom thresholds with print(model.pixel_threshold.value, model.image_threshold.value)
before and after the trainer.predict(), what I get is this, I paste also the warnings I get
tensor(100.) tensor(100.)
\venv\lib\site-packages\torchmetrics\utilities\prints.py:36: UserWarning: Metric `ROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
\venv\lib\site-packages\torchmetrics\utilities\prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (MinMax). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
\venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:240: PossibleUserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.31it/s]
tensor(39.0126) tensor(25.4964)
This is a pip freeze of the env:
absl-py==1.2.0
aiohttp==3.8.1
aiosignal==1.2.0
albumentations==1.2.1
analytics-python==1.4.0
-e git+https://github.com/openvinotoolkit/anomalib@f4903525c05cb8843d3e1fd7036f5d59eac36d79#egg=anomalib
antlr4-python3-runtime==4.9.3
anyio==3.6.1
async-timeout==4.0.2
attrs==22.1.0
backoff==1.10.0
bcrypt==4.0.0
cachetools==5.2.0
certifi==2022.6.15
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
colorama==0.4.5
cryptography==37.0.4
cycler==0.11.0
docker-pycreds==0.4.0
docstring-parser==0.14.1
einops==0.4.1
fastapi==0.81.0
ffmpy==0.3.0
Flask==2.2.2
fonttools==4.37.1
frozenlist==1.3.1
fsspec==2022.8.2
gitdb==4.0.9
GitPython==3.1.27
google-auth==2.11.0
google-auth-oauthlib==0.4.6
gradio==3.2
grpcio==1.47.0
h11==0.12.0
httpcore==0.15.0
httpx==0.23.0
idna==3.3
imageio==2.21.2
imgaug==0.4.0
importlib-metadata==4.12.0
itsdangerous==2.1.2
Jinja2==3.1.2
joblib==1.1.0
jsonargparse==4.13.2
kiwisolver==1.4.4
kornia==0.6.7
linkify-it-py==1.0.3
Markdown==3.4.1
markdown-it-py==2.1.0
MarkupSafe==2.1.1
matplotlib==3.5.3
mdit-py-plugins==0.3.0
mdurl==0.1.2
monotonic==1.6
multidict==6.0.2
networkx==2.8.6
numpy==1.23.2
oauthlib==3.2.0
omegaconf==2.2.3
opencv-python==4.6.0.66
opencv-python-headless==4.6.0.66
orjson==3.8.0
packaging==21.3
pandas==1.4.4
paramiko==2.11.0
pathtools==0.1.2
Pillow==9.2.0
promise==2.3
protobuf==3.19.4
psutil==5.9.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pycryptodome==3.15.0
pydantic==1.10.1
pyDeprecate==0.3.2
pydub==0.25.1
PyNaCl==1.5.0
pyparsing==3.0.9
python-dateutil==2.8.2
python-multipart==0.0.5
pytorch-lightning==1.6.5
pytz==2022.2.1
PyWavelets==1.3.0
PyYAML==6.0
qudida==0.0.4
requests==2.28.1
requests-oauthlib==1.3.1
rfc3986==1.5.0
rsa==4.9
scikit-image==0.19.3
scikit-learn==1.1.2
scipy==1.9.1
sentry-sdk==1.9.6
setproctitle==1.3.2
Shapely==1.8.4
shortuuid==1.0.9
six==1.16.0
smmap==5.0.0
sniffio==1.2.0
starlette==0.19.1
tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
threadpoolctl==3.1.0
tifffile==2022.8.12
timm==0.5.4
torch==1.11.0
torchmetrics==0.9.1
torchtext==0.12.0
torchvision==0.12.0
tqdm==4.64.0
typing_extensions==4.3.0
uc-micro-py==1.0.1
urllib3==1.26.12
uvicorn==0.18.3
wandb==0.12.17
websockets==10.3
Werkzeug==2.2.2
yarl==1.8.1
zipp==3.8.1
@alevangel The thresholds are overwritten when on_predict_start
is called from LoadModelCallback
. One idea would be to manually call torch.load
and pop LoadModelCallback
from the callbacks list. Then updating the thresholds will work. This solution might work for your use case for now but we will try to come up with a better design.
Here is a snippet of the change to the infer
function of lightning_inference.py
def infer():
"""Run inference."""
args = get_args()
config = get_configurable_parameters(config_path=args.config)
# This is commented as setting this adds LoadModelCallback to the callback list
# config.trainer.resume_from_checkpoint = str(args.weights)
config.visualization.show_images = args.show
config.visualization.mode = args.visualization_mode
if args.output: # overwrite save path
config.visualization.save_images = True
config.visualization.image_save_path = args.output
else:
config.visualization.save_images = False
model = get_model(config)
model.load_state_dict(torch.load(args.weights)["state_dict"]) # manually load weights
callbacks = get_callbacks(config)
trainer = Trainer(callbacks=callbacks, **config.trainer)
transform_config = config.dataset.transform_config.val if "transform_config" in config.dataset.keys() else None
dataset = InferenceDataset(
args.input, image_size=tuple(config.dataset.image_size), transform_config=transform_config
)
dataloader = DataLoader(dataset)
model.adaptive_threshold = False
model.pixel_threshold.value = torch.tensor(float(1))
model.image_threshold.value = torch.tensor(float(1))
print(model.pixel_threshold.value, model.image_threshold.value)
trainer.predict(model=model, dataloaders=[dataloader])
print(model.pixel_threshold.value, model.image_threshold.value)
@ashwinvaidya17 Thanks, this will work on avoid threshold overwriting. However I had to add this to make it work:
model_state_dict = torch.load(my_args['weights'], map_location=device)["state_dict"]
model_state_dict.pop("normalization_metrics.min", None) # this is a key that doesn't make the model load the state
model_state_dict.pop("normalization_metrics.max", None) # this is a key that doesn't make the model load the state
model.load_state_dict(model_state_dict)
But in this way the heatmap is not very normalized at all, even if the predicted mask is good.