ITensorsGPU.jl icon indicating copy to clipboard operation
ITensorsGPU.jl copied to clipboard

Simplify `contract!` overloading logic

Open mtfishman opened this issue 3 years ago • 2 comments

This is mostly an issue about the contraction code logic being too convoluted in NDTensors, but I'm raising it here because GPU overloading is a good stress test for the design.

It would be helpful to simplify the contraction overloading logic. There are a few layers of places that overloading can happen:

  1. At the level of contract!! here. This function is supposed to be a simple wrapper around contract! that implements the generic logic of expanding the output type if needed, like the interface from BangBang.jl. So in principle this should not need to get overloaded, and instead what should get overloaded are generic functions for computing the output storage of a contraction as well as contract!.
  2. At the level of contract! here which is supposed to be the general interface for overloading but because of convoluted code logic it is buried too far down. This is where the actually contraction happens and should be the main place new contraction backends or types get overloaded.
  3. At the level of a gemm! call, like is done here. This one is a bit more difficult, since it requires a lot of logic to take the contraction plan and turn it into a gemm! call, so it is probably tough to make that generic across CPU and GPU. This would be if you wanted to use the "Transpose-Transpose-GEMM-Transpose (TTGT)" contraction strategy and just wanted to change the matrix multiplication backend. @kshyatt, I see that is possibly used here, is that code being used or are you only using cuTENSOR at this point?

Ultimately, the idea would be that contract!! only implements the logic of analyzing the inputs to the function, seeing if the output type is "wide enough" for the output to get mutated directly, and if not creates a proper output storage. The interface I would envision is:

function contract!!(R::Tensor, labelsR, T1::Tensor, labelsT1, T2::Tensor, labelsT2, α = 1, β = 0)
  # Like `contraction_output` but a more generic interface (so could also cover cases like `permutedims!`)
  # `RR` may be an alias for `R` if `R` has a wide enough storage already
  RR = output_storage(contract!, R, labelsR, T1, labelsT1, T2, labelsT2, α, β)
  contract!(RR, labelsR, T1, labelsT1, T2, labelsT2, α, β)
  return RR
end

and contract could be defined trivially as:

function contract(T1::Tensor, labelsT1, T2::Tensor, labelsT2, α = 1, β = 0)
  return contract!!(Empty(), contract_labels(labelsT1, labelsT2), T1, labelsT1, T2, labelsT2, α, β)
end

Then the idea would be that new storage types like CuDense would implement overloads for output_storage (which should mostly get handled by generic Dense code along with generic promote_type-like logic) and the highest level contract!(R::Tensor, labelsR, T1::Tensor, labelsT1, T2::Tensor, labelsT2) function.

A bonus to all of this is that it should improve the type stability of things like contract!!, which I think is the main issue for latency at this point. I tried to improve the type stability of contract!! at some point but the logic got very convoluted and ultimately I decided that it seemed to require rewriting and reorganizing a lot of the code anyway.

For now, I think an actionable item in ITensorsGPU would be to simplify the logic by simplifying the overload of contract!! to the following:

function contract!!(R::CuDenseTensor, labelsR, T1::CuDenseTensor, labelsT1, T2::CuDenseTensor, labelsT2)
  # It could do a widening of the storage here if needed, but let's assume `R` already has a wide enough storage
  contract!(R, labelsR, T1, labelsT1, T2, labelsT2)
  return R
end

and then overload contract!:

function contract!(R::CuDenseTensor, labelsR, T1::CuDenseTensor, labelsT1, T2::CuDenseTensor, labelsT2)
  [...]
end

where [...] is the body of the function _contract!. This skips over all of the logic that includes compute_contraction_properties! (which is not needed at all by cuTENSOR) and _contract_scalar!, outer!!, etc. (which hopefully are optimized within cuTENSOR anyway), making the overload burden on ITensorsGPU.jl a lot less and making it less sensitive to code reorganizations within NDTensors.

mtfishman avatar Apr 08 '21 19:04 mtfishman

Sorry I somehow missed this until now. I definitely agree the NDTensors logic is pretty opaque -- anytime it changes I have to do a deep dive to figure out what methods I need to extend/change on my end.

kshyatt avatar Apr 29 '21 13:04 kshyatt

Yeah, I think for the sake of ITensorsGPU.jl my conclusion is that we can just overload the highest level contract!! function, which is the one that actually gets called by the * operator in ITensors.jl, which would simplify the overloading logic a lot. The contract!! interface shouldn't be changing much so ITensorsGPU.jl will be much less sensitive to code refactoring in NDTensors.jl.

mtfishman avatar Apr 29 '21 13:04 mtfishman