x-clip
x-clip copied to clipboard
NaN with mock data
Hi lucidrains,
Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.
import torch
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
import random
import numpy as np
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
num_text_tokens = 10000
batch_sz = 12
text_seq_len = 256
visual_image_size = 256
# mock data
data_sz = 1000
all_text = torch.randint(0, num_text_tokens, (data_sz, text_seq_len)).cuda()
all_images = torch.randn(data_sz, 3, visual_image_size, visual_image_size).cuda()
text = torch.zeros((batch_sz, text_seq_len), dtype=torch.long).cuda()
images = torch.zeros((batch_sz, 3, visual_image_size, visual_image_size)).cuda()
##########################################################################################
import wandb
import datetime
wandb.init(project="Test", name=datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), save_code=False)
from x_clip import CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = num_text_tokens,
text_enc_depth = 6,
text_seq_len = text_seq_len,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = visual_image_size,
visual_patch_size = 32,
visual_heads = 8,
use_all_token_embeds = False, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
).cuda()
optimizer = torch.optim.Adam(clip.parameters(), lr=1e-4, betas=(0.9, 0.99))
for step in range(999999):
for i in range(batch_sz):
data_id = random.randrange(0, data_sz - 1)
text[i] = all_text[data_id]
images[i] = all_images[data_id]
loss = clip(
text,
images,
freeze_image_encoder = False, # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
return_loss = True # needs to be set to True to return contrastive loss
)
clip.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(clip.parameters(), 1.0)
optimizer.step()
now_loss = loss.item()
wandb.log({"loss": now_loss}, step = step)
print(step, now_loss)
if 'nan' in str(now_loss):
break
@BlinkDL Hey Peng Bo! So I quickly checked the script and indeed it NaNs
, but not if the visual_ssl
is turned off
I suspect it has something to do with augmenting the randomly created images in the visual SSL, but not completely sure