fix a bug
In the file cldm/cldm.py, line 420~422,
glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
It may cause a bug, because it change batch‘s values. I change the code as follow, which may solve this bug.
def get_input(self, batch, k, bs=None, *args, **kwargs):
if self.embedding_manager is None: # fill in full caption
self.fill_caption(batch)
x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
control = batch[self.control_key] # for log_images and loss_alpha, not real control
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
inv_mask = batch['inv_mask']
if bs is not None:
inv_mask = inv_mask[:bs]
inv_mask = inv_mask.to(self.device)
inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
glyphs = batch[self.glyph_key]
gly_line = batch['gly_line']
positions = batch[self.position_key]
n_lines = batch['n_lines']
language = batch['language']
texts = batch['texts']
glyphs_new = []
gly_line_new = []
positions_new = []
n_lines_new = []
assert len(glyphs) == len(positions)
for i in range(len(glyphs)):
if bs is not None:
glyphs[i] = glyphs[i][:bs]
gly_line[i] = gly_line[i][:bs]
positions[i] = positions[i][:bs]
n_lines = n_lines[:bs]
glyphs[i] = glyphs[i].to(self.device)
gly_line[i] = gly_line[i].to(self.device)
positions[i] = positions[i].to(self.device)
glyphs_new.append(einops.rearrange(glyphs[i], 'b h w c -> b c h w'))
gly_line_new.append(einops.rearrange(gly_line[i], 'b h w c -> b c h w'))
positions_new.append(einops.rearrange(positions[i], 'b h w c -> b c h w'))
glyphs_new[i] = glyphs_new[i].to(memory_format=torch.contiguous_format).float()
gly_line_new[i] = gly_line_new[i].to(memory_format=torch.contiguous_format).float()
positions_new[i] = positions_new[i].to(memory_format=torch.contiguous_format).float()
info = {}
info['glyphs'] = glyphs_new
info['positions'] = positions_new
info['n_lines'] = n_lines
info['language'] = language
info['texts'] = texts
info['img'] = batch['img'] # nhwc, (-1,1)
info['masked_x'] = mx
info['gly_line'] = gly_line_new
info['inv_mask'] = inv_mask
return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
Hi Thank you but I didn't quite understand it, what kind of bug would these lines of code introduce?
glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
...
The data type of glyphs, gly_line, positions are list. They are input data batch 's members. Those codes changes the dimensional order of the original values. When the func is called at second time to same data, the dimensional order is wrong.
Such as in func validation_step (ddpm.py )
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
or the funclog_images is called after shared_step is called .
It also can be my mistakes, using the dataset in a wrong way. but I think there may be a high risk of the code.
Those codes changes the dimensional order of the original values. When the func is called at second time to same data, the dimensional order is wrong.
I can confirm the issue and your proposed code fixes it. @yuetz