michalk8

Results 40 issues of michalk8

Use `jax.scipy.cho_{factor,solve}`

enhancement

- [ ] fix missing types - [ ] add mypy - [ ] properly type PRNGKeys (can be done in see #172 )

hackathon

Code to reproduce; most likely introduce in #310 : ```python import jax.numpy as jnp import ott x = jnp.ones((10, 12)) ott.geometry.geometry.Geometry(kernel_matrix=x).cost_matrix ``` Traceback: ```pytb RecursionError Traceback (most recent call last)...

bug

To fix wrong links, e.g., missing `PointCloud` reference in [getting started tutorial](https://ott-jax.readthedocs.io/en/latest/tutorials/notebooks/basic_ot_between_datasets.html).

documentation

Meta issues to: - [ ] remove `absl-py` remnants #344 - [ ] refactor slow tests - [ ] increase coverage #254 - [ ] test on 16/64-bit?

tests

Remove the `initialize` methods from `tests` and refactor them as fixtures.

tests

LGTM, we can either push or add the sharp bits page, as you prefer _Originally posted by @marcocuturi in https://github.com/ott-jax/ott/pull/335#pullrequestreview-1345875461_

documentation

TODOs: - [ ] remove type hints from docstrings - [ ] unify style (`.` at the end, etc.) - [ ] add pydocstyle pre-commit - [ ] update CONTRIBUTING.md

documentation

To eliminate the need for specifying `tree_{un,}flatten`, see e.g., [flax struct](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html) for inspiration.

enhancement

Hi, @davidsebfischer , I've started working on the weighted model based on your notes (thanks a lot) and I think I've gotten most of the stuff in numpy right, though...