stylegan2-encoder-pytorch icon indicating copy to clipboard operation
stylegan2-encoder-pytorch copied to clipboard

interpolation.ipynb ImportError: No module named 'fused'

Open molo32 opened this issue 4 years ago • 2 comments


import os
import random
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import torchvision
from torchvision import datasets, transforms

from model import Generator, Encoder
from train_encoder import VGGLoss

import matplotlib.pyplot as plt


def image2tensor(image):
    image = torch.FloatTensor(image).permute(2,0,1).unsqueeze(0)/255.
    return (image-0.5)/0.5

def tensor2image(tensor):
    tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1,2,0).cpu().numpy()
    return tensor*0.5 + 0.5

def imshow(img, size=5, cmap='jet'):
    plt.figure(figsize=(size,size))
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.show()


device = 'cuda'
image_size=256

g_model_path = '/content/generator_ffhq.pt'
g_ckpt = torch.load(g_model_path, map_location=device)

latent_dim = g_ckpt['args'].latent

generator = Generator(image_size, latent_dim, 8).to(device)
generator.load_state_dict(g_ckpt["g_ema"], strict=False)
generator.eval()
print('[generator loaded]')

e_model_path = '/content/encoder_ffhq.pt'
e_ckpt = torch.load(e_model_path, map_location=device)

encoder = Encoder(image_size, latent_dim).to(device)
encoder.load_state_dict(e_ckpt['e'])
encoder.eval()
print('[encoder loaded]')

truncation = 0.7
trunc = generator.mean_latent(4096).detach().clone()

with torch.no_grad():
    latent = generator.get_latent(torch.randn(4*6, latent_dim, device=device))
    imgs_gen, _ = generator([latent],
                              truncation=truncation,
                              truncation_latent=trunc,
                              input_is_latent=True,
                              randomize_noise=True)

    result = []
    for row in imgs_gen.chunk(4, dim=0):
        result.append(torch.cat([img for img in row], dim=2))
    result = torch.cat(result, dim=1)
    print('generated samples:')
    imshow(tensor2image(result), size=15)
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-254-b06be6808604> in <module>()
     10 from torchvision import datasets, transforms
     11 
---> 12 from model import Generator, Encoder
     13 from train_encoder import VGGLoss
     14 

6 frames
/usr/lib/python3.6/imp.py in find_module(name, path)
    295         break  # Break out of outer loop when breaking out of inner loop.
    296     else:
--> 297         raise ImportError(_ERR_MSG.format(name), name=name)
    298 
    299     encoding = None

ImportError: No module named 'fused'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.

molo32 avatar Jan 08 '21 00:01 molo32

Hi, I'm having a similar problem, did you find out how to solve it?

JohannesAck avatar Jan 23 '21 23:01 JohannesAck

Hi, probably this thread might help. Below is the env that I used btw.

gpu: GeForce RTX 2080 Ti torch: 1.3.1 cuda: 10.1.243

bryandlee avatar Jan 24 '21 22:01 bryandlee