generative-models icon indicating copy to clipboard operation
generative-models copied to clipboard

Potential bug in BasicTransformerBlock

Open CarlaSa opened this issue 2 years ago • 0 comments

In the forward-method of BasicTransformerBlock the keywords additional_tokens and n_times_crossframe_attn_in_self are ignored, because these keywords are added to kwargs, but kwargs is not passed to checkpoint, only context is.

So it is impossible to change these keywords right now from the default, which might cause unexpected behaviour.

https://github.com/Stability-AI/generative-models/blob/76e549dd94de6a09f9e75eff82d62377274c00f8/sgm/modules/attention.py#L460C47-L460C47

Preview of above:

    def forward(
        self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
    ):
        kwargs = {"x": x}

        if context is not None:
            kwargs.update({"context": context})

        if additional_tokens is not None:
            kwargs.update({"additional_tokens": additional_tokens})

        if n_times_crossframe_attn_in_self:
            kwargs.update(
                {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
            )

        # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
        return checkpoint(
            self._forward, (x, context), self.parameters(), self.checkpoint
        )

CarlaSa avatar Jun 26 '23 12:06 CarlaSa