stylegan2-encoder-pytorch
stylegan2-encoder-pytorch copied to clipboard
interpolation.ipynb ImportError: No module named 'fused'
import os
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from model import Generator, Encoder
from train_encoder import VGGLoss
import matplotlib.pyplot as plt
def image2tensor(image):
image = torch.FloatTensor(image).permute(2,0,1).unsqueeze(0)/255.
return (image-0.5)/0.5
def tensor2image(tensor):
tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1,2,0).cpu().numpy()
return tensor*0.5 + 0.5
def imshow(img, size=5, cmap='jet'):
plt.figure(figsize=(size,size))
plt.imshow(img, cmap=cmap)
plt.axis('off')
plt.show()
device = 'cuda'
image_size=256
g_model_path = '/content/generator_ffhq.pt'
g_ckpt = torch.load(g_model_path, map_location=device)
latent_dim = g_ckpt['args'].latent
generator = Generator(image_size, latent_dim, 8).to(device)
generator.load_state_dict(g_ckpt["g_ema"], strict=False)
generator.eval()
print('[generator loaded]')
e_model_path = '/content/encoder_ffhq.pt'
e_ckpt = torch.load(e_model_path, map_location=device)
encoder = Encoder(image_size, latent_dim).to(device)
encoder.load_state_dict(e_ckpt['e'])
encoder.eval()
print('[encoder loaded]')
truncation = 0.7
trunc = generator.mean_latent(4096).detach().clone()
with torch.no_grad():
latent = generator.get_latent(torch.randn(4*6, latent_dim, device=device))
imgs_gen, _ = generator([latent],
truncation=truncation,
truncation_latent=trunc,
input_is_latent=True,
randomize_noise=True)
result = []
for row in imgs_gen.chunk(4, dim=0):
result.append(torch.cat([img for img in row], dim=2))
result = torch.cat(result, dim=1)
print('generated samples:')
imshow(tensor2image(result), size=15)
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
<ipython-input-254-b06be6808604> in <module>()
10 from torchvision import datasets, transforms
11
---> 12 from model import Generator, Encoder
13 from train_encoder import VGGLoss
14
6 frames
/usr/lib/python3.6/imp.py in find_module(name, path)
295 break # Break out of outer loop when breaking out of inner loop.
296 else:
--> 297 raise ImportError(_ERR_MSG.format(name), name=name)
298
299 encoding = None
ImportError: No module named 'fused'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
Hi, I'm having a similar problem, did you find out how to solve it?
Hi, probably this thread might help. Below is the env that I used btw.
gpu: GeForce RTX 2080 Ti torch: 1.3.1 cuda: 10.1.243