model-zoo
model-zoo copied to clipboard
model-zoo error about Breast density classification execute reasoning
hello, My question is as follows
I have downloaded model.pth from the address。
code:
`import glob import os import torch import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms from torchvision.models import inception_v3 from monai.bundle import download, load
data_dir = 'breast_density_classification/sample_data'
test_images = sorted(glob.glob(os.path.join(data_dir, "A", "*.jpg")))
preprocess = transforms.Compose([ transforms.Resize((299, 299)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "breast_density_classification/models/model.pth"
model = inception_v3(pretrained=False, aux_logits=False, num_classes=4).to(device)
state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.eval()
for img_path in test_images: img = Image.open(img_path).convert('RGB') img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title(f'Predicted Class: {pred_class}')
plt.axis('off')
plt.show()`
error:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Inception3: Missing key(s) in state_dict: "Conv2d_1a_3x3.conv.weight", "Conv2d_1a_3x3.bn.weight",
This error indicates that the loaded state_dict does not match the defined InceptionV3 model. Did I do something wrong. This is an urgent problem. thank you