NaN with mock data
Hi lucidrains,
Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.
import torch
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
import random
import numpy as np
seed = 42
num_text_tokens = 10000
batch_sz = 12
text_seq_len = 256
visual_image_size = 256
# mock data
data_sz = 1000
all_text = torch.randint(0, num_text_tokens, (data_sz, text_seq_len)).cuda()
all_images = torch.randn(data_sz, 3, visual_image_size, visual_image_size).cuda()
text = torch.zeros((batch_sz, text_seq_len), dtype=torch.long).cuda()
images = torch.zeros((batch_sz, 3, visual_image_size, visual_image_size)).cuda()
import wandb
import datetime
wandb.init(project="Test", name=datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), save_code=False)
from x_clip import CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = num_text_tokens,
text_enc_depth = 6,
text_seq_len = text_seq_len,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = visual_image_size,
visual_patch_size = 32,
visual_heads = 8,
use_all_token_embeds = False, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
optimizer = torch.optim.Adam(clip.parameters(), lr=1e-4, betas=(0.9, 0.99))
for step in range(999999):
for i in range(batch_sz):
data_id = random.randrange(0, data_sz - 1)
text[i] = all_text[data_id]
images[i] = all_images[data_id]
loss = clip(
freeze_image_encoder = False, # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
return_loss = True # needs to be set to True to return contrastive loss
torch.nn.utils.clip_grad_norm_(clip.parameters(), 1.0)
now_loss = loss.item()
wandb.log({"loss": now_loss}, step = step)
print(step, now_loss)
if 'nan' in str(now_loss):
@BlinkDL Hey Peng Bo! So I quickly checked the script and indeed it NaNs
, but not if the visual_ssl
is turned off
I suspect it has something to do with augmenting the randomly created images in the visual SSL, but not completely sure