model-zoo icon indicating copy to clipboard operation
model-zoo copied to clipboard

model-zoo error about Breast density classification execute reasoning

Open shihzenq opened this issue 6 months ago • 5 comments

hello, My question is as follows image I have downloaded model.pth from the address。

code:

`import glob import os import torch import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms from torchvision.models import inception_v3 from monai.bundle import download, load

data_dir = 'breast_density_classification/sample_data'

test_images = sorted(glob.glob(os.path.join(data_dir, "A", "*.jpg")))

preprocess = transforms.Compose([ transforms.Resize((299, 299)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_path = "breast_density_classification/models/model.pth"

model = inception_v3(pretrained=False, aux_logits=False, num_classes=4).to(device)

state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.eval()

for img_path in test_images: img = Image.open(img_path).convert('RGB') img_tensor = preprocess(img).unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model(img_tensor)
    probs = torch.nn.functional.softmax(outputs, dim=1)
    pred_class = torch.argmax(probs, dim=1).item()

plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title(f'Predicted Class: {pred_class}')
plt.axis('off')
plt.show()`

error:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Inception3: Missing key(s) in state_dict: "Conv2d_1a_3x3.conv.weight", "Conv2d_1a_3x3.bn.weight",

This error indicates that the loaded state_dict does not match the defined InceptionV3 model. Did I do something wrong. This is an urgent problem. thank you

shihzenq avatar Jul 25 '24 09:07 shihzenq