recognize-anything icon indicating copy to clipboard operation
recognize-anything copied to clipboard

有没有方法可以提升下执行效率

Open gaoyong06 opened this issue 1 year ago • 9 comments

用10多张图片试了下,效率有点低,在个人电脑上处理一张图片大约需要几十秒(40秒)左右,如果图片数量比较多的话,处理起来,就比较吃力了。 有没法什么方法,可以提示一下效率?如果在1,2,3秒内,能执行结束的话,就很棒了.

gaoyong06 avatar Jun 07 '23 14:06 gaoyong06

对于Tag2Text,可以改成batch批量inference的代码,可以极大的增大inference效率。如果只要tagging的输出,可以在generate函数中只输出tagging就停止。 此外,我们新发布的the Recognize Anything Model (RAM)可以任意自定义识别的类别,可以进一步提升tagging inference效率,此功能会在未来发布。

xinyu1205 avatar Jun 07 '23 14:06 xinyu1205

每次都加载了模型导致的,改用加载一次就好,MAC m2 cpu,网络图片,每张10秒到3秒之间

onexuan avatar Jun 07 '23 14:06 onexuan

对于Tag2Text,可以改成batch批量inference的代码,可以极大的增大inference效率。如果只要tagging的输出,可以在generate函数中只输出tagging就停止。 此外,我们新发布的the Recognize Anything Model (RAM)可以任意自定义识别的类别,可以进一步提升tagging inference效率,此功能会在未来发布。

已经发布了吗还是明天发布

tensorboy avatar Jun 07 '23 20:06 tensorboy

@xinyu1205 @onexuan 感谢2位的建议. 我写了个测试,现在慢的话处理一张图片4秒左右,快的话2~3秒

下面是代码,发出来看一下,是否还有提升的空间?

文件名:test.py 使用示例:

  1. python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --image-dir C:/Users/gaoyo/Desktop/test1/
  2. python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --images C:/Users/gaoyo/Desktop/test1/20170912_234158_1_14_70wf.jpeg C:/Users/gaoyo/Desktop/test1/20170918_192435_1_57_9yto.jpeg
# -*- coding: utf-8 -*-
'''
Author: gaoyong [email protected]
Date: 2023-06-08 10:51:43
LastEditors: gaoyong [email protected]
LastEditTime: 2023-06-08 11:07:57
FilePath: \Tag2Text\test.py
Description: 自动生成图片标签和内容描述
'''
import argparse
import json
import os
import time

import imghdr
import torch
import torchvision.transforms as transforms
from PIL import Image

from models.tag2text import tag2text_caption


def parse_args():
    """
    This function parses command line arguments for a Tag2Text inference model.
    :return: The function `parse_args()` is returning the parsed arguments from the command line using
    the `argparse` module.
    """
    parser = argparse.ArgumentParser(
        description='Tag2Text inference for tagging and captioning')
    parser.add_argument('--image-dir',
                        metavar='DIR',
                        help='path to directory containing input images',
                        default='')
    parser.add_argument('--images',
                        metavar='IMAGE-LIST',
                        nargs='+',
                        help='list of space-separated image filenames',
                        default=[])
    parser.add_argument('--pretrained',
                        metavar='DIR',
                        help='path to pretrained model',
                        default='D:/work/Tag2Text/pretrained/tag2text_swin_14m.pth')
    parser.add_argument('--image-size',
                        default=384,
                        type=int,
                        metavar='N',
                        help='input image size (default: 448)')
    parser.add_argument('--thre',
                        default=0.68,
                        type=float,
                        metavar='N',
                        help='threshold value')
    parser.add_argument('--specified-tags',
                        default='None',
                        help='User input specified tags')
    parser.add_argument('--cache-path',
                        default='None',
                        help='cache model file path')

    return parser.parse_args()


def initialize_model(cache_path, pretrained, image_size, thre):
    """
    This function initializes a Tag2Text model based on specified and identified tags.
    :param cache_path: Cache model file path.
    :param pretrained: Path to the pre-trained model.
    :param image_size: Input image size.
    :param thre: Threshold value for tagging.
    :return: A pre-trained Tag2Text model.
    """

    # delete some tags that may disturb captioning
    # 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
    delete_tag_index = [127, 2961, 3351, 3265, 3338, 3355, 3359]

    if os.path.exists(cache_path):
        model = torch.load(cache_path)
    else:
        model = tag2text_caption(
            pretrained=pretrained,
            image_size=image_size,
            vit='swin_b',
            delete_tag_index=delete_tag_index
        )
        model.threshold = thre  # threshold for tagging
        model.eval()
        torch.save(model, cache_path)

    return model


def generate(model, image, input_tags=None):
    """
    This function generates tags and captions for an input image.
    :param model: The neural network model used for generating captions and predicting tags for an input
    image.
    :param image: The input image to generate tags and captions for.
    :param input_tags: The input tags used as hints for the model to generate captions for the input image.
    It is an optional parameter and can be set to None or left empty if no tag hint is required.
    :return: A tuple of predicted tags, input tags, and generated captions.
    """

    if input_tags in ('', 'none', 'None'):
        input_tags = None

    with torch.no_grad():
        caption, tag_predict = model.generate(image,
                                              tag_input=None,
                                              max_length=50,
                                              return_tag_predict=True)

    if input_tags is None:
        return tag_predict[0], None, caption[0]

    input_tag_list = [input_tags.replace(',', ' | ')]
    with torch.no_grad():
        caption, input_tags = model.generate(image,
                                             tag_input=input_tag_list,
                                             max_length=50,
                                             return_tag_predict=True)
    return tag_predict[0], input_tags[0], caption[0]


def inference(images_dir, image_list, model, image_size, input_tags=None):
    """
    This function takes a list of images or a directory containing images, a model, generates captions
    for the images, and optionally takes a list of input tags to generate captions with those tags.
    :param images_dir: A directory containing input images that the model will use to generate captions and
    potentially predict tags for.
    :param image_list: A list of input images the model will use to generate captions and potentially
    predict tags for.
    :param model: The neural network model used for generating captions and predicting tags for an input
    image.
    :param input_tags: The input tags are lists of strings that represent tags or sets of tags that are
    used as hints for the model to generate captions for the given images. It is an optional parameter and
    can be set to None or left empty if no tag hint is required, defaults to None.
    :return: A list of dictionaries, each containing predicted tags, input tags (if provided), and
    generated captions for a given input image.
    """
    
    results = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(), normalize
    ])

    if images_dir and os.path.isdir(images_dir):
        for filename in os.listdir(images_dir):
            filepath = os.path.join(images_dir, filename)
            if not imghdr.what(filepath):
                continue
            img = Image.open(filepath).convert("RGB")
            img_tensor = transform(img).unsqueeze(0).to(device)
            res = generate(model, img_tensor, input_tags)
            results.append({
                "filepath": filepath,
                "model_identified_tags": res[0],
                "user_specified_tags": res[1],
                "image_caption": res[2]
            })
    elif image_list and isinstance(image_list, list):
        for img_path in image_list:
            filepath = os.path.abspath(img_path)
            if not os.path.isfile(filepath) or not imghdr.what(filepath):
                continue
            img = Image.open(filepath).convert("RGB")
            img_tensor = transform(img).unsqueeze(0).to(device)
            res = generate(model, img_tensor, input_tags)
            results.append({
                "filepath": img_path,
                "model_identified_tags": res[0],
                "user_specified_tags": res[1],
                "image_caption": res[2]
            })

    return results

def main():
    """
    This function loads a pre-trained image captioning model, processes input images in a directory,
    and generates captions for each image based on specified and identified tags.
    """
    start_time = time.time()
    args = parse_args()

    # check if a list of images is provided
    images = args.images if args.images else None
    # initialize the model
    model = initialize_model(
        args.cache_path, args.pretrained, args.image_size, args.thre)

    # perform inference on images
    data = inference(args.image_dir, images, model,
                    args.image_size, input_tags=None)

    # output the results
    results = {
        "status": 0,
        "message": 'ok',
        "data": data
    }

    end_time = time.time()
    elapsed_time = end_time - start_time

    print(
        f"Processed {len(results['data'])} images in {elapsed_time:.2f} seconds.")

    print(json.dumps(results, ensure_ascii=False, indent=2))

# 使用示例:
# 1. python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --image-dir C:/Users/gaoyo/Desktop/test1/
# 2. python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --images C:/Users/gaoyo/Desktop/test1/20170912_234158_1_14_70wf.jpeg C:/Users/gaoyo/Desktop/test1/20170918_192435_1_57_9yto.jpeg


if __name__ == '__main__':
    main()

gaoyong06 avatar Jun 08 '23 03:06 gaoyong06

@xinyu1205 请问用GPU跑RAM模型需要多大显存呢?

nowgoo avatar Jun 14 '23 01:06 nowgoo

请问可以给我提供一个 cache-path文件的下载链接吗?我无法自动生成,代码给我的提示是网络链接异常,可能是因为我的网络问题

qq846511277 avatar Jun 16 '23 03:06 qq846511277

@xinyu1205 请问用GPU跑RAM模型需要多大显存呢?

3.8G

lyy1988323 avatar Jul 07 '23 07:07 lyy1988323

每次都加载了模型导致的,改用加载一次就好,MAC m2 cpu,网络图片,每张10秒到3秒之间

对于Recognize Anything Model (RAM)模型,请问怎么改成加载一次。

alexkinren avatar Jul 13 '23 00:07 alexkinren

对于Tag2Text,可以改成batch批量inference的代码,可以极大的增大inference效率。如果只要tagging的输出,可以在generate函数中只输出tagging就停止。 此外,我们新发布的the Recognize Anything Model (RAM)可以任意自定义识别的类别,可以进一步提升tagging inference效率,此功能会在未来发布。 (RAM)可以任意自定义识别的类别,这个计划什么时期发布呢

alexkinren avatar Jul 13 '23 00:07 alexkinren