Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Batched support for `NormalizingFlows.jl` and `Bijectors.jl`

Open yebai opened this issue 1 year ago • 1 comments

Introduce a batched mode to Bijectors.jl and NormalizingFlows.jl, which are built on top of Bijectors.jl.

Put simply, we want to enable users to provide multiple inputs to the model simultaneously by “stacking” the parameters into a higher-dimensional array.

The implementation can take various forms, as a team of developers who care about both performance and user experience, we are open to different approaches and discussions. One possible approach is to develop a mechanism that signals the code to process the given input as a batch rather than as individual entries.

A preliminary implementation can be found here.

(copy and pasted from here)

yebai avatar Mar 11 '25 12:03 yebai

AcceleratedKernels.map! could be an interesting alternative to batching.

yebai avatar Mar 26 '25 14:03 yebai