torchsde icon indicating copy to clipboard operation
torchsde copied to clipboard

Ridding `_check_2d`

Open lxuechen opened this issue 3 years ago • 2 comments

I'm proposing to get rid of the 2d shape checks. These were added in #88 as part of the v0.2.4.

These checks are creating a huge barrier for applications that don't have vectorized data naturally. Flattening and unflattening will hurt efficiency tremendously.

lxuechen avatar Feb 07 '21 04:02 lxuechen

The reason they're there is that several parts of the internal code -- for example misc.batch_mvp and ForwardSDE.dg_ga_jvp -- implicitly assume there is a single batch dimension.

If we can go through and sort those out then I am in favour of this; I agree this is a wart. Ideally we've be able to have y0 take an arbitrary shape.

Off the top of my head I think the only parts of the code that needs to distinguish batch dimensions from channel dimensions is when creating a default Brownian motion (needing one sample per batch but not one per channel), and ForwardSDE.dg_ga_jvp, so those would need some way of specifying that detail.

In passing, why is flattening/unflattening hurting efficiency? It should be doable just be re-striding the tensor, which is cheap.

patrick-kidger avatar Feb 07 '21 09:02 patrick-kidger

I am completely aware of why they are needed. I will come up with a design doc next weekend.

lxuechen avatar Feb 08 '21 17:02 lxuechen