cutlass
cutlass copied to clipboard
[QST] The best way to do D = func(A x B) x C.
I want todo a function like D = func(A x B) x C
,
T = A x B // Matrix multiply
T = func(T) // Some other operator, like mask add bias ...
D = T x C // Matrix multiply
Want do 3 operator: a matrix multiply follow a function than do another matrix multiply in one kernel. I have 2 idea.
- calculate T = A x B get result in register and do func(T) than write T into shared mem. Last do T x C write result into D.
- calculate T = A x B get result in register and do func(T), do't write func(T) into shared mem, just calculate T x C in register.
For 1 it's easy to understand. Does it cost because write func(T) into shared mem and read again when T x C? For 2 how can I make sure func(T) in register is need by T x C?