FBPINNs icon indicating copy to clipboard operation
FBPINNs copied to clipboard

problem with non-constant boundary condition

Open arty0mkha opened this issue 8 months ago • 3 comments

Hello. I tried to solve 2d Laplace equation div(grad(p))=0 at domain 0<x<L, 0<y<H with Neumman boundary conditions:

  1. x=0,L: dp/dx = psi(y), where psi(y) = 1/c if |y-H/2|<c/2, else 0
  2. y=0,H dp/dy = 0

I used weak bc: loss = physics_loss+weight*bc_loss, where bc_loss = sum( RMSE( dp/dn[i] - dp_FBPINN/dn[i])

However, instead of the boundary conditions being fulfilled and the derivatives at the left and right boundaries being equal to the psi function, fbpinn ''averaged'' the psi values: x=0,L: dp_FBPINN/dx = 1/L for this constant (1/L), the MSE(psi(y)-CONST) value is minimal, but I need dp_FBPINN/dx to be a function, not a constant. image_2024-06-16_13-36-39 At picture c = 0.2 and psi function was psi(y) =1 if |y-H/2|<c/2, else 0, so dp_FBPINN/dx = CONST = c/L, not 1/L

my code:

class Laplace2D(Problem):
    """Solves the 2D Laplace equation with constant velocity
        
        div mobility grad p = 0
        dp/dx = psi(y), x=0,L
        dp/dy = 0 y=0,H
    """

    @staticmethod
    def init_params(chi=0.2):

        static_params = {
            "dims":(1,2),
            "chi":chi
            }
        return static_params, {}

    @staticmethod
    def sample_constraints(all_params, domain, key, sampler, batch_shapes):

        chi = all_params["static"]["problem"]["chi"]
        # physics loss
        x_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
        required_ujs_phys = (
            (0,(0,)), # p_x
            (0,(1,)), # p_y
            (0,(0,0)), # p_xx
            (0,(1,1)), # p_yy
            )
        
        Nx, Ny = batch_shapes[0]
        x_min, x_max = x_batch_phys[0,0], x_batch_phys[-1,0]
        y_min, y_max = x_batch_phys[0,1], x_batch_phys[-1,1]

        # bc loss
        x = np.linspace(x_min,x_max,Nx)
        y = np.linspace(y_min,y_max,Ny)

        # x = 0
        batch_boundary_left = np.vstack((x[0]*np.ones_like(y), y)).T
        p_x_left = np.where(np.abs(batch_boundary_left[:,1]-y_max/2)<chi/2,1/chi,0)
        required_ujs_left = (
            (0,(0,)), # p_x
            )
        # x = L
        batch_boundary_right = np.vstack((x[-1]*np.ones_like(y), y)).T
        p_x_right = p_x_left
        required_ujs_right = (
            (0,(0,)), # p_x
            )
        # y = 0
        batch_boundary_top = np.vstack((x, y[0]*np.ones_like(x))).T
        p_y_top = np.zeros((Nx,1))
        required_ujs_top = (
            (0,(1,)), # p_y
            )        
        # y = H
        batch_boundary_bottom = np.vstack((x, y[-1]*np.ones_like(x))).T
        p_y_bottom = np.zeros((Nx,1))
        required_ujs_bottom = (
            (0,(1,)), # p_y
            )     


        return [[x_batch_phys, required_ujs_phys],
                [batch_boundary_left, p_x_left, required_ujs_left],
                [batch_boundary_right, p_x_right, required_ujs_right],
                [batch_boundary_top, p_y_top, required_ujs_top],
                [batch_boundary_bottom, p_y_bottom, required_ujs_bottom],
                ]

    @staticmethod
    def constraining_fn(all_params, x_batch, u):

        x, y = x_batch[:,0:1], x_batch[:,1:2]
        return u

    @staticmethod
    def loss_fn(all_params, constraints):

        x_batch, p_x, p_y, p_xx, p_yy = constraints[0]

        x, y  = x_batch[:,0:1], x_batch[:,1:2]

        mobility = 1
        mobility_x = 0
        mobility_y = 0
        pressure_loss = mobility*(p_xx+p_yy) + mobility_x*p_x + mobility_y*p_y

        physics_loss = jnp.mean(pressure_loss**2)


        bc_loss = 0
        # x = 0
        batch_boundary_left, p_x_left, p_x_left_predicted = constraints[1]
        bc_loss += jnp.mean((p_x_left_predicted - p_x_left)**2)
        # x = L
        batch_boundary_right, p_x_right,  p_x_right_predicted = constraints[2]        
        bc_loss += jnp.mean((p_x_right_predicted - p_x_right)**2)
        # y = 0
        batch_boundary_top, p_y_top, p_y_top_predicted = constraints[3]        
        bc_loss += jnp.mean((p_y_top_predicted - p_y_top)**2)
        # y = H
        batch_boundary_bottom, p_y_bottom, p_y_bottom_predicted = constraints[4] 
        bc_loss += jnp.mean((p_y_bottom_predicted - p_y_bottom)**2)

        return physics_loss + (10**7)*bc_loss

    @staticmethod
    def exact_solution(all_params, x_batch, batch_shape):
        x, y = x_batch[:,0:1], x_batch[:,1:2]
        sol = np.zeros((x.shape[0],1))
        N = 10
        L, H = x[-1], y[-1]
        chi = all_params["static"]["problem"]["chi"]
        for n in range(1,N):
            sol += 2/(np.pi*n*chi)*((-1)**(n%2))*np.sin(np.pi*n*chi/H)*np.cos(2*np.pi*n*y/H)*np.sinh(2*np.pi*n*(x-L/2)/H)/(2*np.pi*n/H*np.cosh(2*np.pi*n*L/2/H))
        sol += x/H 
        return sol
domain = RectangularDomainND
domain_init_kwargs = dict(
    xmin=np.array([0,0]),
    xmax=np.array([1,1])
)
problem = Laplace2D()
problem_init_kwargs=dict(
    chi=0.4
)
decomposition = RectangularDecompositionND# use a rectangular domain decomposition
decomposition_init_kwargs=dict(
    subdomain_xs = [np.linspace(0,1,15), np.linspace(0,1,15)],
    subdomain_ws = get_subdomain_ws([np.linspace(0,1,15), np.linspace(0,1,15)], 2),
    unnorm=(0.,1.),
    )
network = FCN
network_init_kwargs=dict(
    layer_sizes=[2,32,1])
c = Constants(
    domain=domain,
    domain_init_kwargs=domain_init_kwargs,
    problem=problem,
    problem_init_kwargs=problem_init_kwargs,
    decomposition=decomposition,
    decomposition_init_kwargs=decomposition_init_kwargs,
    network=network,
    network_init_kwargs=network_init_kwargs,
    ns=((50,50),),
    n_test=(50,50),
    n_steps=25000,
    clear_output=False,
    summary_freq = 1000,
    test_freq = 1000,
)

print(c)
run = FBPINNTrainer(c)
all_params = run.train()

arty0mkha avatar Jun 16 '24 14:06 arty0mkha