Enzyme icon indicating copy to clipboard operation
Enzyme copied to clipboard

WIP: Add the initial part of MPI tablegen for adjoint generator

Open PragmaTwice opened this issue 2 years ago • 7 comments

WIP, some comment and pr description will be added soon.

PragmaTwice avatar Jul 11 '23 17:07 PragmaTwice

As discusssed here is a slight adjustment of your PR which you could use to add your mpi code as part of InstructionDerivatives.td instead of copying the style of BlasDerivatives.td

It might have some mistakes so make sure to check, but if you get something similar to this style to compile, we should be able to improve from there on.

class mpiPattern<dag patternToMatch, list<string> funcNames, string overwrittenArg, list<dag> resultOps, dag forwardOps> {
  dag PatternToMatch = patternToMatch;
  string overwritten= overwrittenArg,
  list<string> names = funcNames;
  list<dag> ArgDerivatives = resultOps;
  dag ArgDuals = forwardOps;
}
                    
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                      (InactiveArg), //sendcount
                      (InactiveArg), //sendtype
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                      (InactiveArg), //recvcunt
                      (InactiveArg), //recvtype
                      (InactiveArg), //root
                      (InactiveArg), //comm
                  ],
                  (ForwardFromSummedReverse)
                  >;

ZuseZ4 avatar Jul 12 '23 03:07 ZuseZ4

Based on your second question, this could be a shorter design, which might need a few more adjustments, so we can try this later.

class mpiPattern<dag patternToMatch, list<string> funcNames, list<string> actArgs, string overwrittenArg, list<dag> resultOps, dag forwardOps> {
  dag PatternToMatch = patternToMatch;
  list<string> activeArgs = actArgs;
  string overwritten= overwrittenArg,
  list<string> names = funcNames;
  list<dag> ArgDerivatives = resultOps;
  dag ArgDuals = forwardOps;
}
                    
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  ["sendbuf", "recvbuf"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                  ],
                  (ForwardFromSummedReverse)
                  >;

ZuseZ4 avatar Jul 12 '23 03:07 ZuseZ4

I would heavily recommend against making a new MPI-specific tablegen infrastructure, as opposed to extending and using the existing call infrastructure with whatever new operations you need.

wsmoses avatar Jul 12 '23 15:07 wsmoses

@wsmoses I agree on not having a third mpi-tg beside of enzyme-tg and blas-tg, but I did care less about having 4 or 5 different classes inside of enzyme-tg. But most of the extensions I had in mind (e.g. potentially active args vs. always inactive args) can also be solved by extending the call class and set some default values like all-active.

@PragmaTwice Using the existing call class should make it easier to get a first compiling version, so maybe focus on this one. Most rules will likely be to complex and miss features so you won't be able to emit all of the required code to handle mpi yet. However, once you get it to compile we can see which features are missing and add those one by one.

ZuseZ4 avatar Jul 12 '23 16:07 ZuseZ4

I will try to construct a CallMPIPattern that inherits CallPattern so that CallMPIPattern can be treated as CallPattern and meanwhile we can add some additional arguments for other uses.

PragmaTwice avatar Jul 13 '23 16:07 PragmaTwice

def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),

                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                  ],
                  (ForwardFromSummedReverse)
                  >;                 

So if you do want to have a type system as extension for the callpatern, you could use this to extend the list. The default callPattern expects one dag per input argument. So every time your type system does return buffer, you pick the next dag from your own list and every time you do have a a type different from buffer you do add (InactiveArg), as dag rule

the above example should therefore translate into:

def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  [buf, size, datatype, buf, size, datatype, integer, comm]
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),

                    (InactiveArg),                    
                    (InactiveArg),                  
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                    (InactiveArg),                    
                    (InactiveArg),                    
                    (InactiveArg),                    
                    (InactiveArg),                      
                  ],
                  (ForwardFromSummedReverse)
                  >;

simplification: add the following lines at the beginning of all mpi functions:

        Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType());
        Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType());

Also you can try to generate the following names and helper for each input argument:

  const int pos_x = 1;
  const auto orig_x = call.getArgOperand(pos_x);
  auto arg_x = gutils->getNewFromOriginal(orig_x);

and for all the shadow arguments (might already be done by tablegen, just check)

        Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);

ZuseZ4 avatar Jul 14 '23 03:07 ZuseZ4

Simplifaction1: Mark all MPI_Gather arguments as InactiveArg, Then you specify the forward pass as ForwardFromSummedReverse. You create all helper args I mentioned above. For all arguments which are buffer based on your type system, you do look up the shadow, as in the example above. Then you do create the primal call function and call it, replacing your buffer args with the shadow of those buffer args:

        if (forwardMode) {
          Value *args[] = {
              /*sendbuf*/ shadow_sendbuf,
              /*sendcount*/ sendcount,
              /*sendtype*/ sendtype,
              /*recvbuf*/ shadow_recvbuf,
              /*recvcount*/ recvcount,
              /*recvtype*/ recvtype,
              /*root*/ root,
              /*comm*/ comm,
          };

          auto Defs = gutils->getInvertedBundles(
              &call,
              {ValueType::Shadow, ValueType::Primal, ValueType::Primal,
               ValueType::Shadow, ValueType::Primal, ValueType::Primal,
               ValueType::Primal, ValueType::Primal},
              Builder2, /*lookup*/ false);

#if LLVM_VERSION_MAJOR >= 11
          auto callval = call.getCalledOperand();
#else
          auto callval = call.getCalledValue();
#endif
          Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
          return;
        }

Beside of the 3 funcitons which you already added, this logic would also be able to handle the forward mode of

    if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" ||
        funcName == "PMPI_Irecv" || funcName == "MPI_Irecv") {

So I do think it is worth starting with it :)

ZuseZ4 avatar Jul 14 '23 03:07 ZuseZ4