RobustVideoMatting icon indicating copy to clipboard operation
RobustVideoMatting copied to clipboard

How to run it live?

Open bycloudai opened this issue 4 years ago • 8 comments

Hello, Thank you for the amazing work! I am just wondering how we can make this perform live just like how the online webcam demo works. I would love to test out 4K/HD live inputs on different GPUs. Thank you

bycloudai avatar Sep 25 '21 16:09 bycloudai

I have been having the same question and I am interested in developing a solution for Linux(just like the BGMv2 Linux demo using v4l2loopback).

charitarthchugh avatar Sep 25 '21 16:09 charitarthchugh

I tried adapting the existing webcam test from BackgroundMattingV2 to work with this one. It probably could be done better but I'm not that good of a programmer.

import argparse, os, shutil, time
import cv2
import torch

from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.transforms.functional import to_pil_image
from threading import Thread, Lock
from tqdm import tqdm
from PIL import Image





# ----------- Utility classes -------------


# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
# Use .read() in a tight loop to get the newest frame
class Camera:
    def __init__(self, device_id=0, width=1280, height=720):
        self.capture = cv2.VideoCapture(device_id)
        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
        self.success_reading, self.frame = self.capture.read()
        self.read_lock = Lock()
        self.thread = Thread(target=self.__update, args=())
        self.thread.daemon = True
        self.thread.start()

    def __update(self):
        while self.success_reading:
            grabbed, frame = self.capture.read()
            with self.read_lock:
                self.success_reading = grabbed
                self.frame = frame

    def read(self):
        with self.read_lock:
            frame = self.frame.copy()
        return frame
    def __exit__(self, exec_type, exc_value, traceback):
        self.capture.release()

# An FPS tracker that computes exponentialy moving average FPS
class FPSTracker:
    def __init__(self, ratio=0.5):
        self._last_tick = None
        self._avg_fps = None
        self.ratio = ratio
    def tick(self):
        if self._last_tick is None:
            self._last_tick = time.time()
            return None
        t_new = time.time()
        fps_sample = 1.0 / (t_new - self._last_tick)
        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
        self._last_tick = t_new
        return self.get()
    def get(self):
        return self._avg_fps

# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
# It also tracks FPS and optionally overlays info onto the stream.
class Displayer:
    def __init__(self, title, width=None, height=None, show_info=True):
        self.title, self.width, self.height = title, width, height
        self.show_info = show_info
        self.fps_tracker = FPSTracker()
        cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
        if width is not None and height is not None:
            cv2.resizeWindow(self.title, width, height)
    # Update the currently showing frame and return key press char code
    def step(self, image):
        fps_estimate = self.fps_tracker.tick()
        if self.show_info and fps_estimate is not None:
            message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
        cv2.imshow(self.title, image)
        return cv2.waitKey(1) & 0xFF


# --------------- Main ---------------


import torch
from model import MattingNetwork
model = MattingNetwork(variant='mobilenetv3').eval().cuda() # Change Variant
model.load_state_dict(torch.load("checkpoints/rvm_mobilenetv3.pth")) # Change to checkpoint path


width, height = (1280, 720)
cam = Camera(width=width, height=height)
dsp = Displayer('MattingV2', cam.width, cam.height, show_info=True)

def cv2_frame_to_cuda(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()

with torch.no_grad():
    while True:
        rec = [None] * 4 
        while True: # matting
            frame = cam.read()
            src = cv2_frame_to_cuda(frame)
            #pha, fgr = model(src, bgr)[:2]
            fgr, pha, *rec = model(src, *rec, downsample_ratio=0.375) # Set Downsample Ratio
            bgrgreen = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()
            #res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
            res = fgr * pha + bgrgreen * (1 - pha)
            res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
            res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
            key = dsp.step(res)
            if key == ord('b'):
                break
            elif key == ord('q'):
                exit()

DrPleaseRespect avatar Sep 26 '21 06:09 DrPleaseRespect

@DrPleaseRespect If you are inferencing on 1280x720, the downsample_ratio should be set much higher. The current 0.125 is too low.

PeterL1n avatar Sep 26 '21 07:09 PeterL1n

Another thing to note.

I believe there is a way to display the tensor without moving it back to CPU through OpenGL version of OpenCV. The BGMv2 script doesn't do that, and the data transfer costs additional latency.

PeterL1n avatar Sep 26 '21 07:09 PeterL1n

I add a real-time video matting demo reference from BackgroundMattingV2, you can also customize your background image in real time.

jackhanyuan avatar Sep 28 '21 10:09 jackhanyuan

Thank you. This is the problem I found: Import "model" could not be resolved

hileroy132 avatar Sep 28 '21 12:09 hileroy132

@hileroy132 You suppose to clone this repo, and copy the script above to a new file in the repo and run it from there.

PeterL1n avatar Sep 28 '21 19:09 PeterL1n

Was anyone able to set up a live demo just like the webcam demo? It's not included in the repo, is it?

Agusteando avatar Jun 01 '22 01:06 Agusteando