Masked-Face-Recognition-KD icon indicating copy to clipboard operation
Masked-Face-Recognition-KD copied to clipboard

Sample code for inference

Open pourfard opened this issue 3 years ago • 2 comments

I have written a few lines of code for inference. Is it correct?

import os
import traceback
import torch
import numpy as np
from backbones.iresnet import iresnet100
import math
import cv2


backbone = iresnet100(num_features=512).to("cpu")
try:
    backbone_pth = "ElasticFaceArcAug_backbone.pth"
    if not os.path.exists(backbone_pth):
        raise Exception("Model file does not exist!", backbone_pth)
    backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device("cpu")))
    backbone.train(False)
except Exception:
    traceback.print_exc()
    print("load teacher backbone init, failed!")
    exit(1)


def get_embedding(path):
    aligned_bgr_input = cv2.imread(path)  # (112, 112, 3) aligned face with mtcnn

    aligned_rgb_input = cv2.cvtColor(aligned_bgr_input, cv2.COLOR_BGR2RGB)

    aligned_rgb_input = np.transpose(aligned_rgb_input, (2, 0, 1))  # (channel, height, width)
    aligned_rgb_input = np.asarray([aligned_rgb_input], dtype="float32")

    aligned_rgb_input = ((aligned_rgb_input / 255) - 0.5) / 0.5

    aligned_rgb_input = torch.Tensor(aligned_rgb_input)
    net_out: torch.Tensor = backbone(aligned_rgb_input.to("cpu"))
    _embeddings = net_out.detach().cpu().numpy()

    return _embeddings[0]


def print_similarity(vector1, vector2):
    dot = np.sum(np.multiply([vector1], [vector2]), axis=1)
    norm = np.linalg.norm([vector1], axis=1) * np.linalg.norm([vector2], axis=1)
    sim = dot / norm
    dist = np.arccos(sim) / math.pi

    sim = sim[0]
    dist = dist[0]

    print(sim, dist)


print_similarity(get_embedding("test1.jpg"), get_embedding("test2.jpg"))

pourfard avatar Jan 16 '22 09:01 pourfard

Code looks good

fdbtrs avatar Apr 12 '22 15:04 fdbtrs

@fdbtrs your evaluate use l2 distance,Are these two methods the same? diff = np.subtract(embeddings1, embeddings2) dist = np.sum(np.square(diff), 1)

wen1q84 avatar Oct 19 '22 01:10 wen1q84