InternImage icon indicating copy to clipboard operation
InternImage copied to clipboard

AttributeError: 'InternImage' object has no attribute '_initialize_weights'

Open secrakib opened this issue 5 months ago • 0 comments

Code:

import os import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, roc_auc_score, roc_curve import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms, models from PIL import Image import copy import time import math from collections import Counter from tqdm import tqdm from torch.amp import autocast, GradScaler from transformers import AutoModel

seed=42

Ensure reproducibility

torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False

Define constants

IMAGE_DIR = "/kaggle/input/safenet-ai-final-dataset/Notebooks + CSV + Images/Images/Images" BATCH_SIZE = 8 NUM_EPOCHS = 200 LEARNING_RATE = 0.0001 IMAGE_SIZE = 224 NUM_CLASSES = 3 PATIENCE = 3 #D_O=.5 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Filter dataframe to only include the classes we're interested in

target_classes = ['Harmless Trolling','Targeted Trolling','Provocative_Trolls']#,] 'Satirical Trolling' #target_classes = ['Satirical Trolling','Targeted Trolling','Provocative Trolls','Harmless Trolling'] #,Provocative Trolls','Explicit Harassment Trolls'

Custom dataset class

class MemeDataset(Dataset): def init(self, dataframe, image_dir, transform=None): self.dataframe = dataframe self.image_dir = image_dir self.transform = transform

def __len__(self):
    return len(self.dataframe)

def __getitem__(self, idx):
    img_name = self.dataframe.iloc[idx]['image']
    img_path = os.path.join(self.image_dir, img_name)
    
    try:
        image = Image.open(img_path).convert('RGB')
    except Exception as e:
        image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color='white')
        
    label = self.dataframe.iloc[idx]['class_idx']
    
    if self.transform:
        image = self.transform(image)
    
    return image, label

from transformers import AutoImageProcessor, AutoModelForImageClassification

Load processor and model

model_name = "OpenGVLab/internimage_b_1k_224" #processor = AutoImageProcessor.from_pretrained(model_name) model = model = AutoModelForImageClassification.from_pretrained( model_name, trust_remote_code=True, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True ) model.to(DEVICE)

train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])

val_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])

Create dataset objects

train_dataset = MemeDataset(train_df, IMAGE_DIR, transform=train_transforms) val_dataset = MemeDataset(val_df, IMAGE_DIR, transform=val_transforms) test_dataset = MemeDataset(test_df, IMAGE_DIR, transform=val_transforms)

Create data loaders

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

'''

Compute class weights

class_counts = Counter(train_df['class_idx']) total_samples = sum(class_counts.values()) num_classes = len(class_counts)

class_weights = torch.tensor([ math.log(total_samples / class_counts[i]) for i in range(num_classes) ], dtype=torch.float32).to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights) #'''

criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) scaler = GradScaler()

from tqdm import tqdm from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, classification_report, roc_curve

best_f1 = 0.0 no_improve_epochs = 0 best_model_wts = copy.deepcopy(model.state_dict())

print("Starting training...") for epoch in range(NUM_EPOCHS): print(f'Epoch {epoch+1}/{NUM_EPOCHS}') print('-' * 10)

# Training phase
model.train()
running_loss = 0.0
running_corrects = 0
all_preds_train = []
all_labels_train = []

for inputs, labels in tqdm(train_loader, desc="Training"):
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE)
    
    optimizer.zero_grad()
    
    with autocast(device_type='cuda'):
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data)
    all_preds_train.extend(preds.cpu().numpy())
    all_labels_train.extend(labels.cpu().numpy())

train_loss = running_loss / len(train_loader.dataset)
train_acc = running_corrects.double() / len(train_loader.dataset)
train_f1 = f1_score(all_labels_train, all_preds_train, average='macro')
train_auc = roc_auc_score(all_labels_train, np.eye(NUM_CLASSES)[all_preds_train], multi_class='ovr')

# Validation phase
model.eval()
val_preds, val_labels = [], []
with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc="Validation"):
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        val_preds.extend(preds.cpu().numpy())
        val_labels.extend(labels.cpu().numpy())

val_acc = accuracy_score(val_labels, val_preds)
val_f1 = f1_score(val_labels, val_preds, average='macro')
val_auc = roc_auc_score(val_labels, np.eye(NUM_CLASSES)[val_preds], multi_class='ovr')

print(f'Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | Train ROC-AUC: {train_auc:.4f}')
print(f'Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val ROC-AUC: {val_auc:.4f}')

# Early stopping based on validation macro F1 score
if val_f1 > best_f1:
    best_f1 = val_f1
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improve_epochs = 0
else:
    no_improve_epochs += 1
    print(f'Early stopping Count {no_improve_epochs}')

if no_improve_epochs >= PATIENCE:
    print(f'Early stopping triggered after {epoch+1} epochs')
    break

print(f'Best Validation f1:{best_f1:.4f}')

Load best model weights

#model.load_state_dict(best_model_wts)

Evaluate model on test set

model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for inputs, labels in tqdm(test_loader,desc='Testing'): inputs = inputs.to(DEVICE) labels = labels.to(DEVICE) outputs = model(inputs) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds) test_f1 = f1_score(all_labels, all_preds, average='macro') test_auc = roc_auc_score(all_labels, np.eye(NUM_CLASSES)[all_preds], multi_class='ovr') print('Swin Transformer') print("Test Classification Report:") print(classification_report(all_labels, all_preds, target_names=target_classes, digits=4)) print(f'Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test ROC-AUC: {test_auc:.4f}')

ROC Curve Visualization

plt.figure(figsize=(8, 6)) for i in range(NUM_CLASSES): fpr, tpr, _ = roc_curve(np.array(all_labels) == i, np.array(all_preds) == i) plt.plot(fpr, tpr, label=f'Class {target_classes[i]}') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC Curve') plt.legend() plt.show()

Confusion Matrix

plt.figure(figsize=(10, 8)) cm = confusion_matrix(all_labels, all_preds) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_classes, yticklabels=target_classes) plt.xlabel('Predicted') plt.ylabel('True') plt.title('Confusion Matrix') plt.tight_layout() plt.savefig('confusion_matrix.png') plt.show()

Save the model

torch.save(model.state_dict(), 'meme_classifier.pth')

Error:

AttributeError Traceback (most recent call last) /tmp/ipykernel_36/1370559554.py in <cell line: 0>() 93 model_name = "OpenGVLab/internimage_b_1k_224" 94 #processor = AutoImageProcessor.from_pretrained(model_name) ---> 95 model = model = AutoModelForImageClassification.from_pretrained( 96 model_name, 97 trust_remote_code=True,

/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) 562 elif type(config) in cls._model_mapping.keys(): 563 model_class = _get_model_class(config, cls._model_mapping) --> 564 return model_class.from_pretrained( 565 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs 566 )

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _wrapper(*args, **kwargs) 307 308 def get_state_dict_dtype(state_dict): --> 309 """ 310 Returns the first found floating dtype in state_dict if there is one, otherwise returns the first dtype. 311 """

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs) 4572 f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" 4573 " TRAIN this model on a down-stream task to be able to use it for predictions and inference." -> 4574 ) 4575 elif len(mismatched_keys) == 0: 4576 logger.info(

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only) 4882 start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 4883 -> 4884 x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 4885 x = self.activation(x) 4886 x = self.LayerNorm(x)

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in _initialize_missing_keys(self, loaded_keys, ignore_mismatched_sizes, is_quantized)

/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs) 114 def decorate_context(*args, **kwargs): 115 with ctx_factory(): --> 116 return func(*args, **kwargs) 117 118 return decorate_context

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in initialize_weights(self) 2554 if token is not None: 2555 kwargs["token"] = token -> 2556 2557 _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) 2558

/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py in smart_apply(self, fn) 2545 "The use_auth_token argument is deprecated and will be removed in v5 of Transformers. Please use token instead.", 2546 FutureWarning, -> 2547 ) 2548 if token is not None: 2549 raise ValueError(

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in getattr(self, name) 1926 if name in modules: 1927 return modules[name] -> 1928 raise AttributeError( 1929 f"'{type(self).name}' object has no attribute '{name}'" 1930 )

AttributeError: 'InternImage' object has no attribute '_initialize_weights'

secrakib avatar Jul 20 '25 16:07 secrakib