PiPPy
PiPPy copied to clipboard
[SPMD][Fusion] - ensure buffer dtype matches gradient tensor dtype
Currently we default to FP32 for the fusion buffer, but that is not correct for mixed precison cases. Thus, need to check shape prop metadata and build correct buffer dtype.