Refactored DETR implementation on `optax` API.
Many functions and calls within existing detr implementation are deprecated in Flax as well as OTT.
This PR includes two major changes (all changes only under projects/baselines/detr):
- Migrates from
flax.optimtooptax. sinkhornsolver now uses updatedott-jaxcalls.
I have tested both configs for train/eval/checkpointing.
Hi @MasterSkepticista, I am experimenting with your modified detr, I have encountered many environment conflicts and tried to resolve them, but there are still some difficult bugs, can you provide your environment as a reference. Thank you very much.
@durianer-D I created a new repo with a minimal implementation. It is also significantly faster to train than this PR. https://github.com/MasterSkepticista/detr.