devito
devito copied to clipboard
compiler: Redundant haloupdate
Elastic-like cross-loop dependencies generate a redundant haloupdate.
Problem in Pseudocode:
for time
haloupd vx[t0]
write to vx[t1] - read from vx[t0]
haloupd vx[t1]
read from vx[t1]
read from vx[t0]
it could be
<(haloup vx[t0])>
for time
---DROP haloupd vx[t0]---since previously written t1 is now t0 and is already updated
write to vx[t1] - read from vx[t0]
haloupd vx[t1]
read from vx[t1]
read from vx[t0]
Python script to reproduce: (reduced, starting from https://github.com/devitocodes/devito/blob/master/examples/seismic/tutorials/06_elastic_varying_parameters.ipynb)
DEVITO_LOGGING=DEBUG DEVITO_MPI=1 mpirun -n 1 python3 tests/elastic_mfe_1d.py
DEVITO_LOGGING=DEBUG DEVITO_MPI=1 ../../tmpi/tmpi 1 python3 tests/elastic_mfe_1d.py
elastic_mfe_1d.py :
import numpy as np
from devito import (SpaceDimension, Grid, TimeFunction, Eq, Operator,
solve, Constant)
from examples.seismic.source import TimeAxis, Receiver
# Space related
extent = (1500., )
shape = (201, )
x = SpaceDimension(name='x', spacing=Constant(name='h_x', value=extent[0]/(shape[0]-1)))
grid = Grid(extent=extent, shape=shape, dimensions=(x, ))
# Time related
t0, tn = 0., 30.
dt = (10. / np.sqrt(2.)) / 6.
time_range = TimeAxis(start=t0, stop=tn, step=dt)
# Velocity and pressure fields
so = 2
to = 1
v = TimeFunction(name='v', grid=grid, space_order=so, time_order=to)
tau = TimeFunction(name='tau', grid=grid, space_order=so, time_order=to)
# The receiver
nrec = 1
rec = Receiver(name="rec", grid=grid, npoint=nrec, time_range=time_range)
rec.coordinates.data[:, 0] = np.linspace(0., extent[0], num=nrec)
rec_term = rec.interpolate(expr=v)
# First order elastic-like dependencies equations
pde_v = v.dt - (tau.dx)
pde_tau = (tau.dt - ((v.forward).dx))
u_v = Eq(v.forward, solve(pde_v, v.forward))
u_tau = Eq(tau.forward, solve(pde_tau, tau.forward))
op = Operator([u_v] + [u_tau] + rec_term)
op.apply(dt=dt)
# print(op.ccode)
generated code includes: (where haloupdate1(v_vec,comm,nb,t0); is redundant)
for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))
{
START(section0)
haloupdate0(tau_vec,comm,nb,t0);
--------------------
haloupdate1(v_vec,comm,nb,t0);
--------------------
for (int x = x_m; x <= x_M; x += 1)
{
v[t1][x + 2] = dt*(r0*v[t0][x + 2] - r1*tau[t0][x + 2] + r1*tau[t0][x + 3]);
}
haloupdate0(v_vec,comm,nb,t1);
for (int x = x_m; x <= x_M; x += 1)
{
tau[t1][x + 2] = dt*(r0*tau[t0][x + 2] - r1*v[t1][x + 2] + r1*v[t1][x + 3]);
}
STOP(section0,timers)
START(section1)
for (int p_rec = p_rec_m; p_rec <= p_rec_M; p_rec += 1)
{
float r5 = r2*(-o_x + rec_coords[p_rec][0]);
float r4 = floorf(r5);
int posx = (int)r4;
float px = -r4 + r5;
float sum = 0.0F;
for (int rrecx = 0; rrecx <= 1; rrecx += 1)
{
if (rrecx + posx >= x_m - 1 && rrecx + posx <= x_M + 1)
{
sum += (rrecx*px + (1 - rrecx)*(1 - px))*v[t0][rrecx + posx + 2];
}
}
rec[time][p_rec] = sum;
}
STOP(section1,timers)
@FabioLuporini here
Reminder for @georgebisbas to open as PR with test