StyleSwap icon indicating copy to clipboard operation
StyleSwap copied to clipboard

How to get the expression error

Open shiwk20 opened this issue 2 years ago • 3 comments

Thanks for your wonderful work! Recently I'm working on another faceswap network, but when I tried to reproduce the expression error in the paper, I met with difficulties. I use the paper A Compact Embedding for Facial Expression Similarity refered in your paper to extract expression embeddings. The code I used is AmirSh15/FECNet which is the only pytorch implementation I could find. But when I compute the L2 distance of the 16-dim embeddings of target face and swap face, I found that the result is always around 0.5(I tested DeepFakes, FaceShifter.etc), which differ greatly with the result you put on your paper. So could you please give me some details about how you compute the expression error? I would appreciate it if you could show your expression error code to your convenience.

shiwk20 avatar Feb 08 '23 03:02 shiwk20

The code I used to compute expression error is like this:

def get_expr(tgt_imgs, gen_imgs, model):
    '''
    input tensor: b * 3 * h * w
    '''
    tgt_out = model(tgt_imgs) # 10 x 16
    gen_out = model(gen_imgs)

    return torch.sqrt(torch.sum((tgt_out - gen_out) ** 2, dim=1))

def test_deepfakes(model):
    df_data_root = 'data/MyFF++_no_rotation/DeepFakes/images256'
    print('test_deepfakes')
    ori_data_root = 'data/MyFF++_no_rotation/samples/images256'
    landmarks = pickle.load(open('data/MyFF++_no_rotation/landmark/landmarks256.pkl', 'rb'))
    
    df_video_list = os.listdir(df_data_root)
    df_video_list.sort()
    ori_video_list = os.listdir(ori_data_root)
    ori_video_list.sort()
    samples = json.load(open('data/MyFF++/samples/samples.json'))
    mtcnn = MTCNN(image_size=224)
    gen_imgs = torch.zeros((10, 3, 224,224))
    tgt_imgs = torch.zeros((10, 3, 224,224))

    ave_expr_error = 0
    count = 0
    with torch.no_grad():
        for video_idx, video in enumerate(tqdm(ori_video_list)):
            df_video = df_video_list[video_idx]
            for img_idx, img in enumerate(samples[video]):
                df_image = Image.open(os.path.join(df_data_root, df_video, img)).convert('RGB')
                ori_image = Image.open(os.path.join(ori_data_root, video, img)).convert('RGB')
                
                tmp = mtcnn(df_image, return_prob=False)
                gen_imgs[img_idx]= tmp
                tmp = mtcnn(ori_image, return_prob=False)
                tgt_imgs[img_idx]= tmp
                count += 1
            
            tmp = get_expr(tgt_imgs.to(device), gen_imgs.to(device), model)
            ave_expr_error += tmp.sum()
            print(f'{video_idx+1}/1000', ave_expr_error / count, 'all: ', tmp, 'count', count)
    print('final', ave_expr_error / count)

if __name__ == '__main__':
    device = torch.device('cuda:0')
    model = FECNet(pretrained=True)
    
    model.to(device)
    # Test the Model
    model.eval()  # Change model to 'eval' mode
    
    print('start test')
    test_deepfakes(model)

shiwk20 avatar Feb 08 '23 03:02 shiwk20

Hello, I also have this situation. Have you solved this problem

chentting avatar Mar 17 '23 12:03 chentting

Hello, I have a similar situation. Have you found a solution to this problem? @chentting @shiwk20

zhouzk5 avatar Oct 28 '23 13:10 zhouzk5