big-sleep
big-sleep copied to clipboard
Run on multiply gpus
Hi ive tried to add multiply gpu support by nn.DataParallel Now its looks like this ` import os import sys import subprocess import signal import string import re
from datetime import datetime from pathlib import Path import random
import torch import torch.nn.functional as F from torch import nn from torch.optim import Adam from torchvision.utils import save_image import torchvision.transforms as T from PIL import Image from tqdm import tqdm, trange
from ema import EMA from resample import resample from biggan import BigGAN from clip import load, tokenize
assert torch.cuda.is_available(), 'CUDA must be available in order to use Big Sleep'
graceful keyboard interrupt
terminate = False
def signal_handling(signum,frame): global terminate terminate = True
signal.signal(signal.SIGINT,signal_handling)
helpers
def exists(val): return val is not None
def open_folder(path): if os.path.isfile(path): path = os.path.dirname(path)
if not os.path.isdir(path):
return
cmd_list = None
if sys.platform == 'darwin':
cmd_list = ['open', '--', path]
elif sys.platform == 'linux2' or sys.platform == 'linux':
cmd_list = ['xdg-open', path]
elif sys.platform in ['win32', 'win64']:
cmd_list = ['explorer', path.replace('/','\\')]
if cmd_list == None:
return
try:
subprocess.check_call(cmd_list)
except subprocess.CalledProcessError:
pass
except OSError:
pass
def create_text_path(text=None, img=None, encoding=None): input_name = "" if text is not None: input_name += text if img is not None: if isinstance(img, str): img_name = "".join(img.split(".")[:-1]) # replace spaces by underscores, remove img extension img_name = img_name.split("/")[-1] # only take img name, not path else: img_name = "PIL_img" input_name += "" + img_name if encoding is not None: input_name = "your_encoding" return input_name.replace("-", "").replace(",", "").replace(" ", "").replace("|", "--").strip('-')[:255]
tensor helpers
def differentiable_topk(x, k, temperature=1.): n, dim = x.shape topk_tensors = []
for i in range(k):
is_last = i == (k - 1)
values, indices = (x / temperature).softmax(dim=-1).topk(1, dim=-1)
topks = torch.zeros_like(x).scatter_(-1, indices, values)
topk_tensors.append(topks)
if not is_last:
x = x.scatter(-1, indices, float('-inf'))
topks = torch.cat(topk_tensors, dim=-1)
return topks.reshape(n, k, dim).sum(dim = 1)
def create_clip_img_transform(image_width): clip_mean = [0.48145466, 0.4578275, 0.40821073] clip_std = [0.26862954, 0.26130258, 0.27577711] transform = T.Compose([ #T.ToPILImage(), T.Resize(image_width), T.CenterCrop((image_width, image_width)), T.ToTensor(), T.Normalize(mean=clip_mean, std=clip_std) ]) return transform
def rand_cutout(image, size, center_bias=False, center_focus=2): width = image.shape[-1] min_offset = 0 max_offset = width - size if center_bias: # sample around image center center = max_offset / 2 std = center / center_focus offset_x = int(random.gauss(mu=center, sigma=std)) offset_y = int(random.gauss(mu=center, sigma=std)) # resample uniformly if over boundaries offset_x = random.randint(min_offset, max_offset) if (offset_x > max_offset or offset_x < min_offset) else offset_x offset_y = random.randint(min_offset, max_offset) if (offset_y > max_offset or offset_y < min_offset) else offset_y else: offset_x = random.randint(min_offset, max_offset) offset_y = random.randint(min_offset, max_offset) cutout = image[:, :, offset_x:offset_x + size, offset_y:offset_y + size] return cutout
load clip
perceptor, normalize_image = load('ViT-B/32', jit = False)
load biggan
class Latents(torch.nn.Module): def init( self, num_latents = 15, num_classes = 1000, z_dim = 128, max_classes = None, class_temperature = 2. ): super().init() self.normu = torch.nn.Parameter(torch.zeros(num_latents, z_dim).normal_(std = 1)) self.cls = torch.nn.Parameter(torch.zeros(num_latents, num_classes).normal_(mean = -3.9, std = .3)) self.register_buffer('thresh_lat', torch.tensor(1))
assert not exists(max_classes) or max_classes > 0 and max_classes <= num_classes, f'max_classes must be between 0 and {num_classes}'
self.max_classes = max_classes
self.class_temperature = class_temperature
def forward(self):
if exists(self.max_classes):
classes = differentiable_topk(self.cls, self.max_classes, temperature = self.class_temperature)
else:
classes = torch.sigmoid(self.cls)
return self.normu, classes
class Model(nn.Module):
def init(
self,
image_size,
max_classes = None,
class_temperature = 2.,
ema_decay = 0.99
):
super().init()
assert image_size in (128, 256, 512), 'image size must be one of 128, 256, or 512'
self.biggan = BigGAN.from_pretrained(f'biggan-deep-{image_size}')
self.max_classes = max_classes
self.class_temperature = class_temperature
self.ema_decay
= ema_decay
self.init_latents()
def init_latents(self):
latents = Latents(
num_latents = len(self.biggan.config.layers) + 1,
num_classes = self.biggan.config.num_classes,
z_dim = self.biggan.config.z_dim,
max_classes = self.max_classes,
class_temperature = self.class_temperature
)
self.latents = EMA(latents, self.ema_decay)
def forward(self):
self.biggan.eval()
out = self.biggan(*self.latents(), 1)
return (out + 1) / 2
class BigSleep(nn.Module): def init( self, num_cutouts = 128, loss_coef = 100, image_size = 512, bilinear = False, max_classes = None, class_temperature = 2., experimental_resample = False, ema_decay = 0.99, center_bias = False, ): super().init() self.loss_coef = loss_coef self.image_size = image_size self.num_cutouts = num_cutouts self.experimental_resample = experimental_resample self.center_bias = center_bias
self.interpolation_settings = {'mode': 'bilinear', 'align_corners': False} if bilinear else {'mode': 'nearest'}
self.model =torch.nn.DataParallel(Model(
image_size = image_size,
max_classes = max_classes,
class_temperature = class_temperature,
ema_decay = ema_decay
).cuda())
def reset(self):
self.model.init_latents()
def sim_txt_to_img(self, text_embed, img_embed, text_type="max"):
sign = -1
if text_type == "min":
sign = 1
return sign * self.loss_coef * torch.cosine_similarity(text_embed, img_embed, dim = -1).mean()
def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
width, num_cutouts = self.image_size, self.num_cutouts
out = self.model()
if not return_loss:
return out
pieces = []
for ch in range(num_cutouts):
# sample cutout size
size = int(width * torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95))
# get cutout
apper = rand_cutout(out, size, center_bias=self.center_bias)
if (self.experimental_resample):
apper = resample(apper, (224, 224))
else:
apper = F.interpolate(apper, (224, 224), **self.interpolation_settings)
pieces.append(apper)
into = torch.cat(pieces)
into = normalize_image(into)
image_embed = perceptor.encode_image(into)
latents, soft_one_hot_classes = self.model.latents()
num_latents = latents.shape[0]
latent_thres = self.model.latents.model.thresh_lat
lat_loss = torch.abs(1 - torch.std(latents, dim=1)).mean() + \
torch.abs(torch.mean(latents, dim = 1)).mean() + \
4 * torch.max(torch.square(latents).mean(), latent_thres)
for array in latents:
mean = torch.mean(array)
diffs = array - mean
var = torch.mean(torch.pow(diffs, 2.0))
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0))
kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0
lat_loss = lat_loss + torch.abs(kurtoses) / num_latents + torch.abs(skews) / num_latents
cls_loss = ((50 * torch.topk(soft_one_hot_classes, largest = False, dim = 1, k = 999)[0]) ** 2).mean()
results = []
for txt_embed in text_embeds:
results.append(self.sim_txt_to_img(txt_embed, image_embed))
for txt_min_embed in text_min_embeds:
results.append(self.sim_txt_to_img(txt_min_embed, image_embed, "min"))
sim_loss = sum(results).mean()
return out, (lat_loss, cls_loss, sim_loss)
class Imagine(nn.Module): def init( self, *, text=None, img=None, encoding=None, text_min = "", lr = .07, image_size = 512, gradient_accumulate_every = 1, save_every = 50, epochs = 20, iterations = 1050, save_progress = False, bilinear = False, open_folder = True, seed = None, append_seed = False, torch_deterministic = False, max_classes = None, class_temperature = 2., save_date_time = False, save_best = False, experimental_resample = False, ema_decay = 0.99, num_cutouts = 128, center_bias = False, ): super().init()
if torch_deterministic:
assert not bilinear, 'the deterministic (seeded) operation does not work with interpolation (PyTorch 1.7.1)'
torch.set_deterministic(True)
self.seed = seed
self.append_seed = append_seed
if exists(seed):
print(f'setting seed of {seed}')
if seed == 0:
print('you can override this with --seed argument in the command line, or --random for a randomly chosen one')
torch.manual_seed(seed)
self.epochs = epochs
self.iterations = iterations
model = torch.nn.DataParallel(BigSleep(
image_size = image_size,
bilinear = bilinear,
max_classes = max_classes,
class_temperature = class_temperature,
experimental_resample = experimental_resample,
ema_decay = ema_decay,
num_cutouts = num_cutouts,
center_bias = center_bias,
).cuda())
self.model = (model)
self.lr = lr
self.optimizer = Adam(model.model.latents.model.parameters(), lr)
self.gradient_accumulate_every = gradient_accumulate_every
self.save_every = save_every
self.save_progress = save_progress
self.save_date_time = save_date_time
self.save_best = save_best
self.current_best_score = 0
self.open_folder = open_folder
self.total_image_updates = (self.epochs * self.iterations) / self.save_every
self.encoded_texts = {
"max": [],
"min": []
}
# create img transform
self.clip_transform = create_clip_img_transform(224)
# create starting encoding
self.set_clip_encoding(text=text, img=img, encoding=encoding, text_min=text_min)
@property
def seed_suffix(self):
return f'.{self.seed}' if self.append_seed and exists(self.seed) else ''
def set_text(self, text):
self.set_clip_encoding(text = text)
def create_clip_encoding(self, text=None, img=None, encoding=None):
self.text = text
self.img = img
if encoding is not None:
encoding = torch.nn.DataParallel(encoding.cuda())
#elif self.create_story:
# encoding = self.update_story_encoding(epoch=0, iteration=1)
elif text is not None and img is not None:
encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2
elif text is not None:
encoding = self.create_text_encoding(text)
elif img is not None:
encoding = self.create_img_encoding(img)
return encoding
def create_text_encoding(self, text):
tokenized_text = torch.nn.DataParallel(tokenize(text).cuda())
with torch.no_grad():
text_encoding = perceptor.encode_text(tokenized_text).detach()
return text_encoding
def create_img_encoding(self, img):
if isinstance(img, str):
img = Image.open(img)
normed_img = torch.nn.DataParallel(self.clip_transform(img).unsqueeze(0).cuda())
with torch.no_grad():
img_encoding = perceptor.encode_image(normed_img).detach()
return img_encoding
def encode_multiple_phrases(self, text, img=None, encoding=None, text_type="max"):
if text is not None and "|" in text:
self.encoded_texts[text_type] = [self.create_clip_encoding(text=prompt_min, img=img, encoding=encoding) for prompt_min in text.split("|")]
else:
self.encoded_texts[text_type] = [self.create_clip_encoding(text=text, img=img, encoding=encoding)]
def encode_max_and_min(self, text, img=None, encoding=None, text_min=""):
self.encode_multiple_phrases(text, img=img, encoding=encoding)
if text_min is not None and text_min != "":
self.encode_multiple_phrases(text_min, img=img, encoding=encoding, text_type="min")
def set_clip_encoding(self, text=None, img=None, encoding=None, text_min=""):
self.current_best_score = 0
self.text = text
self.text_min = text_min
if len(text_min) > 0:
text = text + "_wout_" + text_min[:255] if text is not None else "wout_" + text_min[:255]
text_path = create_text_path(text=text, img=img, encoding=encoding)
if self.save_date_time:
text_path = datetime.now().strftime("%y%m%d-%H%M%S-") + text_path
self.text_path = text_path
self.filename = Path(f'./{text_path}{self.seed_suffix}.png')
self.encode_max_and_min(text, img=img, encoding=encoding, text_min=text_min) # Tokenize and encode each prompt
def reset(self):
self.model.reset()
self.model = torch.nn.DataParallel(self.model.cuda())
self.optimizer = Adam(self.model.model.latents.parameters(), self.lr)
def train_step(self, epoch, i, pbar=None):
total_loss = 0
for _ in range(self.gradient_accumulate_every):
out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
loss = sum(losses) / self.gradient_accumulate_every
total_loss += loss
loss.backward()
self.optimizer.step()
self.model.model.latents.update()
self.optimizer.zero_grad()
if (i + 1) % self.save_every == 0:
with torch.no_grad():
self.model.model.latents.eval()
out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
top_score, best = torch.topk(losses[2], k=1, largest=False)
image = self.model.model()[best].cpu()
self.model.model.latents.train()
save_image(image, str(self.filename))
if pbar is not None:
pbar.update(1)
else:
print(f'image updated at "./{str(self.filename)}"')
if self.save_progress:
total_iterations = epoch * self.iterations + i
num = total_iterations // self.save_every
save_image(image, Path(f'./{self.text_path}.{num}{self.seed_suffix}.png'))
if self.save_best and top_score.item() < self.current_best_score:
self.current_best_score = top_score.item()
save_image(image, Path(f'./{self.text_path}{self.seed_suffix}.best.png'))
return out, total_loss
def forward(self):
penalizing = ""
if len(self.text_min) > 0:
penalizing = f'penalizing "{self.text_min}"'
print(f'Imagining "{self.text_path}" {penalizing}...')
with torch.no_grad():
self.model(self.encoded_texts["max"][0]) # one warmup step due to issue with CLIP and CUDA
if self.open_folder:
open_folder('./')
self.open_folder = False
image_pbar = tqdm(total=self.total_image_updates, desc='image update', position=2, leave=True)
for epoch in trange(self.epochs, desc = ' epochs', position=0, leave=True):
pbar = trange(self.iterations, desc=' iteration', position=1, leave=True)
image_pbar.update(0)
for i in pbar:
out, loss = self.train_step(epoch, i, image_pbar)
pbar.set_description(f'loss: {loss.item():04.2f}')
if terminate:
print('detecting keyboard interrupt, gracefully exiting')
return
`
But i recive this error `UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning) " silence this warning)", UserWarning) Traceback (most recent call last): File "/home/jovyan/SberArtist/big-sleep/big_sleep/clip.py", line 101, in load model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() File "/home/user/conda/lib/python3.7/site-packages/torch/jit/init.py", line 275, in load cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files) RuntimeError:
aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor): Expected at most 12 arguments but found 13 positional arguments.`
@AlexWortega could you use ``` at the begin and the end of your code ? You can use the tab Preview
to check the display.
It's hard to read your code :-/
Could you also tell us what did you change ?
I would like to use all my GPU to.
I think you only need to add DataParallel to the top level model, rather than in multiple places in the model implementation. this tutorial suggests something like:
diff --git a/big_sleep/cli.py b/big_sleep/cli.py
index 8aa8820..a4cbdf4 100644
--- a/big_sleep/cli.py
+++ b/big_sleep/cli.py
@@ -4,6 +4,7 @@ from big_sleep import Imagine, version
from pathlib import Path
from .version import __version__;
+import torch
def train(
text=None,
@@ -68,6 +69,11 @@ def train(
if answer not in ('yes', 'y'):
exit()
+ if torch.cuda.device_count() > 1:
+ print("Let's use", torch.cuda.device_count(), "GPUs!")
+ imagine = torch.nn.DataParallel(imagine)
+ imagine.to(torch.device('cuda'))
+
imagine()
def main():
ought to allow the dream
script to use multiple GPUs. However, this model get an exception:
RuntimeError: Output 237 of BroadcastBackward is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
and I think that means this architecture needs a little retooling to support multi-gpu training.