facenet-pytorch icon indicating copy to clipboard operation
facenet-pytorch copied to clipboard

How to load .pt model in scala?

Open zaryabRiasat opened this issue 6 months ago • 0 comments

I've downloaded pre-trained model from there, which is 20180402-114759-vggface2.pt. I've used this in python and it is working fine with great accuracy.


from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image
import torch

mtcnn = MTCNN(image_size=160, margin=0)
resnet = InceptionResnetV1(pretrained='vggface2').eval()

resnet.load_state_dict(torch.load('../20180402-114759-vggface2.pt'), strict=False)

img1 = Image.open('../img1')
img2 = Image.open('../img2')

img1_cropped = mtcnn(img1)
img2_cropped = mtcnn(img2)

if img1_cropped is not None and img2_cropped is not None:
    img1_embedding = resnet(img1_cropped.unsqueeze(0))
    img2_embedding = resnet(img2_cropped.unsqueeze(0))

    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    similarity = cos(img1_embedding, img2_embedding)
    print(f"Cosine Similarity: {similarity.item()}")
    threshold = 0.6  
    if similarity > threshold:
        print("The faces are similar!")
        print("The faces are different!")
    print("Face not detected in one or both images.")

Now I want to use it in Scala (JVM Environment). I've searched a lot, and found that we can use .pt model in scala using DJL (Deep Java Library), the code which I tried in scala is:

libraries in build.sbt:

libraryDependencies ++= Seq(
  "ai.djl" % "api" % "0.29.0",
  "ai.djl.pytorch" % "pytorch-engine" % "0.29.0" % "runtime",
  "ai.djl.pytorch" % "pytorch-model-zoo" % "0.29.0",
  "ai.djl.pytorch" % "pytorch-native-cpu" % "2.3.1" % "runtime" classifier "linux-x86_64",
  "ai.djl.pytorch" % "pytorch-jni" % "2.3.1-0.29.0" % "runtime"


import ai.djl.Model
import ai.djl.modality.cv.Image
import ai.djl.modality.cv.ImageFactory
import ai.djl.ndarray.{NDArray, NDList, NDManager}
import ai.djl.ndarray.types.Shape
import ai.djl.translate.{Batchifier, Translator, TranslatorContext}

import java.nio.file.Paths

object FaceRecognitionDJL {

  def main(args: Array[String]): Unit = {
    val image1Path = Paths.get("../img_1.png")
    val image2Path = Paths.get("../img_2.png")

    val image1 = ImageFactory.getInstance().fromFile(image1Path)
    val image2 = ImageFactory.getInstance().fromFile(image2Path)

    val model = Model.newInstance("face_recognition_model")

    val embeddings1 = getEmbeddings(model, image1)
    val embeddings2 = getEmbeddings(model, image2)

    val similarity = compareEmbeddings(embeddings1, embeddings2)
    println(s"Similarity between faces: $similarity")

    if (similarity > 0.7) {
      println("Faces belong to the same person.")
    } else {
      println("Faces do not belong to the same person.")

  def getEmbeddings(model: Model, image: Image): Array[Float] = {
    val predictor = model.newPredictor(new MyTranslator)

  def compareEmbeddings(embedding1: Array[Float], embedding2: Array[Float]): Double = {
    val dotProduct = embedding1.zip(embedding2).map { case (a, b) => a * b }.sum
    val norm1 = Math.sqrt(embedding1.map(x => x * x).sum)
    val norm2 = Math.sqrt(embedding2.map(x => x * x).sum)
    dotProduct / (norm1 * norm2)

class MyTranslator extends Translator[Image, Array[Float]] {
  override def processInput(ctx: TranslatorContext, input: Image): NDList = {
    val manager = NDManager.newBaseManager()

    val imgArray: NDArray = input.toNDArray(manager)

    val resizedImgArray = imgArray.reshape(new Shape(160, 160))
    val normalizedImgArray = resizedImgArray.div(255.0)

    new NDList(normalizedImgArray)

  override def processOutput(ctx: TranslatorContext, list: NDList): Array[Float] = {

  override def getBatchifier: Batchifier = null

I have tried above code, after searching on different websites. But this is giving an error:

[error] Exception in thread "main" ai.djl.engine.EngineException: PytorchStreamReader failed reading zip archive: failed finding central directory

Same .pt model is working fine in python but I'm unable to run that in scala. Guide me what I'm doing wrong?

zaryabRiasat avatar Aug 10 '24 11:08 zaryabRiasat