open-metric-learning icon indicating copy to clipboard operation
open-metric-learning copied to clipboard

Test errors on own dataset

Open snow-wind-001 opened this issue 9 months ago • 11 comments

import torch
import pandas as pd
import random
from pathlib import Path
from torch.utils.data import DataLoader
from typing import Union, Optional
from pprint import pprint
from oml.const import PATHS_COLUMN
from oml.datasets.base import DatasetQueryGallery
from oml.inference.flat import inference_on_dataframe, inference_on_images
from oml.models import ConcatSiamese, ViTExtractor
from oml.registry.transforms import get_transforms_for_pretrained
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
from oml.utils.misc_torch import pairwise_dist
from oml.utils.io import download_checkpoint_one_of
from NumpyImageDataset import *
import os

dataset_root = Path("/home/snowolf/dataset/bottle_processed_images/part1/archive")

def load_images_from_folder(folder_path):
   images = []  # 创建一个列表来存储图像数据
   valid_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff')  # 支持的文件格式

   for filename in os.listdir(folder_path):
       if filename.endswith(valid_extensions):
           img_path = os.path.join(folder_path, filename)  # 获取图像的完整路径
           with Image.open(img_path) as img:
               img_array = np.array(img)  # 将图像转换为 numpy 数组
               images.append(img_array)  # 将数组添加到列表中
   return images

class CustomViTExtractor(ViTExtractor):
   def __init__(self, arch: str = "vits16", normalise_features: bool = False, use_multi_scale: bool = False, weights: Optional[Union[Path, str]] = None):
       super().__init__(weights=None, arch=arch, normalise_features=normalise_features, use_multi_scale=use_multi_scale)
       if weights is not None:
           self.load_pretrained_weights(weights)
       self.to(device)

   def load_pretrained_weights(self, weights: Union[Path, str]):
       if isinstance(weights, str) and Path(weights).exists():
           ckpt = torch.load(weights, map_location=device)
           state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
       elif weights in self.pretrained_models:
           pretrained = self.pretrained_models[weights]
           downloaded_weights = download_checkpoint_one_of(
               url_or_fid_list=pretrained["url"],
               hash_md5=pretrained["hash"],
               fname=pretrained["fname"],
           )
           ckpt = torch.load(downloaded_weights, map_location=device)
           state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
       self.model.load_state_dict(state_dict, strict=False)
       pprint('加载权重到模型!!')

   @classmethod
   def from_pretrained(cls, path: Union[Path, str], arch: str = "vits16", normalise_features: bool = False, use_multi_scale: bool = False):
       return cls(weights=path, arch=arch, normalise_features=normalise_features, use_multi_scale=use_multi_scale)

def data_bottle(dataset_root, df_name):
   df = pd.read_csv(dataset_root / df_name)
   df_train = df[df["split"] == "train"].reset_index(drop=True)
   df_val = df[df["split"] == "validation"].reset_index(drop=True)
   df_val["is_query"] = df_val["is_query"].astype(bool)
   df_val["is_gallery"] = df_val["is_gallery"].astype(bool)
   return df_train, df_val

_, df_val = data_bottle(dataset_root, 'df_test.csv')
df_val["path"] = df_val["path"].apply(lambda x: dataset_root / x)
queries = df_val[df_val["is_query"]]["path"].tolist()
galleries = df_val[df_val["is_gallery"]]["path"].tolist()

for i in enumerate(queries):
   print(i)

print('========================')

for i in enumerate(galleries):
   print(i)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = ViTExtractor.from_pretrained("vits16_dino")
# checkpoints_path = '/home/snowolf/git_code/open-metric-learning/pipelines/postprocessing/pairwise_postprocessing/logs/2024-05-13_10-58-23_feature_extractor/checkpoints/best.ckpt'
checkpoints_path = '/home/snowolf/git_code/open-metric-learning/pipelines/postprocessing/pairwise_postprocessing/checkpoints/bottle/extractor.ckpt'
model = CustomViTExtractor.from_pretrained(checkpoints_path, arch='vits16')
transform, _ = get_transforms_for_pretrained("vits16_dino")
save_features_galleries = True
args = {"num_workers": 0, "batch_size": 8}

features_queries = inference_on_images(model, paths=queries, transform=transform, **args)

if save_features_galleries == False:
   features_galleries = inference_on_images(model, paths=galleries, transform=transform, **args)
   #将features_galleries保存起来,方便后续使用
   torch.save(features_galleries, 'features_galleries.pth')
else:
   features_galleries = torch.load('features_galleries.pth')
   print('features_galleries loaded!')

features_galleries1 = torch.load('features_galleries.pth')
if torch.equal(features_galleries, features_galleries1):
   print('Save and load features_galleries is successful!')
else:
   print('Save and load features_galleries is not successful!')

dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries)
ii_closest = torch.argmin(dist_mat, dim=1)
print(f"Indices of the items closest to queries: {ii_closest}")

snow-wind-001 avatar May 23 '24 15:05 snow-wind-001