Matte-Anything icon indicating copy to clipboard operation
Matte-Anything copied to clipboard

Demo cannot work

Open ning22 opened this issue 1 year ago • 5 comments

image

No matter where I clicked after dropping an image, Error occurred! Do you have detailed instructions to run the demo?

ning22 avatar Jul 04 '24 03:07 ning22

升级gradio包到4.39.0;使用下面的代码替换matte_anything.py中对应的部分 with gr.Blocks() as demo: gr.Markdown( """ #

Matte Anything🐒 ! """ ) with gr.Row(equal_height=True): #.style(equal_height=True): with gr.Column(): # input image original_image = gr.State(value="numpy") # store original image without points, default None input_image = gr.Image(type="numpy", label="Input Image")
# prompt (point or text) # Point Input with gr.Tab(label='Point Input') as Tab1: with gr.Column(): selected_points = gr.State([]) # store points radio = gr.Radio(['foreground_point', 'background_point'], label='Point Labels') with gr.Row(): undo_button = gr.Button('Remove Point') undo_all_button = gr.Button('Remove All Points') # Foreground Text Input with gr.Tab(label='Foreground Text Input') as Tab2: with gr.Blocks(): gr.Markdown("Foreground Text Input") fg_caption = gr.Textbox(lines=1, placeholder="the girl in the middle", label="foreground input text")
            # Save Config
            with gr.Tab(label='Save Config') as Tab3:
                with gr.Blocks():
                    gr.Markdown("save name")
                    save_dir = gr.Textbox(lines=1, placeholder="the girl in the middle", label="Give a name of your demo. It will be saved in ```your_demos/your_name.pny```")
            
            
            # run button
            button = gr.Button("Start!")

            # Trimap Settings
            with gr.Tab(label='Trimap Settings'):
                gr.Markdown("Trimap Settings")
                erode_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="erode_kernel_size")
                dilate_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="dilate_kernel_size")
            
            # Input Text Settings
            with gr.Tab(label='Input Text Settings'):
                gr.Markdown("Input Text Settings")
                fg_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_box_threshold")
                fg_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_text_threshold")

            # Transparency Settings
            with gr.Tab(label='Transparency Settings'):
                gr.Markdown("Transparency Settings")
                tr_caption = gr.Textbox(lines=1, placeholder="glass.lens.crystal.diamond.bubble.bulb.web.grid", label="transparency input text")
                tr_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.5, label="transparency_box_threshold")
                tr_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.25, label="transparency_text_threshold")

        
        with gr.Column():

            # show the image with mask
            with gr.Tab(label='SAM Mask'):
                mask = gr.Image(type='numpy')
            # with gr.Tab(label='Trimap'):
            #     trimap = gr.Image(type='numpy')
            with gr.Tab(label='Alpha Matte'):
                alpha = gr.Image(type='numpy')
            # show only mask
            with gr.Tab(label='Foreground by SAM Mask'):
                foreground_by_sam_mask = gr.Image(type='numpy')
            with gr.Tab(label='Refined by ViTMatte'):
                refined_by_vitmatte = gr.Image(type='numpy')
            # with gr.Tab(label='Transparency Detection'):
            #     transparency = gr.Image(type='numpy')
            with gr.Tab(label='New Background 1'):
                new_bg_1 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 2'):
                new_bg_2 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 3'):
                new_bg_3 = gr.Image(type='numpy')

    input_image.upload(
        store_img,
        [input_image],
        [original_image, selected_points]
    )
    input_image.select(
        get_point,
        [input_image, selected_points, radio],
        [input_image],
    )
    undo_button.click(
        undo_points,
        [original_image, selected_points],
        [input_image]
    )
    undo_all_button.click(
        undo_all_points,
        [original_image, selected_points],
        [input_image]
    )
    Tab1.select(
        clear_fg_caption,
        [fg_caption],
        [fg_caption]
    )
    Tab2.select(
        undo_all_points,
        [original_image, selected_points],
        [input_image]
    )
    
    
    button.click(run_inference, inputs=[original_image, selected_points, erode_kernel_size, dilate_kernel_size, fg_box_threshold, fg_text_threshold, fg_caption, tr_box_threshold, tr_text_threshold, \
                                        save_dir, tr_caption],  outputs=[mask, alpha, foreground_by_sam_mask, refined_by_vitmatte, new_bg_1, new_bg_2, new_bg_3])

    with gr.Row():
        with gr.Column():
            background_image = gr.State(value=None)

demo.launch()

f-chen165 avatar Jul 25 '24 07:07 f-chen165

升级gradio包到4.39.0;使用下面的代码替换matte_anything.py中对应的部分 with gr.Blocks() as demo: gr.Markdown( """ #

Matte Anything🐒 ! """ ) with gr.Row(equal_height=True): #.style(equal_height=True): with gr.Column():

input image

original_image = gr.State(value="numpy") # store original image without points, default None input_image = gr.Image(type="numpy", label="Input Image")

prompt (point or text)

Point Input

with gr.Tab(label='Point Input') as Tab1: with gr.Column(): selected_points = gr.State([]) # store points radio = gr.Radio(['foreground_point', 'background_point'], label='Point Labels') with gr.Row(): undo_button = gr.Button('Remove Point') undo_all_button = gr.Button('Remove All Points')

Foreground Text Input

with gr.Tab(label='Foreground Text Input') as Tab2: with gr.Blocks(): gr.Markdown("Foreground Text Input") fg_caption = gr.Textbox(lines=1, placeholder="the girl in the middle", label="foreground input text")

            # Save Config
            with gr.Tab(label='Save Config') as Tab3:
                with gr.Blocks():
                    gr.Markdown("save name")
                    save_dir = gr.Textbox(lines=1, placeholder="the girl in the middle", label="Give a name of your demo. It will be saved in ```your_demos/your_name.pny```")
            
            
            # run button
            button = gr.Button("Start!")

            # Trimap Settings
            with gr.Tab(label='Trimap Settings'):
                gr.Markdown("Trimap Settings")
                erode_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="erode_kernel_size")
                dilate_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="dilate_kernel_size")
            
            # Input Text Settings
            with gr.Tab(label='Input Text Settings'):
                gr.Markdown("Input Text Settings")
                fg_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_box_threshold")
                fg_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_text_threshold")

            # Transparency Settings
            with gr.Tab(label='Transparency Settings'):
                gr.Markdown("Transparency Settings")
                tr_caption = gr.Textbox(lines=1, placeholder="glass.lens.crystal.diamond.bubble.bulb.web.grid", label="transparency input text")
                tr_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.5, label="transparency_box_threshold")
                tr_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.25, label="transparency_text_threshold")

        
        with gr.Column():

            # show the image with mask
            with gr.Tab(label='SAM Mask'):
                mask = gr.Image(type='numpy')
            # with gr.Tab(label='Trimap'):
            #     trimap = gr.Image(type='numpy')
            with gr.Tab(label='Alpha Matte'):
                alpha = gr.Image(type='numpy')
            # show only mask
            with gr.Tab(label='Foreground by SAM Mask'):
                foreground_by_sam_mask = gr.Image(type='numpy')
            with gr.Tab(label='Refined by ViTMatte'):
                refined_by_vitmatte = gr.Image(type='numpy')
            # with gr.Tab(label='Transparency Detection'):
            #     transparency = gr.Image(type='numpy')
            with gr.Tab(label='New Background 1'):
                new_bg_1 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 2'):
                new_bg_2 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 3'):
                new_bg_3 = gr.Image(type='numpy')

    input_image.upload(
        store_img,
        [input_image],
        [original_image, selected_points]
    )
    input_image.select(
        get_point,
        [input_image, selected_points, radio],
        [input_image],
    )
    undo_button.click(
        undo_points,
        [original_image, selected_points],
        [input_image]
    )
    undo_all_button.click(
        undo_all_points,
        [original_image, selected_points],
        [input_image]
    )
    Tab1.select(
        clear_fg_caption,
        [fg_caption],
        [fg_caption]
    )
    Tab2.select(
        undo_all_points,
        [original_image, selected_points],
        [input_image]
    )
    
    
    button.click(run_inference, inputs=[original_image, selected_points, erode_kernel_size, dilate_kernel_size, fg_box_threshold, fg_text_threshold, fg_caption, tr_box_threshold, tr_text_threshold, \
                                        save_dir, tr_caption],  outputs=[mask, alpha, foreground_by_sam_mask, refined_by_vitmatte, new_bg_1, new_bg_2, new_bg_3])

    with gr.Row():
        with gr.Column():
            background_image = gr.State(value=None)

demo.launch()

now I update gradio to the latest version (4.44.1) but annother bug appears, which is foreground point and background point cannot be selected.

DavidXie03 avatar Oct 29 '24 17:10 DavidXie03

升级gradio包到4.39.0;使用下面的代码替换matte_anything.py中对应的部分 with gr.Blocks() as demo: gr.Markdown( """ # Matte Anything🐒 ! """ ) with gr.Row(equal_height=True): #.style(equal_height=True): with gr.Column():

input image

original_image = gr.State(value="numpy") # store original image without points, default None input_image = gr.Image(type="numpy", label="Input Image")

prompt (point or text)

Point Input

with gr.Tab(label='Point Input') as Tab1: with gr.Column(): selected_points = gr.State([]) # store points radio = gr.Radio(['foreground_point', 'background_point'], label='Point Labels') with gr.Row(): undo_button = gr.Button('Remove Point') undo_all_button = gr.Button('Remove All Points')

Foreground Text Input

with gr.Tab(label='Foreground Text Input') as Tab2: with gr.Blocks(): gr.Markdown("Foreground Text Input") fg_caption = gr.Textbox(lines=1, placeholder="the girl in the middle", label="foreground input text")

            # Save Config
            with gr.Tab(label='Save Config') as Tab3:
                with gr.Blocks():
                    gr.Markdown("save name")
                    save_dir = gr.Textbox(lines=1, placeholder="the girl in the middle", label="Give a name of your demo. It will be saved in ```your_demos/your_name.pny```")
            
            
            # run button
            button = gr.Button("Start!")

            # Trimap Settings
            with gr.Tab(label='Trimap Settings'):
                gr.Markdown("Trimap Settings")
                erode_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="erode_kernel_size")
                dilate_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="dilate_kernel_size")
            
            # Input Text Settings
            with gr.Tab(label='Input Text Settings'):
                gr.Markdown("Input Text Settings")
                fg_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_box_threshold")
                fg_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_text_threshold")

            # Transparency Settings
            with gr.Tab(label='Transparency Settings'):
                gr.Markdown("Transparency Settings")
                tr_caption = gr.Textbox(lines=1, placeholder="glass.lens.crystal.diamond.bubble.bulb.web.grid", label="transparency input text")
                tr_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.5, label="transparency_box_threshold")
                tr_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.25, label="transparency_text_threshold")

        
        with gr.Column():

            # show the image with mask
            with gr.Tab(label='SAM Mask'):
                mask = gr.Image(type='numpy')
            # with gr.Tab(label='Trimap'):
            #     trimap = gr.Image(type='numpy')
            with gr.Tab(label='Alpha Matte'):
                alpha = gr.Image(type='numpy')
            # show only mask
            with gr.Tab(label='Foreground by SAM Mask'):
                foreground_by_sam_mask = gr.Image(type='numpy')
            with gr.Tab(label='Refined by ViTMatte'):
                refined_by_vitmatte = gr.Image(type='numpy')
            # with gr.Tab(label='Transparency Detection'):
            #     transparency = gr.Image(type='numpy')
            with gr.Tab(label='New Background 1'):
                new_bg_1 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 2'):
                new_bg_2 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 3'):
                new_bg_3 = gr.Image(type='numpy')

    input_image.upload(
        store_img,
        [input_image],
        [original_image, selected_points]
    )
    input_image.select(
        get_point,
        [input_image, selected_points, radio],
        [input_image],
    )
    undo_button.click(
        undo_points,
        [original_image, selected_points],
        [input_image]
    )
    undo_all_button.click(
        undo_all_points,
        [original_image, selected_points],
        [input_image]
    )
    Tab1.select(
        clear_fg_caption,
        [fg_caption],
        [fg_caption]
    )
    Tab2.select(
        undo_all_points,
        [original_image, selected_points],
        [input_image]
    )
    
    
    button.click(run_inference, inputs=[original_image, selected_points, erode_kernel_size, dilate_kernel_size, fg_box_threshold, fg_text_threshold, fg_caption, tr_box_threshold, tr_text_threshold, \
                                        save_dir, tr_caption],  outputs=[mask, alpha, foreground_by_sam_mask, refined_by_vitmatte, new_bg_1, new_bg_2, new_bg_3])

    with gr.Row():
        with gr.Column():
            background_image = gr.State(value=None)

demo.launch()

now I update gradio to the latest version (4.44.1) but annother bug appears, which is foreground point and background point cannot be selected.

before I used python 3.9 (gradio 5 not supported), then I upgraded to python 3.10 (gradio 5 supported) and gradio 5. now the problem is solved

DavidXie03 avatar Oct 30 '24 05:10 DavidXie03

with fastapi==0.111.0, gradio==4.39.0,

import os
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from torchvision.ops import box_convert
from detectron2.config import LazyConfig, instantiate
from detectron2.checkpoint import DetectionCheckpointer
from segment_anything import sam_model_registry, SamPredictor
import groundingdino.datasets.transforms as T
from groundingdino.util.inference import load_model as dino_load_model, predict as dino_predict, annotate as dino_annotate

models = {
        'vit_h': './pretrained/sam_vit_h_4b8939.pth',
    'vit_b': './pretrained/sam_vit_b_01ec64.pth'
}

vitmatte_models = {
        'vit_b': './pretrained/ViTMatte_B_DIS.pth',
}

vitmatte_config = {
        'vit_b': './configs/matte_anything.py',
}

grounding_dino = {
    'config': './GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py',
    'weight': './pretrained/groundingdino_swint_ogc.pth'
}

def generate_checkerboard_image(height, width, num_squares):
    num_squares_h = num_squares
    square_size_h = height // num_squares_h
    square_size_w = square_size_h
    num_squares_w = width // square_size_w
    

    new_height = num_squares_h * square_size_h
    new_width = num_squares_w * square_size_w
    image = np.zeros((new_height, new_width), dtype=np.uint8)

    for i in range(num_squares_h):
        for j in range(num_squares_w):
            start_x = j * square_size_w
            start_y = i * square_size_h
            color = 255 if (i + j) % 2 == 0 else 200
            image[start_y:start_y + square_size_h, start_x:start_x + square_size_w] = color

    image = cv2.resize(image, (width, height))
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    return image


def init_segment_anything(model_type):
    """
    Initialize the segmenting anything with model_type in ['vit_b', 'vit_l', 'vit_h']
    """
    
    sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
    predictor = SamPredictor(sam)

    return predictor

def init_vitmatte(model_type):
    """
    Initialize the vitmatte with model_type in ['vit_s', 'vit_b']
    """
    cfg = LazyConfig.load(vitmatte_config[model_type])
    vitmatte = instantiate(cfg.model)
    vitmatte.to(device)
    vitmatte.eval()
    DetectionCheckpointer(vitmatte).load(vitmatte_models[model_type])

    return vitmatte

def generate_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10):
    erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8)
    dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
    eroded = cv2.erode(mask, erode_kernel, iterations=5)
    dilated = cv2.dilate(mask, dilate_kernel, iterations=5)
    trimap = np.zeros_like(mask)
    trimap[dilated==255] = 128
    trimap[eroded==255] = 255
    return trimap

# user click the image to get points, and show the points on the image
def get_point(img, sel_pix, point_type, evt: gr.SelectData):
    if point_type == 'foreground_point':
        sel_pix.append((evt.index, 1))   # append the foreground_point
    elif point_type == 'background_point':
        sel_pix.append((evt.index, 0))    # append the background_point
    else:
        sel_pix.append((evt.index, 1))    # default foreground_point
    # draw points
    for point, label in sel_pix:
        cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
    if img[..., 0][0, 0] == img[..., 2][0, 0]:  # BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img if isinstance(img, np.ndarray) else np.array(img)

# undo the selected point
def undo_points(orig_img, sel_pix):
    temp = orig_img.copy()
    # draw points
    if len(sel_pix) != 0:
        sel_pix.pop()
        for point, label in sel_pix:
            cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
    if temp[..., 0][0, 0] == temp[..., 2][0, 0]:  # BGR to RGB
        temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
    return temp if isinstance(temp, np.ndarray) else np.array(temp)

# undo all selected points
def undo_all_points(orig_img, sel_pix):
    if orig_img is None:
        raise gr.Error("Please upload pictures first!")
    else:
        temp = orig_img.copy()
        while len(sel_pix) != 0:
            sel_pix.pop()
        if temp[..., 0][0, 0] == temp[..., 2][0, 0]:  # BGR to RGB
            temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
    return temp if isinstance(temp, np.ndarray) else np.array(temp)

# clear the fg_caption
def clear_fg_caption(fg_caption):
    fg_caption = ""
    return fg_caption

# once user upload an image, the original image is stored in `original_image`
def store_img(img):
    return img, []  # when new image is uploaded, `selected_points` should be empty

def convert_pixels(gray_image, boxes):
    converted_image = np.copy(gray_image)

    for box in boxes:
        x1, y1, x2, y2 = box
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        converted_image[y1:y2, x1:x2][converted_image[y1:y2, x1:x2] == 1] = 0.5

    return converted_image

if __name__ == "__main__":
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    sam_model = 'vit_h'
    vitmatte_model = 'vit_b'
    
    colors = [(255, 0, 0), (0, 255, 0)]
    markers = [1, 5]

    print('Initializing models... Please wait...')

    predictor = init_segment_anything(sam_model)
    vitmatte = init_vitmatte(vitmatte_model)
    grounding_dino = dino_load_model(grounding_dino['config'], grounding_dino['weight'])

    def run_inference(input_x, selected_points, erode_kernel_size, dilate_kernel_size, fg_box_threshold, fg_text_threshold, fg_caption, 
                      tr_box_threshold, tr_text_threshold, save_name, tr_caption = "glass, lens, crystal, diamond, bubble, bulb, web, grid"):
        
        predictor.set_image(input_x)

        dino_transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        image_transformed, _ = dino_transform(Image.fromarray(input_x), None)
        
        if len(selected_points) != 0:
            points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1)
            labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1)
            transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2])
            print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points)
            point_coords=transformed_points.permute(1, 0, 2)
            point_labels=labels.permute(1, 0)
        else:
            transformed_points, labels = None, None
            point_coords, point_labels = None, None
        
        if fg_caption is not None and fg_caption != "": # This section has benefited from the contributions of neuromorph,thanks! 
            fg_boxes, logits, phrases = dino_predict(
                model=grounding_dino,
                image=image_transformed,
                caption=fg_caption,
                box_threshold=fg_box_threshold,
                text_threshold=fg_text_threshold,
                device=device)
            print(logits, phrases)
            if fg_boxes.shape[0] == 0:
                # no fg object detected
                transformed_boxes = None
            else:
                h, w, _ = input_x.shape
                fg_boxes = torch.Tensor(fg_boxes).to(device)
                fg_boxes = fg_boxes * torch.Tensor([w, h, w, h]).to(device)
                fg_boxes = box_convert(boxes=fg_boxes, in_fmt="cxcywh", out_fmt="xyxy")
                transformed_boxes = predictor.transform.apply_boxes_torch(fg_boxes, input_x.shape[:2])
        else:
            transformed_boxes = None
                    
        # predict segmentation according to the boxes
        masks, scores, logits = predictor.predict_torch(
            point_coords = point_coords,
            point_labels = point_labels,
            boxes = transformed_boxes,
            multimask_output = False,
        )
        masks = masks.cpu().detach().numpy()
        mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3))
        for ann in masks:
            color_mask = np.random.random((1, 3)).tolist()[0]
            for i in range(3):
                mask_all[ann[0] == True, i] = color_mask[i]
        img = input_x / 255 * 0.3 + mask_all * 0.7
        
        # generate alpha matte
        torch.cuda.empty_cache()
        mask = masks[0][0].astype(np.uint8)*255
        trimap = generate_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32)
        trimap[trimap==128] = 0.5
        trimap[trimap==255] = 1
        
        boxes, logits, phrases = dino_predict(
            model=grounding_dino,
            image=image_transformed,
            caption= tr_caption,
            box_threshold=tr_box_threshold,
            text_threshold=tr_text_threshold,
            device=device)
        annotated_frame = dino_annotate(image_source=input_x, boxes=boxes, logits=logits, phrases=phrases)
        
        annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)

        if boxes.shape[0] == 0:
            # no transparent object detected
            pass
        else:
            h, w, _ = input_x.shape
            boxes = boxes * torch.Tensor([w, h, w, h])
            xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
            trimap = convert_pixels(trimap, xyxy)

        input = {
            "image": torch.from_numpy(input_x).permute(2, 0, 1).unsqueeze(0)/255,
            "trimap": torch.from_numpy(trimap).unsqueeze(0).unsqueeze(0),
        }

        torch.cuda.empty_cache()
        alpha = vitmatte(input)['phas'].flatten(0,2)
        alpha = alpha.detach().cpu().numpy()
        
        # get a green background
        background = generate_checkerboard_image(input_x.shape[0], input_x.shape[1], 8)

        # calculate foreground with alpha blending
        foreground_alpha = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255

        # calculate foreground with mask
        foreground_mask = input_x * np.expand_dims(mask/255, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(mask/255, axis=2).repeat(3,2))/255

        # concatenate input_x and foreground_alpha
        cv2_alpha = (np.expand_dims(alpha, axis=2)*255).astype(np.uint8)
        cv2_input_x = cv2.cvtColor(input_x, cv2.COLOR_BGR2RGB)
        rgba = np.concatenate((cv2_input_x, cv2_alpha), axis=2)
        cv2.imwrite(f'your_demos/{save_name}.png', rgba)

        foreground_alpha[foreground_alpha>1] = 1
        foreground_mask[foreground_mask>1] = 1

        # return img, mask_all
        trimap[trimap==1] == 0.999

        # new background

        background_1 = cv2.imread('figs/sea.jpg')
        background_2 = cv2.imread('figs/forest.jpg')
        background_3 = cv2.imread('figs/sunny.jpg')

        background_1 = cv2.resize(background_1, (input_x.shape[1], input_x.shape[0]))
        background_2 = cv2.resize(background_2, (input_x.shape[1], input_x.shape[0]))
        background_3 = cv2.resize(background_3, (input_x.shape[1], input_x.shape[0]))

        # to RGB
        background_1 = cv2.cvtColor(background_1, cv2.COLOR_BGR2RGB)
        background_2 = cv2.cvtColor(background_2, cv2.COLOR_BGR2RGB)
        background_3 = cv2.cvtColor(background_3, cv2.COLOR_BGR2RGB)

        # use alpha blending
        new_bg_1 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_1 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255
        new_bg_2 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_2 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255
        new_bg_3 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_3 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255

        return  mask, alpha, foreground_mask, foreground_alpha, new_bg_1, new_bg_2, new_bg_3

    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # <center>Matte Anything🐒 !
            """
        )
        with gr.Row(equal_height=True): #  .style(equal_height=True):
            with gr.Column():
                # input image
                original_image = gr.State(value="numpy")   # store original image without points, default None
                input_image = gr.Image(type="numpy", label="Input Image")                             
                # prompt (point or text)
                # Point Input
                with gr.Tab(label='Point Input') as Tab1:
                    with gr.Column():
                        selected_points = gr.State([])      # store points
                        radio = gr.Radio(['foreground_point', 'background_point'], label='Point Labels')
                        with gr.Row():
                            undo_button = gr.Button('Remove Point')
                            undo_all_button = gr.Button('Remove All  Points')
                # Foreground Text Input
                with gr.Tab(label='Foreground Text Input') as Tab2:
                    with gr.Blocks():
                        gr.Markdown("Foreground Text Input")
                        fg_caption = gr.Textbox(lines=1, placeholder="the girl in the middle", label="foreground input text")
                

            # Save Config
            with gr.Tab(label='Save Config') as Tab3:
                with gr.Blocks():
                    gr.Markdown("save name")
                    save_dir = gr.Textbox(lines=1, placeholder="the girl in the middle", label="Give a name of your demo. It will be saved in ```your_demos/your_name.pny```")
            
            
            # run button
            button = gr.Button("Start!")

            # Trimap Settings
            with gr.Tab(label='Trimap Settings'):
                gr.Markdown("Trimap Settings")
                erode_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="erode_kernel_size")
                dilate_kernel_size = gr.Slider(minimum=1, maximum=30, step=1, value=10, label="dilate_kernel_size")
            
            # Input Text Settings
            with gr.Tab(label='Input Text Settings'):
                gr.Markdown("Input Text Settings")
                fg_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_box_threshold")
                fg_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.25, label="foreground_text_threshold")

            # Transparency Settings
            with gr.Tab(label='Transparency Settings'):
                gr.Markdown("Transparency Settings")
                tr_caption = gr.Textbox(lines=1, placeholder="glass.lens.crystal.diamond.bubble.bulb.web.grid", label="transparency input text")
                tr_box_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.5, label="transparency_box_threshold")
                tr_text_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.005, value=0.25, label="transparency_text_threshold")

        
        with gr.Column():

            # show the image with mask
            with gr.Tab(label='SAM Mask'):
                mask = gr.Image(type='numpy')
            # with gr.Tab(label='Trimap'):
            #     trimap = gr.Image(type='numpy')
            with gr.Tab(label='Alpha Matte'):
                alpha = gr.Image(type='numpy')
            # show only mask
            with gr.Tab(label='Foreground by SAM Mask'):
                foreground_by_sam_mask = gr.Image(type='numpy')
            with gr.Tab(label='Refined by ViTMatte'):
                refined_by_vitmatte = gr.Image(type='numpy')
            # with gr.Tab(label='Transparency Detection'):
            #     transparency = gr.Image(type='numpy')
            with gr.Tab(label='New Background 1'):
                new_bg_1 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 2'):
                new_bg_2 = gr.Image(type='numpy')
            with gr.Tab(label='New Background 3'):
                new_bg_3 = gr.Image(type='numpy')

        input_image.upload(
            store_img,
            [input_image],
            [original_image, selected_points]
        )
        input_image.select(
            get_point,
            [input_image, selected_points, radio],
            [input_image],
        )
        undo_button.click(
            undo_points,
            [original_image, selected_points],
            [input_image]
        )
        undo_all_button.click(
            undo_all_points,
            [original_image, selected_points],
            [input_image]
        )
        Tab1.select(
            clear_fg_caption,
            [fg_caption],
            [fg_caption]
        )
        Tab2.select(
            undo_all_points,
            [original_image, selected_points],
            [input_image]
        )
        
        
        button.click(run_inference, inputs=[original_image, selected_points, erode_kernel_size, dilate_kernel_size, fg_box_threshold, fg_text_threshold, fg_caption, tr_box_threshold, tr_text_threshold, \
                                            save_dir, tr_caption],  outputs=[mask, alpha, foreground_by_sam_mask, refined_by_vitmatte, new_bg_1, new_bg_2, new_bg_3])

    with gr.Row():
        with gr.Column():
            background_image = gr.State(value=None)

    demo.launch()

these full code will work

mokurin000 avatar Feb 11 '25 06:02 mokurin000

@mokurin000

This works for me. Thank you for sharing your work. 🌹😙👍

For people getting "TypeError: argument of type 'bool' is not iterable" while running this code, the problem is with your pydantic version being too recent. You'll need to downgrade it by pip install pydantic==2.10.6. For reference, check out these pages: https://github.com/Stability-AI/stable-virtual-camera/issues/39 https://github.com/gradio-app/gradio/issues/10662#issuecomment-2677236567

感谢楼上大神分享代码🌹😙👍

如果有人在跑这段代码时遇到 "TypeError: argument of type 'bool' is not iterable" 报错,这是由于pydantic版本过新,可以用 pip install pydantic==2.10.6来降级它。参考解决方法来源: https://github.com/Stability-AI/stable-virtual-camera/issues/39 https://github.com/gradio-app/gradio/issues/10662#issuecomment-2677236567

superVariant avatar Apr 12 '25 11:04 superVariant