diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

add conv_in_channels in def save_motion_modules

Open ernestchu opened this issue 1 year ago • 1 comments

What does this PR do?

In save_motion_modules, current implementation doesn't pass conv_in_channels to the MotionAdapter constructor. This would incorrectly result into "conv_in_channels": null in the saved config.json, disregarding the actual conv_in_channels. This commit fixes the issue.

However, more actions may be required.

If a instance is constructed from from_unet2d with a motion_adapter, one can simply copy motion_adapter.config to the UNetMotionModel instance (like this) and use the config to construct the MotionAdapter instance to be saved.

If an instance is NOT constructed from from_unet2d, there has to be a way to identify the correct config that describes the motion_adapter in the instance, but it is beyond my bandwidth for the moment. Maybe @DN6 can help?

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [x] Did you read our philosophy doc (important for complex PRs)?
  • [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

ernestchu avatar Jun 05 '24 16:06 ernestchu

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Hi @ernestchu sorry for the delay here. Yeah this is tricky since the conv_in channels for models like PIA are loaded directly into the UNet. I assume you're working with a model like that?

We can't save conv_in_channels by default because the conv_in weights of the UNet are not saved with the MotionModules, which is why we have that failing test.

We could add an arg to save_motion_modules such as save_conv_in_channels and save the config and weights appropriately if it is set to True. WDYT?

    def save_motion_modules(
        self,
        save_directory: str,
        is_main_process: bool = True,
        safe_serialization: bool = True,
        variant: Optional[str] = None,
        push_to_hub: bool = False,
        save_conv_in_channels=False,  # Used to save conv_in weights of PIA-like models
        **kwargs,
    ) -> None:
        state_dict = self.state_dict()

        # Extract all motion modules
        motion_state_dict = {}
        for k, v in state_dict.items():
            if "motion_modules" in k:
                motion_state_dict[k] = v
            if save_conv_in_channels and ((k == "conv_in.weight") or (k == "conv_in.bias")):
                motion_state_dict[k] = v

        adapter = MotionAdapter(
            block_out_channels=self.config["block_out_channels"],
            motion_layers_per_block=self.config["layers_per_block"],
            motion_norm_num_groups=self.config["norm_num_groups"],
            motion_num_attention_heads=self.config["motion_num_attention_heads"],
            motion_max_seq_length=self.config["motion_max_seq_length"],
            use_motion_mid_block=self.config["use_motion_mid_block"],
            conv_in_channels=self.config["in_channels"] if save_conv_in_channels else None,
        )
        adapter.load_state_dict(motion_state_dict)
        adapter.save_pretrained(
            save_directory=save_directory,
            is_main_process=is_main_process,
            safe_serialization=safe_serialization,
            variant=variant,
            push_to_hub=push_to_hub,
            **kwargs,
        )

DN6 avatar Jul 10 '24 12:07 DN6

Yeah, I was referring to PIA. The code should generalize to models like it. Your proposal LGTM.

ernestchu avatar Jul 11 '24 13:07 ernestchu

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 14 '24 15:09 github-actions[bot]

hi @ernestchu let us know if you'll have time to finish this PR:)

yiyixuxu avatar Dec 03 '24 04:12 yiyixuxu

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 27 '24 15:12 github-actions[bot]