This line prevents the use of jax.distributed.initialize when importing flashbax
https://github.com/instadeepai/flashbax/blob/f518aa8a795ca0c3341816e9f0268b6deb853de6/flashbax/buffers/prioritised_trajectory_buffer.py#L778
As the title says, you cannot call jax.devices or similar functions before using jax.distributed.initialize, if people want to use flashbax in a distributed setup, it might be worthwhile finding an alternative to use or just not validating devices that way. i don't think its strictly necessary to explicitly check.
I don't know of an alternative. Wouldn't we expect people to call jax.distributed.initialize before creating buffers?
I think its due to the import, however, its possible this issue is not correct. I'll try make a dummy script to check at some point.