deepxde icon indicating copy to clipboard operation
deepxde copied to clipboard

Mini batch training using point set boundary condition

Open mattragoza opened this issue 3 years ago • 8 comments

This pull request implements BatchPointSetBC, as described in #744 . This is a subclass of PointSetBC that takes additional constructor arguments of batch_size and shuffle. It uses a deepxde.data.sampler.BatchSampler to maintain the internal state while iterating over the data points. Each call to collocation_points advances the iterator and returns the next batch of data points, then subsequent calls to error compute the boundary condition residual with respect to the data values for the current batch.

I also created an example script examples/pinn_inverse/elliptic_inverse_field_batch.py and verified that the performance is similar to the non-minibatch version.

mattragoza avatar Jun 24 '22 14:06 mattragoza

I am aware that the PR contains additional commits from my dev branch that were already merged in a previous PR. I'm not sure what the best practice is for omitting these commits from future pull requests given that this repo uses squash merging. Should I just create a separate branch for each pull request?

mattragoza avatar Jun 24 '22 14:06 mattragoza

I am aware that the PR contains additional commits from my dev branch that were already merged in a previous PR. I'm not sure what the best practice is for omitting these commits from future pull requests given that this repo uses squash merging. Should I just create a separate branch for each pull request?

Yes, that would be better. It is fine this time.

lululxvi avatar Jun 24 '22 14:06 lululxvi

Which backends have you used for testing the code? tensorflow.compat.v1 and pytorch?

This will not work as you expected, because the training points will not be updated. Generating the training points every iteration could be expensive, and thus once it is generated, it will not change

https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/data/pde.py#L167

https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/data/pde.py#L255

To re-generate, we should use

https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/data/pde.py#L193

which is used in

https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/callbacks.py#L487

So, a correct implementation of mini-batch PointSetBC is tricky.

lululxvi avatar Jun 24 '22 15:06 lululxvi

I have only tested with pytorch so far, but I will test with other backends.

mattragoza avatar Jun 24 '22 15:06 mattragoza

You are right, I forgot about the caching of training points. In fact I don't think even resample_train_points will cause the batch resampling. The batch is advanced by calling bc.collocation_points, which is called by pde.bc_points, which only runs when train_x_bc is None. This attribute is not set to None by resample_train_points.

mattragoza avatar Jun 24 '22 15:06 mattragoza

Maybe we can add an additional method pde.resample_bc_points to PDE, along with a BCResidualResampler callback?

mattragoza avatar Jun 24 '22 15:06 mattragoza

Maybe we can add an additional method pde.resample_bc_points to PDE, along with a BCResidualResampler callback?

Yes, adding a callback is a good idea.

lululxvi avatar Jun 24 '22 15:06 lululxvi

It might be better to combine BatchPointSetBC and PointSetBC into one class, by checking if batch_size is None.

lululxvi avatar Jun 24 '22 21:06 lululxvi