Enzyme
                                
                                
                                
                                    Enzyme copied to clipboard
                            
                            
                            
                        WIP: Add the initial part of MPI tablegen for adjoint generator
WIP, some comment and pr description will be added soon.
As discusssed here is a slight adjustment of your PR which you could use to add your mpi code as part of InstructionDerivatives.td instead of copying the style of BlasDerivatives.td
It might have some mistakes so make sure to check, but if you get something similar to this style to compile, we should be able to improve from there on.
class mpiPattern<dag patternToMatch, list<string> funcNames, string overwrittenArg, list<dag> resultOps, dag forwardOps> {
  dag PatternToMatch = patternToMatch;
  string overwritten= overwrittenArg,
  list<string> names = funcNames;
  list<dag> ArgDerivatives = resultOps;
  dag ArgDuals = forwardOps;
}
                    
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                      (InactiveArg), //sendcount
                      (InactiveArg), //sendtype
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                      (InactiveArg), //recvcunt
                      (InactiveArg), //recvtype
                      (InactiveArg), //root
                      (InactiveArg), //comm
                  ],
                  (ForwardFromSummedReverse)
                  >;
                                    
                                    
                                    
                                
Based on your second question, this could be a shorter design, which might need a few more adjustments, so we can try this later.
class mpiPattern<dag patternToMatch, list<string> funcNames, list<string> actArgs, string overwrittenArg, list<dag> resultOps, dag forwardOps> {
  dag PatternToMatch = patternToMatch;
  list<string> activeArgs = actArgs;
  string overwritten= overwrittenArg,
  list<string> names = funcNames;
  list<dag> ArgDerivatives = resultOps;
  dag ArgDuals = forwardOps;
}
                    
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  ["sendbuf", "recvbuf"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                  ],
                  (ForwardFromSummedReverse)
                  >;
                                    
                                    
                                    
                                
I would heavily recommend against making a new MPI-specific tablegen infrastructure, as opposed to extending and using the existing call infrastructure with whatever new operations you need.
@wsmoses I agree on not having a third mpi-tg beside of enzyme-tg and blas-tg, but I did care less about having 4 or 5 different classes inside of enzyme-tg. But most of the extensions I had in mind (e.g. potentially active args vs. always inactive args) can also be solved by extending the call class and set some default values like all-active.
@PragmaTwice Using the existing call class should make it easier to get a first compiling version, so maybe focus on this one. Most rules will likely be to complex and miss features so you won't be able to emit all of the required code to handle mpi yet. However, once you get it to compile we can see which features are missing and add those one by one.
I will try to construct a CallMPIPattern that inherits CallPattern so that CallMPIPattern can be treated as CallPattern and meanwhile we can add some additional arguments for other uses.
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                  ],
                  (ForwardFromSummedReverse)
                  >;                 
So if you do want to have a type system as extension for the callpatern, you could use this to extend the list (InactiveArg), as dag rule
the above example should therefore translate into:
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  [buf, size, datatype, buf, size, datatype, integer, comm]
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                    (InactiveArg),                    
                    (InactiveArg),                  
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                    (InactiveArg),                    
                    (InactiveArg),                    
                    (InactiveArg),                    
                    (InactiveArg),                      
                  ],
                  (ForwardFromSummedReverse)
                  >;
simplification: add the following lines at the beginning of all mpi functions:
        Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType());
        Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType());
Also you can try to generate the following names and helper for each input argument:
  const int pos_x = 1;
  const auto orig_x = call.getArgOperand(pos_x);
  auto arg_x = gutils->getNewFromOriginal(orig_x);
and for all the shadow arguments (might already be done by tablegen, just check)
        Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
                                    
                                    
                                    
                                
Simplifaction1: Mark all MPI_Gather arguments as InactiveArg, Then you specify the forward pass as ForwardFromSummedReverse. You create all helper args I mentioned above. For all arguments which are buffer based on your type system, you do look up the shadow, as in the example above. Then you do create the primal call function and call it, replacing your buffer args with the shadow of those buffer args:
        if (forwardMode) {
          Value *args[] = {
              /*sendbuf*/ shadow_sendbuf,
              /*sendcount*/ sendcount,
              /*sendtype*/ sendtype,
              /*recvbuf*/ shadow_recvbuf,
              /*recvcount*/ recvcount,
              /*recvtype*/ recvtype,
              /*root*/ root,
              /*comm*/ comm,
          };
          auto Defs = gutils->getInvertedBundles(
              &call,
              {ValueType::Shadow, ValueType::Primal, ValueType::Primal,
               ValueType::Shadow, ValueType::Primal, ValueType::Primal,
               ValueType::Primal, ValueType::Primal},
              Builder2, /*lookup*/ false);
#if LLVM_VERSION_MAJOR >= 11
          auto callval = call.getCalledOperand();
#else
          auto callval = call.getCalledValue();
#endif
          Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
          return;
        }
Beside of the 3 funcitons which you already added, this logic would also be able to handle the forward mode of
    if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" ||
        funcName == "PMPI_Irecv" || funcName == "MPI_Irecv") {
So I do think it is worth starting with it :)