training-free-structured-diffusion-guidance
training-free-structured-diffusion-guidance copied to clipboard
The implementation is wrong
I would advise anyone against using this implementation until these issues are fixed.
In the function for sequence alignment (but the same can be said about _expand_sequence
), we have:
def _align_sequence(
self,
full_seq: torch.Tensor,
seq: torch.Tensor,
span: Span,
eos_loc: int,
dim: int = 1,
zero_out: bool = False,
replace_pad: bool = False,
) -> torch.Tensor:
# shape: (77, 768) -> (768, 77)
seq = seq.transpose(0, dim)
# shape: (77, 768) -> (768, 77)
full_seq = full_seq.transpose(0, dim)
start, end = span.left + 1, span.right + 1
seg_length = end - start
full_seq[start:end] = seq[1 : 1 + seg_length]
if zero_out:
full_seq[1:start] = 0
full_seq[end:eos_loc] = 0
if replace_pad:
pad_length = len(full_seq) - eos_loc
full_seq[eos_loc:] = seq[1 + seg_length : 1 + seg_length + pad_length]
# shape: (768, 77) -> (77, 768)
return full_seq.transpose(0, dim)
which is supposed to replace embeddings in full_seq
(77,768) between start
and end
with the ones from seq
. However, a transpose operation is first performed, making full_seq
have a shape of (768,77), which makes the assignment full_seq[start:end]
be over the wrong dimension. Similarly, seq
is also addressed wrongly.
Moreover, I believe the calculation of spans to also be incorrect, as it considers words without considering the possibility of a word being broken into multiple tokens. In the repository of the paper author, this function
def get_token_alignment_map(tree, tokens):
if tokens is None:
return {i:[i] for i in range(len(tree.leaves())+1)}
def get_token(token):
return token[:-4] if token.endswith("</w>") else token
idx_map = {}
j = 0
max_offset = np.abs(len(tokens) - len(tree.leaves()))
mytree_prev_leaf = ""
for i, w in enumerate(tree.leaves()):
token = get_token(tokens[j])
idx_map[i] = [j]
if token == mytree_prev_leaf+w:
mytree_prev_leaf = ""
j += 1
else:
if len(token) < len(w):
prev = ""
while prev + token != w:
prev += token
j += 1
token = get_token(tokens[j])
idx_map[i].append(j)
# assert j - i <= max_offset
else:
mytree_prev_leaf += w
j -= 1
j += 1
idx_map[i+1] = [j]
return idx_map
is used to perform this mapping between word spans and token spans.