diffusers
diffusers copied to clipboard
initial pass on jaxify pndm scheduler + tests
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
FYI Jax precision makes the following test fail by rounding a very small number to zero:
def test_betas(self):
self.check_over_configs(beta_start=0.01, beta_end=0.2)
I don't think this needs to go into diffusers@main
yet, as it will require users to install jax to use the pndm
scheduler. It's cool to have the tests, but then I'd put the implementation in a separate file rather than changing scheduling_pndm.py
.
I agree, this is a proof of concept. Eventually a lot of work would need to be done to make the tests torch, numpy, jax independent. This can be referenced then.