Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

Tracking issue for missing Batch Op Interface

Open Pangoraw opened this issue 1 year ago • 5 comments

NOTE: Strikethrough ops are deliberately not implemented because the default broadcasting behavior of enzyme batch is enough.

  • [ ] StableHLO
    • [x] ~AbsOp~
    • [x] ~AddOp~
    • [x] ~AfterAllOp~
    • [ ] AllGatherOp
    • [ ] AllReduceOp
    • [ ] AllToAllOp
    • [x] ~AndOp~
    • [x] ~Atan2Op~
    • [ ] BatchNormGradOp
    • [ ] BatchNormInferenceOp
    • [ ] BatchNormTrainingOp
    • [x] ~BitcastConvertOp~
    • [x] BroadcastInDimOp
    • [ ] CaseOp
    • [x] ~CbrtOp~
    • [x] ~CeilOp~
    • [x] ~CholeskyOp~
    • [x] ~ClampOp~
    • [ ] CollectiveBroadcastOp
    • [ ] CollectivePermuteOp
    • [x] ~CompareOp~
    • [x] ~ComplexOp~
    • [ ] CompositeOp
    • [x] ConcatenateOp
    • [x] ConstantOp
    • [x] ~ConvertOp~
    • [x] ConvolutionOp #151
    • [x] ~CosineOp~
    • [x] ~ClzOp~
    • [x] CustomCallOp
    • [x] ~DivOp~
    • [x] DotGeneralOp
    • [ ] DynamicBroadcastInDimOp
    • [ ] DynamicConvOp
    • [ ] DynamicGatherOp
    • [ ] DynamicIotaOp
    • [ ] DynamicPadOp
    • [ ] DynamicReshapeOp
    • [x] DynamicSliceOp
    • [x] DynamicUpdateSliceOp
    • [x] ~ExpOp~
    • [x] ~Expm1Op~
    • [x] ~FftOp~
    • [x] ~FloorOp~
    • [x] GatherOp
    • [x] GetDimensionSizeOp
    • [x] IfOp #151
    • [x] ~ImagOp~
    • [ ] InfeedOp
    • [x] IotaOp
    • [x] ~IsFiniteOp~
    • [x] ~LogOp~
    • [x] ~Logp1Op~
    • [x] ~LogisticOp~
    • [x] ~MaxOp~
    • [x] ~MinOp~
    • [x] ~MulOp~
    • [x] ~NegateOp~
    • [x] ~NotOp~
    • [ ] OptimizationBarrierOp
    • [x] ~OrOp~
    • [ ] OutfeedOp
    • [ ] PadOp
    • [ ] PartitionIdOp
    • [x] ~PopcntOp~
    • [x] ~PowOp~
    • [x] ~RealOp~
    • [ ] RecvOp
    • [x] ReduceOp
    • [x] ~ReducePrecisionOp~
    • [ ] ReduceScatterOp
    • [x] ReduceWindowOp
    • [x] ~RemainderOp~
    • [ ] ReplicaIdOp
    • [x] ReshapeOp
    • [x] ReverseOp #151
    • [ ] RngBitGeneratorOp
    • [x] ~RoundOp~
    • [x] ~RoundNearestEvenOp~
    • [x] ~RsqrtOp~
    • [ ] ScatterOp
    • [x] SelectOp
    • [ ] SelectAndScatterOp
    • [ ] SendOp
    • [x] ~ShiftLeftOp~
    • [x] ~ShiftRightArithmeticOp~
    • [x] ~ShiftRightLogicalOp~
    • [x] ~SignOp~
    • [x] ~SineOp~
    • [x] SliceOp
    • [x] SortOp
    • [x] ~SqrtOp~
    • [x] ~SubtractOp~
    • [x] ~TanhOp~
    • [x] TransposeOp
    • [x] ~TriangularSolveOp~
    • [ ] UniformDequantizeOp
    • [ ] UniformQuantizeOp
    • [x] WhileOp #151
    • [x] ~XorOp~
    • [ ] Deprecated operations in StableHLO
      • [ ] RngOp
      • [ ] GetTupleElementOp
      • [ ] BroadcastOp
      • [ ] CreateTokenOp
      • [ ] CrossReplicaSumOp
      • [ ] DotOp
      • [ ] EinsumOp
      • [x] ~TorchIndexSelectOp~
      • [ ] UnaryEinsumOp
      • [ ] TupleOp
      • [ ] MapOp
  • [ ] CHLO
    • [x] Binary Element-wise Operations
      • [x] ~BroadcastAddOp~
      • [x] ~BroadcastAtan2Op~
      • [x] ~BroadcastDivOp~
      • [x] ~BroadcastMaxOp~
      • [x] ~BroadcastMinOp~
      • [x] ~BroadcastMulOp~
      • [x] ~BroadcastNextAfterOp~
      • [x] ~BroadcastPolygammaOp~
      • [x] ~BroadcastPowOp~
      • [x] ~BroadcastRemOp~
      • [x] ~BroadcastShiftLeftOp~
      • [x] ~BroadcastShiftRightArithmeticOp~
      • [x] ~BroadcastShiftRightLogicalOp~
      • [x] ~BroadcastSubOp~
      • [x] ~BroadcastZetaOp~
    • [x] Binary Logical Element-wise Operations
      • [x] ~BroadcastAndOp~
      • [x] ~BroadcastOrOp~
      • [x] ~BroadcastXorOp~
    • [ ] Non-broadcasting Binary Operations
      • [ ] NextAfterOp
      • [ ] PolygammaOp
      • [ ] ZetaOp
    • [ ] ComplexOp
    • [x] Unary Element-wise Operations
      • [x] ~AcosOp~
      • [x] ~AcoshOp~
      • [x] ~AsinOp~
      • [x] ~AsinhOp~
      • [x] ~AtanOp~
      • [x] ~AtanhOp~
      • [x] ~BesselI1eOp~
      • [x] ~ConjOp~
      • [x] ~CoshOp~
      • [x] ~SinhOp~
      • [x] ~TanOp~
      • [x] ~ConstantOp~ (shared with StableHLO_ConstantOp)
      • [x] ~ConstantLikeOp~
      • [x] ~DigammaOp~
      • [x] ~ErfOp~
      • [x] ~ErfInvOp~
      • [x] ~ErfcOp~
      • [x] ~IsInfOp~
      • [x] ~IsNegInfOp~
      • [x] ~IsPosInfOp~
      • [ ] LgammaOp
    • [ ] BroadcastCompareOp
    • [ ] BroadcastSelectOp
    • [x] ~TopKOp~
  • [ ] EnzymeXLA
    • [ ] KernelCallOp
    • [ ] JitCallOp
    • [ ] GetStreamOp
    • [ ] Memref2PointerOp
    • [ ] Pointer2MemrefOp
    • [ ] AffineScopeOp
    • [ ] RotateOp
    • [ ] WrapOp
    • [ ] ExtendOp
    • [x] ~CommRegionOp~
    • [x] ~LUFactorizationOp~

Pangoraw avatar Nov 05 '24 12:11 Pangoraw

... because the default broadcasting behavior of enzyme batch is enough.

Do you mean unrolling or leave the op unchanged?

mofeing avatar Nov 05 '24 13:11 mofeing

The batch pass will take the original op:

%0 = stablehlo.add %arg0, %arg1 : tensor<10xf32>

and just prepend the broadcasted dimensions (e.g. 20x4):

%0 = stablehlo.add %arg0, %arg1 : tensor<20x4x10xf32>

This is the default behavior for all ops unless they implement BatchOpInterface.

Pangoraw avatar Nov 05 '24 18:11 Pangoraw

whileop definitely shouldn't be unrolled here in most cases [since it almost always has a number of iterations fixed by a constant aka non data-dependent value]

wsmoses avatar Nov 08 '24 15:11 wsmoses

Should CustomCallOp always fallback to the looped version?

avik-pal avatar May 20 '25 00:05 avik-pal

tenatively I think so, it at minimum feels like a better default [more likely to not error]

wsmoses avatar May 20 '25 00:05 wsmoses