Mini batch training using point set boundary condition
This pull request implements BatchPointSetBC, as described in #744 . This is a subclass of PointSetBC that takes additional constructor arguments of batch_size and shuffle. It uses a deepxde.data.sampler.BatchSampler to maintain the internal state while iterating over the data points. Each call to collocation_points advances the iterator and returns the next batch of data points, then subsequent calls to error compute the boundary condition residual with respect to the data values for the current batch.
I also created an example script examples/pinn_inverse/elliptic_inverse_field_batch.py and verified that the performance is similar to the non-minibatch version.
I am aware that the PR contains additional commits from my dev branch that were already merged in a previous PR. I'm not sure what the best practice is for omitting these commits from future pull requests given that this repo uses squash merging. Should I just create a separate branch for each pull request?
I am aware that the PR contains additional commits from my dev branch that were already merged in a previous PR. I'm not sure what the best practice is for omitting these commits from future pull requests given that this repo uses squash merging. Should I just create a separate branch for each pull request?
Yes, that would be better. It is fine this time.
Which backends have you used for testing the code? tensorflow.compat.v1 and pytorch?
This will not work as you expected, because the training points will not be updated. Generating the training points every iteration could be expensive, and thus once it is generated, it will not change
https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/data/pde.py#L167
https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/data/pde.py#L255
To re-generate, we should use
https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/data/pde.py#L193
which is used in
https://github.com/lululxvi/deepxde/blob/7dc2285275d227393ab70adcd425d05502efec92/deepxde/callbacks.py#L487
So, a correct implementation of mini-batch PointSetBC is tricky.
I have only tested with pytorch so far, but I will test with other backends.
You are right, I forgot about the caching of training points. In fact I don't think even resample_train_points will cause the batch resampling. The batch is advanced by calling bc.collocation_points, which is called by pde.bc_points, which only runs when train_x_bc is None. This attribute is not set to None by resample_train_points.
Maybe we can add an additional method pde.resample_bc_points to PDE, along with a BCResidualResampler callback?
Maybe we can add an additional method
pde.resample_bc_pointsto PDE, along with a BCResidualResampler callback?
Yes, adding a callback is a good idea.
It might be better to combine BatchPointSetBC and PointSetBC into one class, by checking if batch_size is None.