horde-ad
horde-ad copied to clipboard
Merge `ranked` and `shaped` (as in `ADReady`) into `tensor` indexed by `[Maybe Nat]`
This is @awf's idea that would improve the end user API. The Nothing
would signify that the shape may differ at runtime. Cognitive load would be reduced (no more ranked
and shaped
in signatures) and new expressiveness would be gained (tensors that are ranked on some dimensions but shaped on others).
For this to be possible we need an extension of GHC.TypeLits.Normalise
plugin family and somebody to maintain it. [Edit: also an extension of orthotope
would be needed]. It would be great to see if any other Haskell libraries require such functionality already. Or we could move to Agda.
We'd also need good examples where the ability to vary some shapes during runtime is crucial, because in all current examples the shapes are known and fixed once they are determined at the start of runtime (by looking at input data). It's not even clear such varying shapes would work, e.g., with our current optimizer implementations.
An example. A signature in our current formalism.
type LayerWeightsRNNShaped shaped in_width out_width r =
( shaped r '[out_width, in_width] -- input weight
, shaped r '[out_width, out_width] -- state weight
, shaped r '[out_width] ) -- bias
rnnMnistTwoS
:: forall out_width batch_size sizeMnistHeight.
ADReady ranked shaped r
=> shaped r [2 * out_width, batch_size]
-> PrimalOfS shaped r [sizeMnistHeight, batch_size]
-> (LayerWeightsRNNShaped shaped sizeMnistHeight out_width r
,LayerWeightsRNNShaped shaped out_width out_width r)
-> (shaped r [out_width, batch_size]
,shaped r [2 * out_width, batch_size])
A new signature of a different program. Let's assume the program requires some shapes to vary.
type data Size = Dynamic | N Nat
type LayerWeightsRNNShaped tensor in_width out_width r =
( tensor r '[out_width, in_width] -- input weight
, tensor r '[Dynamic, out_width] -- state weight
, tensor r '[out_width] ) -- bias
rnnMnistTwoS
:: forall out_width batch_size sizeMnistHeight.
(ADReady tensor r, N 0 + batch_size ~ batch_size)
=> tensor r [Dynamic, batch_size]
-> PrimalOfS tensor r [sizeMnistHeight, batch_size]
-> (LayerWeightsRNNShaped tensor sizeMnistHeight out_width r
,LayerWeightsRNNShaped tensor out_width out_width r)
-> (tensor r [Dynamic :: Size, batch_size :: Size]
,tensor r [N 2 * out_width, batch_size])