AttributeError: 'InternImage' object has no attribute '_initialize_weights'
Code:
import os import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, roc_auc_score, roc_curve import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms, models from PIL import Image import copy import time import math from collections import Counter from tqdm import tqdm from torch.amp import autocast, GradScaler from transformers import AutoModel
seed=42
Ensure reproducibility
torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
Define constants
IMAGE_DIR = "/kaggle/input/safenet-ai-final-dataset/Notebooks + CSV + Images/Images/Images" BATCH_SIZE = 8 NUM_EPOCHS = 200 LEARNING_RATE = 0.0001 IMAGE_SIZE = 224 NUM_CLASSES = 3 PATIENCE = 3 #D_O=.5 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Filter dataframe to only include the classes we're interested in
target_classes = ['Harmless Trolling','Targeted Trolling','Provocative_Trolls']#,] 'Satirical Trolling' #target_classes = ['Satirical Trolling','Targeted Trolling','Provocative Trolls','Harmless Trolling'] #,Provocative Trolls','Explicit Harassment Trolls'
Custom dataset class
class MemeDataset(Dataset): def init(self, dataframe, image_dir, transform=None): self.dataframe = dataframe self.image_dir = image_dir self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
img_name = self.dataframe.iloc[idx]['image']
img_path = os.path.join(self.image_dir, img_name)
try:
image = Image.open(img_path).convert('RGB')
except Exception as e:
image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color='white')
label = self.dataframe.iloc[idx]['class_idx']
if self.transform:
image = self.transform(image)
return image, label
from transformers import AutoImageProcessor, AutoModelForImageClassification
Load processor and model
model_name = "OpenGVLab/internimage_b_1k_224" #processor = AutoImageProcessor.from_pretrained(model_name) model = model = AutoModelForImageClassification.from_pretrained( model_name, trust_remote_code=True, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True ) model.to(DEVICE)
train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
val_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
Create dataset objects
train_dataset = MemeDataset(train_df, IMAGE_DIR, transform=train_transforms) val_dataset = MemeDataset(val_df, IMAGE_DIR, transform=val_transforms) test_dataset = MemeDataset(test_df, IMAGE_DIR, transform=val_transforms)
Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
'''
Compute class weights
class_counts = Counter(train_df['class_idx']) total_samples = sum(class_counts.values()) num_classes = len(class_counts)
class_weights = torch.tensor([ math.log(total_samples / class_counts[i]) for i in range(num_classes) ], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights) #'''
criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) scaler = GradScaler()
from tqdm import tqdm from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, classification_report, roc_curve
best_f1 = 0.0 no_improve_epochs = 0 best_model_wts = copy.deepcopy(model.state_dict())
print("Starting training...") for epoch in range(NUM_EPOCHS): print(f'Epoch {epoch+1}/{NUM_EPOCHS}') print('-' * 10)
# Training phase
model.train()
running_loss = 0.0
running_corrects = 0
all_preds_train = []
all_labels_train = []
for inputs, labels in tqdm(train_loader, desc="Training"):
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
optimizer.zero_grad()
with autocast(device_type='cuda'):
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
all_preds_train.extend(preds.cpu().numpy())
all_labels_train.extend(labels.cpu().numpy())
train_loss = running_loss / len(train_loader.dataset)
train_acc = running_corrects.double() / len(train_loader.dataset)
train_f1 = f1_score(all_labels_train, all_preds_train, average='macro')
train_auc = roc_auc_score(all_labels_train, np.eye(NUM_CLASSES)[all_preds_train], multi_class='ovr')
# Validation phase
model.eval()
val_preds, val_labels = [], []
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Validation"):
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
val_preds.extend(preds.cpu().numpy())
val_labels.extend(labels.cpu().numpy())
val_acc = accuracy_score(val_labels, val_preds)
val_f1 = f1_score(val_labels, val_preds, average='macro')
val_auc = roc_auc_score(val_labels, np.eye(NUM_CLASSES)[val_preds], multi_class='ovr')
print(f'Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | Train ROC-AUC: {train_auc:.4f}')
print(f'Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val ROC-AUC: {val_auc:.4f}')
# Early stopping based on validation macro F1 score
if val_f1 > best_f1:
best_f1 = val_f1
best_model_wts = copy.deepcopy(model.state_dict())
no_improve_epochs = 0
else:
no_improve_epochs += 1
print(f'Early stopping Count {no_improve_epochs}')
if no_improve_epochs >= PATIENCE:
print(f'Early stopping triggered after {epoch+1} epochs')
break
print(f'Best Validation f1:{best_f1:.4f}')
Load best model weights
#model.load_state_dict(best_model_wts)
Evaluate model on test set
model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for inputs, labels in tqdm(test_loader,desc='Testing'): inputs = inputs.to(DEVICE) labels = labels.to(DEVICE) outputs = model(inputs) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy())
test_acc = accuracy_score(all_labels, all_preds) test_f1 = f1_score(all_labels, all_preds, average='macro') test_auc = roc_auc_score(all_labels, np.eye(NUM_CLASSES)[all_preds], multi_class='ovr') print('Swin Transformer') print("Test Classification Report:") print(classification_report(all_labels, all_preds, target_names=target_classes, digits=4)) print(f'Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test ROC-AUC: {test_auc:.4f}')
ROC Curve Visualization
plt.figure(figsize=(8, 6)) for i in range(NUM_CLASSES): fpr, tpr, _ = roc_curve(np.array(all_labels) == i, np.array(all_preds) == i) plt.plot(fpr, tpr, label=f'Class {target_classes[i]}') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC Curve') plt.legend() plt.show()
Confusion Matrix
plt.figure(figsize=(10, 8)) cm = confusion_matrix(all_labels, all_preds) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_classes, yticklabels=target_classes) plt.xlabel('Predicted') plt.ylabel('True') plt.title('Confusion Matrix') plt.tight_layout() plt.savefig('confusion_matrix.png') plt.show()
Save the model
torch.save(model.state_dict(), 'meme_classifier.pth')
Error:
AttributeError Traceback (most recent call last) /tmp/ipykernel_36/1370559554.py in <cell line: 0>() 93 model_name = "OpenGVLab/internimage_b_1k_224" 94 #processor = AutoImageProcessor.from_pretrained(model_name) ---> 95 model = model = AutoModelForImageClassification.from_pretrained( 96 model_name, 97 trust_remote_code=True,
/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) 562 elif type(config) in cls._model_mapping.keys(): 563 model_class = _get_model_class(config, cls._model_mapping) --> 564 return model_class.from_pretrained( 565 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs 566 )
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _wrapper(*args, **kwargs)
307
308 def get_state_dict_dtype(state_dict):
--> 309 """
310 Returns the first found floating dtype in state_dict if there is one, otherwise returns the first dtype.
311 """
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs) 4572 f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" 4573 " TRAIN this model on a down-stream task to be able to use it for predictions and inference." -> 4574 ) 4575 elif len(mismatched_keys) == 0: 4576 logger.info(
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only) 4882 start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 4883 -> 4884 x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 4885 x = self.activation(x) 4886 x = self.LayerNorm(x)
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _initialize_missing_keys(self, loaded_keys, ignore_mismatched_sizes, is_quantized)
/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs) 114 def decorate_context(*args, **kwargs): 115 with ctx_factory(): --> 116 return func(*args, **kwargs) 117 118 return decorate_context
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in initialize_weights(self) 2554 if token is not None: 2555 kwargs["token"] = token -> 2556 2557 _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) 2558
/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in smart_apply(self, fn)
2545 "The use_auth_token argument is deprecated and will be removed in v5 of Transformers. Please use token instead.",
2546 FutureWarning,
-> 2547 )
2548 if token is not None:
2549 raise ValueError(
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in getattr(self, name) 1926 if name in modules: 1927 return modules[name] -> 1928 raise AttributeError( 1929 f"'{type(self).name}' object has no attribute '{name}'" 1930 )
AttributeError: 'InternImage' object has no attribute '_initialize_weights'