cog icon indicating copy to clipboard operation
cog copied to clipboard

Output Unknown error handling prediction.

Open swapnil-lader opened this issue 1 year ago • 0 comments

Hi, I am trying to return an image using model prediction and it's throwing me an unknown error, In local I am able to get the output out of the model but in replicate it throws the above error.

##Base Code

from cog import BasePredictor, Input, Path,File import os import cv2 from PIL import Image import pandas as pd import os import torch from src.dataset import DatasetInference from src.model import ModelAlpha, Model from models.alpha_model_config import cfg as alpha_model_cfg from models.trimap_model_config import cfg as trimap_model_cfg from tqdm import tqdm import numpy as np import pandas as pd import cv2 import glob import torch from torch.utils.data import DataLoader import albumentations as alb from os import path import tempfile os.environ['TRANSFORMERS_CACHE'] = './cache/' os.environ['TORCH_HOME'] = './cache/'

class Predictor(BasePredictor):

def setup(self):
    self.device = torch.device("cpu")
    self.alpha_model_general , self.trimap_model_general = self.get_models("models/alpha_model_general.ckpt" , "models/trimap_model_general.ckpt" , self.device)

def remove_data(self):
    normal_images = os.listdir("data/output/val/images/")
    mask_path = "output_mask/"
    mask_images = os.listdir(mask_path)

    for mask , normal in zip(mask_images , normal_images):
        os.remove(mask_path + mask)
        os.remove("data/output/val/images/" + normal)

    print("data cleaning complete!")
    
def create_csv(self,image_name , tracking_id):    
    test_dataset = {
        "image_path" : [image_name],
        "id" : [tracking_id]
    }
    dataframe = pd.DataFrame(test_dataset)
    csv_file = "data/test_dataset_" + tracking_id + ".csv"
    dataframe.to_csv(csv_file)
    # print("CSV file created!")
    return csv_file

def get_models(self,alpha_path , trimap_path , device = torch.device("cpu")):
    # print('Load TriMap Model...')
    trimap_model = Model(cfg=trimap_model_cfg)
    trimap_model = trimap_model.load_from_checkpoint(
        trimap_path, cfg=trimap_model_cfg)
    trimap_model.freeze()
    trimap_model = trimap_model.to(device)
    # print('OK!\n')

    # print('Load Alpha Model...')
    alpha_model = ModelAlpha(cfg=alpha_model_cfg)
    alpha_model = alpha_model.load_from_checkpoint(
        alpha_path, cfg=alpha_model_cfg)
    alpha_model.freeze()
    alpha_model = alpha_model.to(device)
    # print('OK!\n')

    return alpha_model , trimap_model

def blend(self, image_path, alpha_path):   
    foreground = cv2.imread(image_path)
    alpha = cv2.imread(alpha_path)
    background = np.ones(foreground.shape)

    foreground = foreground.astype(float)
    background = background.astype(float)

    alpha = alpha.astype(float)/255
    
    # print(foreground.shape , alpha.shape)

    foreground = cv2.multiply(alpha, foreground)
    background = cv2.multiply(1.0 - alpha, background)

    blend = cv2.add(foreground, background)

    return blend

def prepare_dataset(self,**kwargs):
    src_df = pd.read_csv(kwargs['image_csv_path'])
    # print(f'Alpha masks will be built for {len(src_df)} image(s).')

    #  get path to save results
    dst_dir = kwargs['mask_path']
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir, exist_ok=True)

    # print(f'Alpha masks will be saved in {dst_dir}\n')

    output_dir = kwargs["output_path"]
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # print('Load the models...')

    # print('Create dataset...')
    t = alb.Compose([
        alb.Resize(512, 512),
        alb.Normalize(mean=trimap_model_cfg.MEAN, std=trimap_model_cfg.STD)])

    test_dataset = DatasetInference(df=src_df, transform=t)

    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=kwargs.get('batch_size', 1),
        shuffle=False,
        num_workers=kwargs.get('workers', 2))
    # print('OK!\n')

    #  iterate over data and build alphas
    # print('Get alpha masks...')

    return test_dataloader , dst_dir , output_dir

def predict_output(self,dataset , test_dataloader  , dst_dir, output_dir , device = torch.device("cpu")):
    public_urls = []
    mask_paths = []
    k = 0

    for batch in tqdm(test_dataloader, total=len(test_dataloader)):
        # print(batch)
        img, data = batch
        img = img.to(device)
        # print(img)
        # print(img.shape)
        trimap = self.trimap_model_general(img)
        trimap = self.trimap_model_general.activation(trimap)
        trimap = trimap.argmax(1, keepdim=True)
        trimap_processed = trimap.clone().detach()
        trimap_processed[trimap_processed == 1] = 255
        trimap_processed[trimap_processed == 2] = 128
        trimap_processed = trimap_processed.to(torch.float32)
        trimap_processed = trimap_processed / 255

        img_trimap = torch.cat([img, trimap_processed], dim=1)
        trimap_output = trimap_processed.detach().cpu().numpy()*255
        trimap_output = np.reshape(trimap_output , (trimap_output.shape[0] , trimap_output.shape[2] , trimap_output.shape[3] , trimap_output.shape[1]))

        alpha = self.alpha_model_general(img_trimap)
        alpha = self.alpha_model_general.activation(alpha['refine_output'])

        trimap_final = alpha
        k+=1
        # print("reached for loop")
        for i in range(len(img)):
            # print(dataset['id'][i] , f'_{k}_{i}.png')
            mask_path = os.path.join(
                dst_dir,
                dataset['id'][i] + f'_{k}_{i}.png')

            m = trimap_final[i].cpu().detach().numpy()[0, ...]
            m = np.uint8(m * 255)

            m = cv2.resize(m, (int(data['size'][0][i]), int(data['size'][1][i] )))
            # m = np.expand_dims(m, axis=2)

            numpy_image = img[i].cpu().detach().numpy()
            
            numpy_image = np.reshape(numpy_image , (numpy_image.shape[1] ,numpy_image.shape[2] ,numpy_image.shape[0]))

            resized_image = cv2.resize(numpy_image , (int(data['size'][0][i]), int(data['size'][1][i])))
            # print(resized_image.shape , m.shape)

            # blended_image = blend(resized_image , m)

            cv2.imwrite(f"{mask_path}" , m)
            # cv2.imwrite(mask_path, m)

            mask_paths.append(mask_path)

        image_path = sorted(glob.glob('data/output/val/images/*'))
        alpha_path = sorted(glob.glob(dst_dir + '/*'))


        for i in tqdm(image_path):
            # try:
            ID = dataset[dataset["image_path"]==i]["id"].to_string().split(" ")[-1]
            # print(dataset[dataset["image_path"]==i]["id"].to_string().split(" ")[-1])

            for j in tqdm(alpha_path):
                if ID in j: 
                    blended_image =  self.blend(i, j)
                    output_path = Path(tempfile.mkdtemp()) / "output.png"
                    cv2.imwrite(str(output_path) , blended_image)

                    # os.remove(output_location)

            # except Exception as e:
            #     print(e)

        self.remove_data()

        return output_path

# The arguments and types the model takes as input
def predict(self,
        image: Path = Input(description="Image to run inference on"),
        tracking_id: str = Input(description="Insert Your tracking id",default="test")
) -> Path:
    """Run a single prediction on the model"""
    filepath = str(image)
    tracking_id = tracking_id
    device = "cpu"
    raw_image = Image.open(filepath).convert("RGB")
    input_image_dir = "data/output/val/images"
    if path.exists(input_image_dir):
        pass
    else:
        os.makedirs(input_image_dir)
    image_download_path = input_image_dir+"/"+"test.png"
    # print(image_download_path)
    raw_image.save(image_download_path)
    
    csv_path = self.create_csv(image_download_path , tracking_id)
    # print(csv_path)
    if path.exists("blended_images/") and path.exists("output_mask/"):
        pass
    else:
        os.makedirs("blended_images/")
        os.makedirs("output_mask/")
    req = {}    
    req["image_csv_path"] = csv_path
    req["mask_path"] = "output_mask/"
    req["output_path"] = "output/"
    test_dataloader , dst_dir , output_dir = self.prepare_dataset(**req)   
    dataset = pd.read_csv(req["image_csv_path"])
    
    # print(test_dataloader , dst_dir , output_dir , dataset.head())
    
    output_location = self.predict_output(
        dataset,
        test_dataloader,
        dst_dir,
        output_dir,
        device)
    # print(output_location,type(output_location))
    return output_location

swapnil-lader avatar Aug 23 '23 06:08 swapnil-lader