open-metric-learning
open-metric-learning copied to clipboard
Test errors on own dataset
import torch
import pandas as pd
import random
from pathlib import Path
from torch.utils.data import DataLoader
from typing import Union, Optional
from pprint import pprint
from oml.const import PATHS_COLUMN
from oml.datasets.base import DatasetQueryGallery
from oml.inference.flat import inference_on_dataframe, inference_on_images
from oml.models import ConcatSiamese, ViTExtractor
from oml.registry.transforms import get_transforms_for_pretrained
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
from oml.utils.misc_torch import pairwise_dist
from oml.utils.io import download_checkpoint_one_of
from NumpyImageDataset import *
import os
dataset_root = Path("/home/snowolf/dataset/bottle_processed_images/part1/archive")
def load_images_from_folder(folder_path):
images = [] # 创建一个列表来存储图像数据
valid_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff') # 支持的文件格式
for filename in os.listdir(folder_path):
if filename.endswith(valid_extensions):
img_path = os.path.join(folder_path, filename) # 获取图像的完整路径
with Image.open(img_path) as img:
img_array = np.array(img) # 将图像转换为 numpy 数组
images.append(img_array) # 将数组添加到列表中
return images
class CustomViTExtractor(ViTExtractor):
def __init__(self, arch: str = "vits16", normalise_features: bool = False, use_multi_scale: bool = False, weights: Optional[Union[Path, str]] = None):
super().__init__(weights=None, arch=arch, normalise_features=normalise_features, use_multi_scale=use_multi_scale)
if weights is not None:
self.load_pretrained_weights(weights)
self.to(device)
def load_pretrained_weights(self, weights: Union[Path, str]):
if isinstance(weights, str) and Path(weights).exists():
ckpt = torch.load(weights, map_location=device)
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
elif weights in self.pretrained_models:
pretrained = self.pretrained_models[weights]
downloaded_weights = download_checkpoint_one_of(
url_or_fid_list=pretrained["url"],
hash_md5=pretrained["hash"],
fname=pretrained["fname"],
)
ckpt = torch.load(downloaded_weights, map_location=device)
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
self.model.load_state_dict(state_dict, strict=False)
pprint('加载权重到模型!!')
@classmethod
def from_pretrained(cls, path: Union[Path, str], arch: str = "vits16", normalise_features: bool = False, use_multi_scale: bool = False):
return cls(weights=path, arch=arch, normalise_features=normalise_features, use_multi_scale=use_multi_scale)
def data_bottle(dataset_root, df_name):
df = pd.read_csv(dataset_root / df_name)
df_train = df[df["split"] == "train"].reset_index(drop=True)
df_val = df[df["split"] == "validation"].reset_index(drop=True)
df_val["is_query"] = df_val["is_query"].astype(bool)
df_val["is_gallery"] = df_val["is_gallery"].astype(bool)
return df_train, df_val
_, df_val = data_bottle(dataset_root, 'df_test.csv')
df_val["path"] = df_val["path"].apply(lambda x: dataset_root / x)
queries = df_val[df_val["is_query"]]["path"].tolist()
galleries = df_val[df_val["is_gallery"]]["path"].tolist()
for i in enumerate(queries):
print(i)
print('========================')
for i in enumerate(galleries):
print(i)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = ViTExtractor.from_pretrained("vits16_dino")
# checkpoints_path = '/home/snowolf/git_code/open-metric-learning/pipelines/postprocessing/pairwise_postprocessing/logs/2024-05-13_10-58-23_feature_extractor/checkpoints/best.ckpt'
checkpoints_path = '/home/snowolf/git_code/open-metric-learning/pipelines/postprocessing/pairwise_postprocessing/checkpoints/bottle/extractor.ckpt'
model = CustomViTExtractor.from_pretrained(checkpoints_path, arch='vits16')
transform, _ = get_transforms_for_pretrained("vits16_dino")
save_features_galleries = True
args = {"num_workers": 0, "batch_size": 8}
features_queries = inference_on_images(model, paths=queries, transform=transform, **args)
if save_features_galleries == False:
features_galleries = inference_on_images(model, paths=galleries, transform=transform, **args)
#将features_galleries保存起来,方便后续使用
torch.save(features_galleries, 'features_galleries.pth')
else:
features_galleries = torch.load('features_galleries.pth')
print('features_galleries loaded!')
features_galleries1 = torch.load('features_galleries.pth')
if torch.equal(features_galleries, features_galleries1):
print('Save and load features_galleries is successful!')
else:
print('Save and load features_galleries is not successful!')
dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries)
ii_closest = torch.argmin(dist_mat, dim=1)
print(f"Indices of the items closest to queries: {ii_closest}")