LASTED icon indicating copy to clipboard operation
LASTED copied to clipboard

Model just predicts most images as real (painting or photo)

Open xiankgx opened this issue 1 year ago • 0 comments

Hi, I tried using your model to try to detect AI generated photos like those from SD, SDXL, Dalle, etc. However, most of the predictions are "real". Do you see any problem with my code?

import clip
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image

from model import LASTED

LABELS = ["Real Photo", "Synthetic Photo", "Real Painting", "Synthetic Painting"]


def modify_state_dict(sd: dict) -> dict:
    new_sd = dict()
    for k, v in sd.items():
        new_sd[k.replace("module.", "")] = v
    return new_sd


def classify(image: Image.Image):
    with torch.inference_mode():
        tensor_in = transform(image).unsqueeze(0).to(device)
        text = clip.tokenize(LABELS).to(device)

        image_features = model.clip_model.encode_image(tensor_in)
        text_features = model.clip_model.encode_text(text)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (
            (100.0 * image_features @ text_features.T)
            .softmax(dim=-1)
            .detach()
            .cpu()
            .numpy()
        )
        print(f"similarity: {similarity}")

        return np.array(LABELS)[np.argmax(similarity, axis=1)].tolist()


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    transform = transforms.Compose(
        [
            # transforms.ToPILImage(),
            transforms.Resize((448, 448)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )

    print("Loading model...")
    model = LASTED()
    model.load_state_dict(
        modify_state_dict(torch.load("LASTED_pretrained.pt", map_location="cpu"))
    )
    model.eval()
    model.to(device)
    print("Done!")

    demo = gr.Interface(
        fn=classify,
        inputs=[gr.Image(label="input image", type="pil")],
        outputs=[gr.Text(label="predicted label")],
    )

    demo.launch(server_name="0.0.0.0", server_port=80)

xiankgx avatar Mar 29 '24 08:03 xiankgx