training-free-structured-diffusion-guidance icon indicating copy to clipboard operation
training-free-structured-diffusion-guidance copied to clipboard

The implementation is wrong

Open elvisnava opened this issue 1 year ago • 2 comments

I would advise anyone against using this implementation until these issues are fixed.

In the function for sequence alignment (but the same can be said about _expand_sequence), we have:

    def _align_sequence(
            self,
            full_seq: torch.Tensor,
            seq: torch.Tensor,
            span: Span,
            eos_loc: int,
            dim: int = 1,
            zero_out: bool = False,
            replace_pad: bool = False,
    ) -> torch.Tensor:

    # shape: (77, 768) -> (768, 77)
    seq = seq.transpose(0, dim)

    # shape: (77, 768) -> (768, 77)
    full_seq = full_seq.transpose(0, dim)

    start, end = span.left + 1, span.right + 1
    seg_length = end - start

    full_seq[start:end] = seq[1 : 1 + seg_length]
    if zero_out:
        full_seq[1:start] = 0
        full_seq[end:eos_loc] = 0

    if replace_pad:
        pad_length = len(full_seq) - eos_loc
        full_seq[eos_loc:] = seq[1 + seg_length : 1 + seg_length + pad_length]

    # shape: (768, 77) -> (77, 768)
    return full_seq.transpose(0, dim)

which is supposed to replace embeddings in full_seq (77,768) between start and end with the ones from seq. However, a transpose operation is first performed, making full_seq have a shape of (768,77), which makes the assignment full_seq[start:end] be over the wrong dimension. Similarly, seq is also addressed wrongly.

Moreover, I believe the calculation of spans to also be incorrect, as it considers words without considering the possibility of a word being broken into multiple tokens. In the repository of the paper author, this function

def get_token_alignment_map(tree, tokens):
 if tokens is None:
     return {i:[i] for i in range(len(tree.leaves())+1)}
     
 def get_token(token):
     return token[:-4] if token.endswith("</w>") else token

 idx_map = {}
 j = 0
 max_offset = np.abs(len(tokens) - len(tree.leaves()))
 mytree_prev_leaf = ""
 for i, w in enumerate(tree.leaves()):
     token = get_token(tokens[j])
     idx_map[i] = [j]
     if token == mytree_prev_leaf+w:
         mytree_prev_leaf = ""
         j += 1
     else:
         if len(token) < len(w):
             prev = ""
             while prev + token != w:
                 prev += token
                 j += 1
                 token = get_token(tokens[j])
                 idx_map[i].append(j)
                 # assert j - i <= max_offset
         else:
             mytree_prev_leaf += w
             j -= 1
         j += 1
 idx_map[i+1] = [j]
 return idx_map

is used to perform this mapping between word spans and token spans.

elvisnava avatar Mar 27 '23 20:03 elvisnava