FBPINNs
FBPINNs copied to clipboard
Modifiying the wave 3D Problem
Thank you for the innovative contribution!
I tried modifying the wave 3D problem to have the following boundary conditions:
u(x,y,0) = 0 u(0,0,t) = 2 sin (2 pi t) #time-dependent source
in this way:
def boundary_condition(self, x, u, dudt, d2udx2, d2udy2, d2udt2, sd):
# Apply u = tanh^2((t-0)/sd)*NN + sigmoid((d-t)/sd)*exp( -(1/2)((x/sd)^2+(y/sd)^2) ) ansatz
t_2, dudt_2, d2udt_22 = boundary_conditions.tanh2_2(x[:,2:3], 0, sd)
s, _, d2uds2 = boundary_conditions.sigmoid_2(-x[:,2:3], -2*sd, 0.2*sd)# beware (!) this gives correct 2nd order gradients but negative 1st order (sign flip!)
mx = my = 0;
sx = sy = self.source_sd
xnx, xny = (x[:,0:1]-mx)/sx, (x[:,1:2]-my)/sy
#exp = torch.exp(-0.5*(xnx**2 + xny**2))
exp = torch.exp(-0.5*(xnx**2 + xny**2))*0 #IC = 0 instead of exp
#Initial GP
f = exp
d2udfx2 = (1/sx**2) * ((xnx**2) - 1)*exp
d2udfy2 = (1/sy**2) * ((xny**2) - 1)*exp
u_new = t_2*u + s*f
d2udx2_new = t_2*d2udx2 + s*d2udfx2
d2udy2_new = t_2*d2udy2 + s*d2udfy2
d2udt2_new = d2udt_22*u + 2*dudt_2*dudt + t_2*d2udt2 + d2uds2*f
#Zero Ic and BC
# u_new = t_2*u *0
# d2udx2_new = t_2*d2udx2
# d2udy2_new = t_2*d2udy2
# d2udt2_new = d2udt_22*u
return u_new, dudt, d2udx2_new, d2udy2_new, d2udt2_new# skip updating first order gradients (not needed for loss)
I also made some changes for the FD file to be this way:
import numpy as np
import time
from seismic_CPML_helper import get_dampening_profiles
# todo: is this faster in parallel with np.roll?
def seismicCPML2D_wS(NX,
NY,
NSTEPS,
DELTAX,
DELTAY,
DELTAT,
NPOINTS_PML,
velocity,
density,
initial_pressures,
f0=20.,
dtype=np.float32,
output_wavefields=True,
gather_is=None):
"Run seismicCPML2D"
## INPUT PARAMETERS
velocity = velocity.astype(dtype)
density = density.astype(dtype)
if type(gather_is) != type(None): output_gather = True
else: output_gather = False
K_MAX_PML = 1.
ALPHA_MAX_PML = 2.*np.pi*(f0/2.)# from Festa and Vilotte
NPOWER = 2.# power to compute d0 profile
Rcoef = 0.001
STABILITY_THRESHOLD = 1e25
##
# STABILITY CHECKS
# basically: delta x > np.sqrt(3) * max(v) * delta t
courant_number = np.max(velocity) * DELTAT * np.sqrt(1/(DELTAX**2) + 1/(DELTAY**2))
if courant_number > 1.: raise Exception("ERROR: time step is too large, simulation will be unstable %.2f"%(courant_number))
if NPOWER < 1: raise Exception("ERROR: NPOWER must be greater than 1")
# GET DAMPENING PROFILES
[[a_x, a_x_half, b_x, b_x_half, K_x, K_x_half],
[a_y, a_y_half, b_y, b_y_half, K_y, K_y_half]] = get_dampening_profiles(velocity, NPOINTS_PML, Rcoef, K_MAX_PML, ALPHA_MAX_PML, NPOWER, DELTAT, DELTAS=(DELTAX, DELTAY), dtype=dtype, qc=False)
# INITIALISE ARRAYS
kappa = density*(velocity**2)
# pressure_present = initial_pressures[1].astype(dtype)
# pressure_past = initial_pressures[0].astype(dtype)
#zero IC
pressure_present = np.zeros((NX, NY), dtype=dtype)
pressure_past = np.zeros((NX, NY), dtype=dtype)
memory_dpressure_dx = np.zeros((NX, NY), dtype=dtype)
memory_dpressure_dy = np.zeros((NX, NY), dtype=dtype)
memory_dpressurexx_dx = np.zeros((NX, NY), dtype=dtype)
memory_dpressureyy_dy = np.zeros((NX, NY), dtype=dtype)
if output_wavefields: wavefields = np.zeros((NSTEPS, NX, NY), dtype=dtype)
if output_gather: gather = np.zeros((gather_is.shape[0], NSTEPS), dtype=dtype)
# precompute density_half arrays
density_half_x = np.pad(0.5 * (density[1:NX,:]+density[:NX-1,:]), [[0,1],[0,0]], mode="edge")
density_half_y = np.pad(0.5 * (density[:,1:NY]+density[:,:NY-1]), [[0,0],[0,1]], mode="edge")
# RUN SIMULATION
start = time.time()
for it in range(NSTEPS):
# compute the first spatial derivatives divided by density
value_dpressure_dx = np.pad((pressure_present[1:NX,:]-pressure_present[:NX-1,:]) / DELTAX, [[0,1],[0,0]], mode="constant", constant_values=0.)
value_dpressure_dy = np.pad((pressure_present[:,1:NY]-pressure_present[:,:NY-1]) / DELTAY, [[0,0],[0,1]], mode="constant", constant_values=0.)
memory_dpressure_dx = b_x_half * memory_dpressure_dx + a_x_half * value_dpressure_dx
memory_dpressure_dy = b_y_half * memory_dpressure_dy + a_y_half * value_dpressure_dy
value_dpressure_dx = value_dpressure_dx / K_x_half + memory_dpressure_dx
value_dpressure_dy = value_dpressure_dy / K_y_half + memory_dpressure_dy
pressure_xx = value_dpressure_dx / density_half_x
pressure_yy = value_dpressure_dy / density_half_y
# compute the second spatial derivatives
value_dpressurexx_dx = np.pad((pressure_xx[1:NX,:]-pressure_xx[:NX-1,:]) / DELTAX, [[1,0],[0,0]], mode="constant", constant_values=0.)
value_dpressureyy_dy = np.pad((pressure_yy[:,1:NY]-pressure_yy[:,:NY-1]) / DELTAY, [[0,0],[1,0]], mode="constant", constant_values=0.)
memory_dpressurexx_dx = b_x * memory_dpressurexx_dx + a_x * value_dpressurexx_dx
memory_dpressureyy_dy = b_y * memory_dpressureyy_dy + a_y * value_dpressureyy_dy
value_dpressurexx_dx = value_dpressurexx_dx / K_x + memory_dpressurexx_dx
value_dpressureyy_dy = value_dpressureyy_dy / K_y + memory_dpressureyy_dy
dpressurexx_dx = value_dpressurexx_dx
dpressureyy_dy = value_dpressureyy_dy
# apply the time evolution scheme
# we apply it everywhere, including at some points on the edges of the domain that have not be calculated above,
# which is of course wrong (or more precisely undefined), but this does not matter because these values
# will be erased by the Dirichlet conditions set on these edges below
# pressure_future = - pressure_past \
# + 2 * pressure_present \
# + DELTAT*DELTAT*(dpressurexx_dx+dpressureyy_dy)*kappa
# Stepping with a source function, p is passed from the main file as p0 (Gaussian pulse)
# location of source is passed within p0
def func_t(p,t_inst):
Amp = 1
freq = 1
t = t_inst*DELTAT
return p*Amp*np.sin(2*np.pi**freq*t)
pressure_future = - pressure_past \
+ 2 * pressure_present \
+ DELTAT*DELTAT*(dpressurexx_dx+dpressureyy_dy)*kappa \
+ DELTAT*DELTAT*func_t(initial_pressures[1].astype(dtype),it)
# apply Dirichlet conditions at the bottom of the C-PML layers,
# which is the right condition to implement in order for C-PML to remain stable at long times
# Dirichlet conditions
pressure_future[0,:] = pressure_future[-1,:] = 0.
pressure_future[:,0] = pressure_future[:,-1] = 0.
if output_wavefields: wavefields[it,:,:] = np.copy(pressure_present)
if output_gather:
gather[:,it] = np.copy(pressure_present[gather_is[:,0], gather_is[:,1]])# nb important to copy
# check stability of the code, exit if unstable
if(np.max(np.abs(pressure_present)) > STABILITY_THRESHOLD):
raise Exception('code became unstable and blew up')
# move new values to old values (the present becomes the past, the future becomes the present)
pressure_past = pressure_present
pressure_present = pressure_future
#print(pressure_past.dtype, pressure_future.dtype, wavefields.dtype, gather.dtype)
if it % 10000 == 0 and it!=0:
rate = (time.time()-start)/10.
print("[%i/%i] %.2f s per step"%(it, NSTEPS, rate))
start = time.time()
output = [None, None]
if output_wavefields: output[0]=wavefields
if output_gather: output[1]=gather
return output
Mainly attempting to change the IC and add a time-dependent source term to the equation so it becomes:
d^2 u d^2 u 1 d^2 u
------ + ------ - --- ------ = S(x,y,t)
dx^2 dy^2 c^2 dt^2
where,
Amp = 1
freq = 1
sx = sy =self.source_sd
mx = 0; my = 0
GP = torch.exp(-0.5*(( (x[:,0:1]-mx)/sx)**2 + ((x[:,1:2]-my)/sy)**2 ))
S = Amp * GP * torch.sin(2*np.pi*freq*x[:,2:3]) #The Source function
but the results are as shown in the image. So, my questions are:
- How can I implement my specified boundary and initial conditions in a better way than the one I tried (if my attempt was correct)? I don't fully understand how to use the implemented boundary condition helper functions to implement my specific equations.
d^2 u d^2 u 1 d^2 u
------ + ------ - --- ------ = S(x,y,t)
dx^2 dy^2 c^2 dt^2
Boundary conditions:
u(x,y,0) = 0
u(0,0,t) = 2 * sin (2 * pi * t) #time-dependent source
The results in the image were executed with these batch sizes:
batch_size = (30,30,30)
batch_size_test = (40,40,15)
because of the limited memory on my GPU.
- Does this affect the results? If so, how can I increase
batch_size_test
without getting OOM error?
Thanks again! Looking forward to your reply.
Also, it'll be helpful to know how can I change the source function?
sin (2 pi x).....(1)
to
sin (2 pi t).....(2)
Simply modifying the initial condition that contains the original equation (1) does not satisfy all t instances except t = 0.
Hi @engsbk sorry for replying slowly. Please check out the latest FBPINN release - it is a major update and there is no need to update the gradients of the FBPINN by hand when applying constraining operators, this is now done automatically using autodiff, so your workflow should be much simpler. Also the memory performance was improved, so you may be able to train/test with more points now
Hi @engsbk I am also interested in implementing source term. I have done it for 1D wave equation by adding a term in physics_loss but 2D seems problematic. Not sure what the ansatz should look like. @benmoseley could you give a example of how source term with zero IC should be implemented? Thanks!
Here is some code for the (2+1)D wave equation, with zero ICs and a source term. Please ignore the exact solution - you would need to add e.g. FD modelling code to compare to this.
import jax
import jax.numpy as jnp
import numpy as np
from fbpinns.domains import RectangularDomainND
from fbpinns.problems import Problem
from fbpinns.decompositions import RectangularDecompositionND
from fbpinns.networks import FCN
from fbpinns.constants import Constants, get_subdomain_ws
from fbpinns.trainers import FBPINNTrainer, PINNTrainer
class WaveEquation3D(Problem):
"""Solves the time-dependent (2+1)D wave equation with constant velocity
d^2 u d^2 u 1 d^2 u
----- + ----- - --- ----- = s(x,y,t)
dx^2 dy^2 c^2 dt^2
Boundary conditions:
u(x,y,0) = 0
du
--(x,y,0) = 0
dt
"""
@staticmethod
def init_params(c=1, sd=1):
static_params = {
"dims":(1,3),
"c":c,
"sd":sd,
}
return static_params, {}
@staticmethod
def sample_constraints(all_params, domain, key, sampler, batch_shapes):
# physics loss
x_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
required_ujs_phys = (
(0,(0,0)),
(0,(1,1)),
(0,(2,2)),
)
return [[x_batch_phys, required_ujs_phys],]
@staticmethod
def constraining_fn(all_params, x_batch, u):
c = all_params["static"]["problem"]["c"]
sd = all_params["static"]["problem"]["sd"]
t = x_batch[:,2:3]
u = (jax.nn.tanh(c*t/(2*sd))**2)*u# constrains u(x,y,0) = u_t(x,y,0) = 0
return u
@staticmethod
def loss_fn(all_params, constraints):
c = all_params["static"]["problem"]["c"]
sd = all_params["static"]["problem"]["sd"]
x_batch, uxx, uyy, utt = constraints[0]
x, y, t = x_batch[:,0:1], x_batch[:,1:2], x_batch[:,2:3]
e = -0.5*(x**2 + y**2 + t**2)/(sd**2)
s = 2e3*(1+e)*jnp.exp(e)# ricker source term
phys = (uxx + uyy) - (1/c**2)*utt - s
return jnp.mean(phys**2)
@staticmethod
def exact_solution(all_params, x_batch, batch_shape):
key = jax.random.PRNGKey(0)
return jax.random.normal(key, (x_batch.shape[0],1))
subdomain_xs = [np.linspace(-1,1,5), np.linspace(-1,1,5), np.linspace(0,1,5)]
subdomain_ws = get_subdomain_ws(subdomain_xs, 1.9)
c = Constants(
run="test",
domain=RectangularDomainND,
domain_init_kwargs=dict(
xmin=np.array([-1,-1,0]),
xmax=np.array([1,1,1]),
),
problem=WaveEquation3D,
problem_init_kwargs=dict(
c=1, sd=0.1,
),
decomposition=RectangularDecompositionND,
decomposition_init_kwargs=dict(
subdomain_xs=subdomain_xs,
subdomain_ws=subdomain_ws,
unnorm=(0.,1.),
),
network=FCN,
network_init_kwargs=dict(
layer_sizes=[3,32,1],
),
ns=((50,50,50),),
n_test=(100,100,5),
n_steps=5000,
optimiser_kwargs=dict(learning_rate=1e-3),
summary_freq=200,
test_freq=200,
show_figures=True,
clear_output=True,
)
#run = FBPINNTrainer(c)
#run.train()
c["network_init_kwargs"] = dict(layer_sizes=[3,64,64,1])
run = PINNTrainer(c)
run.train()