ttach icon indicating copy to clipboard operation
ttach copied to clipboard

AttributeError: 'SegmentationTTAWrapper' object has no attribute 'predict'

Open chefkrym opened this issue 2 years ago • 0 comments

`import torch import ttach as tta import timm import numpy as np from PIL import Image import matplotlib.pyplot as plt import cv2

model = torch.load('E:/PhD_Projects/egmentation models/new model weights/UNet_mitb2_thresh0.3.pth')

tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")

image_dir = 'E:/PhD_Projects/segmentation models/patches' image_filename_2 = 'image__02_02.tif' image_path = os.path.join(image_dir, image_filename_2) image = tiff.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

preprocessing_fn_inference = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) preprocessing_inference=get_preprocessing(preprocessing_fn_inference) sample = preprocessing_inference(image=image) image = sample['image']

x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0) pr_mask = tta_model.predict(x_tensor) pr_mask = (pr_mask.squeeze().cpu().numpy().round()) pr_mask = (pr_mask.astype('float') * 255.0/16) #pr_mask = (pr_mask.astype('float') * 255.0/16).astype('uint8')

=============================================================================

plt.imshow(pr_mask) plt.show()`

Can anyone help me with this prediction problem? Thank you. @qubvel

chefkrym avatar Jul 13 '23 10:07 chefkrym