AnyText icon indicating copy to clipboard operation
AnyText copied to clipboard

fix a bug

Open yuetz opened this issue 1 year ago • 3 comments

In the file cldm/cldm.py, line 420~422,

            glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
            gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
            positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')

It may cause a bug, because it change batch‘s values. I change the code as follow, which may solve this bug.

    def get_input(self, batch, k, bs=None, *args, **kwargs):
        if self.embedding_manager is None:  # fill in full caption
            self.fill_caption(batch)
        x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
        control = batch[self.control_key]  # for log_images and loss_alpha, not real control
        if bs is not None:
            control = control[:bs]
        control = control.to(self.device)
        control = einops.rearrange(control, 'b h w c -> b c h w')
        control = control.to(memory_format=torch.contiguous_format).float()

        inv_mask = batch['inv_mask']
        if bs is not None:
            inv_mask = inv_mask[:bs]
        inv_mask = inv_mask.to(self.device)
        inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
        inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()

        glyphs = batch[self.glyph_key]
        gly_line = batch['gly_line']
        positions = batch[self.position_key]
        n_lines = batch['n_lines']
        language = batch['language']
        texts = batch['texts']

        glyphs_new = []
        gly_line_new = []
        positions_new = []
        n_lines_new = []
        assert len(glyphs) == len(positions)
        for i in range(len(glyphs)):
            if bs is not None:
                glyphs[i] = glyphs[i][:bs]
                gly_line[i] = gly_line[i][:bs]
                positions[i] = positions[i][:bs]
                n_lines = n_lines[:bs]
            glyphs[i] = glyphs[i].to(self.device)
            gly_line[i] = gly_line[i].to(self.device)
            positions[i] = positions[i].to(self.device)

            glyphs_new.append(einops.rearrange(glyphs[i], 'b h w c -> b c h w'))
            gly_line_new.append(einops.rearrange(gly_line[i], 'b h w c -> b c h w'))
            positions_new.append(einops.rearrange(positions[i], 'b h w c -> b c h w'))
            glyphs_new[i] = glyphs_new[i].to(memory_format=torch.contiguous_format).float()
            gly_line_new[i] = gly_line_new[i].to(memory_format=torch.contiguous_format).float()
            positions_new[i] = positions_new[i].to(memory_format=torch.contiguous_format).float()
        info = {}
        info['glyphs'] = glyphs_new
        info['positions'] = positions_new
        info['n_lines'] = n_lines
        info['language'] = language
        info['texts'] = texts
        info['img'] = batch['img']  # nhwc, (-1,1)
        info['masked_x'] = mx
        info['gly_line'] = gly_line_new
        info['inv_mask'] = inv_mask
        return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)

yuetz avatar Jan 19 '24 12:01 yuetz

Hi Thank you but I didn't quite understand it, what kind of bug would these lines of code introduce?

glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
...

tyxsspa avatar Jan 19 '24 14:01 tyxsspa

The data type of glyphs, gly_line, positions are list. They are input data batch 's members. Those codes changes the dimensional order of the original values. When the func is called at second time to same data, the dimensional order is wrong. Such as in func validation_step (ddpm.py )

    def validation_step(self, batch, batch_idx):
        _, loss_dict_no_ema = self.shared_step(batch)
        with self.ema_scope():
            _, loss_dict_ema = self.shared_step(batch)
            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)

or the funclog_images is called after shared_step is called .

It also can be my mistakes, using the dataset in a wrong way. but I think there may be a high risk of the code.

yuetz avatar Jan 19 '24 15:01 yuetz

Those codes changes the dimensional order of the original values. When the func is called at second time to same data, the dimensional order is wrong.

I can confirm the issue and your proposed code fixes it. @yuetz

tydia avatar Jan 24 '24 06:01 tydia