ITensorsGPU.jl
ITensorsGPU.jl copied to clipboard
Simplify `contract!` overloading logic
This is mostly an issue about the contraction code logic being too convoluted in NDTensors
, but I'm raising it here because GPU overloading is a good stress test for the design.
It would be helpful to simplify the contraction overloading logic. There are a few layers of places that overloading can happen:
- At the level of
contract!!
here. This function is supposed to be a simple wrapper aroundcontract!
that implements the generic logic of expanding the output type if needed, like the interface from BangBang.jl. So in principle this should not need to get overloaded, and instead what should get overloaded are generic functions for computing the output storage of a contraction as well ascontract!
. - At the level of
contract!
here which is supposed to be the general interface for overloading but because of convoluted code logic it is buried too far down. This is where the actually contraction happens and should be the main place new contraction backends or types get overloaded. - At the level of a
gemm!
call, like is done here. This one is a bit more difficult, since it requires a lot of logic to take the contraction plan and turn it into agemm!
call, so it is probably tough to make that generic across CPU and GPU. This would be if you wanted to use the "Transpose-Transpose-GEMM-Transpose (TTGT)" contraction strategy and just wanted to change the matrix multiplication backend. @kshyatt, I see that is possibly used here, is that code being used or are you only usingcuTENSOR
at this point?
Ultimately, the idea would be that contract!!
only implements the logic of analyzing the inputs to the function, seeing if the output type is "wide enough" for the output to get mutated directly, and if not creates a proper output storage. The interface I would envision is:
function contract!!(R::Tensor, labelsR, T1::Tensor, labelsT1, T2::Tensor, labelsT2, α = 1, β = 0)
# Like `contraction_output` but a more generic interface (so could also cover cases like `permutedims!`)
# `RR` may be an alias for `R` if `R` has a wide enough storage already
RR = output_storage(contract!, R, labelsR, T1, labelsT1, T2, labelsT2, α, β)
contract!(RR, labelsR, T1, labelsT1, T2, labelsT2, α, β)
return RR
end
and contract
could be defined trivially as:
function contract(T1::Tensor, labelsT1, T2::Tensor, labelsT2, α = 1, β = 0)
return contract!!(Empty(), contract_labels(labelsT1, labelsT2), T1, labelsT1, T2, labelsT2, α, β)
end
Then the idea would be that new storage types like CuDense
would implement overloads for output_storage
(which should mostly get handled by generic Dense
code along with generic promote_type
-like logic) and the highest level contract!(R::Tensor, labelsR, T1::Tensor, labelsT1, T2::Tensor, labelsT2)
function.
A bonus to all of this is that it should improve the type stability of things like contract!!
, which I think is the main issue for latency at this point. I tried to improve the type stability of contract!!
at some point but the logic got very convoluted and ultimately I decided that it seemed to require rewriting and reorganizing a lot of the code anyway.
For now, I think an actionable item in ITensorsGPU
would be to simplify the logic by simplifying the overload of contract!!
to the following:
function contract!!(R::CuDenseTensor, labelsR, T1::CuDenseTensor, labelsT1, T2::CuDenseTensor, labelsT2)
# It could do a widening of the storage here if needed, but let's assume `R` already has a wide enough storage
contract!(R, labelsR, T1, labelsT1, T2, labelsT2)
return R
end
and then overload contract!
:
function contract!(R::CuDenseTensor, labelsR, T1::CuDenseTensor, labelsT1, T2::CuDenseTensor, labelsT2)
[...]
end
where [...]
is the body of the function _contract!. This skips over all of the logic that includes compute_contraction_properties!
(which is not needed at all by cuTENSOR
) and _contract_scalar!
, outer!!
, etc. (which hopefully are optimized within cuTENSOR
anyway), making the overload burden on ITensorsGPU.jl
a lot less and making it less sensitive to code reorganizations within NDTensors
.
Sorry I somehow missed this until now. I definitely agree the NDTensors logic is pretty opaque -- anytime it changes I have to do a deep dive to figure out what methods I need to extend/change on my end.
Yeah, I think for the sake of ITensorsGPU.jl
my conclusion is that we can just overload the highest level contract!!
function, which is the one that actually gets called by the *
operator in ITensors.jl
, which would simplify the overloading logic a lot. The contract!!
interface shouldn't be changing much so ITensorsGPU.jl
will be much less sensitive to code refactoring in NDTensors.jl
.