imagen-pytorch
imagen-pytorch copied to clipboard
Text to video no Attentions layers
I don't understand why in the Unet3D we don't use attention layers for text conditionning ( sorry if this is dumb question ).
this : layer_attns = (False, False, False, True), layer_cross_attns = False
@axel588 hmm, if you are training with text conditioning, but have no cross attention layers set, it should error out (does it not?) i can add it if you show me a script where this is not true
@lucidrains I applied the attention layer at first, but even with a dimension of 8 ( very low yes ) and a batch of 1 it overflows my 24gb memory card graphic card, this configuration below takes 23Gb of VRAM with 2 of batch, how to solve memory issue ? this code work for text conditionning without attention layer and gives no error, but yes the sample seems random relative to the prompt :
unet = Unet3D(
dim = config.dim, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
#cond_dim = config.cond_dim,
channels = 5,
dim_mults = config.dim_mults, # the channel dimensions inside the model (multiplied by dim)
# num_resnet_blocks = config.num_resnet_blocks,
# layer_attns = (False,) + (True,) * (len(config.dim_mults) - 1),
# layer_cross_attns = (False,) + (True,) * (len(config.dim_mults) - 1)
)
imagen = ElucidatedImagen(
unets = (unet),
image_sizes = (reshaped_m),
cond_drop_prob = 0.1,
text_encoder_name = 't5-base',
channels=5,
num_sample_steps = (64), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = (80), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
).cuda()```