diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

initial pass on jaxify pndm scheduler + tests

Open natolambert opened this issue 2 years ago • 4 comments

natolambert avatar Aug 12 '22 22:08 natolambert

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

FYI Jax precision makes the following test fail by rounding a very small number to zero:

  def test_betas(self):
          self.check_over_configs(beta_start=0.01, beta_end=0.2)

natolambert avatar Aug 12 '22 22:08 natolambert

I don't think this needs to go into diffusers@main yet, as it will require users to install jax to use the pndm scheduler. It's cool to have the tests, but then I'd put the implementation in a separate file rather than changing scheduling_pndm.py.

pcuenca avatar Aug 13 '22 13:08 pcuenca

I agree, this is a proof of concept. Eventually a lot of work would need to be done to make the tests torch, numpy, jax independent. This can be referenced then.

natolambert avatar Aug 13 '22 21:08 natolambert