ocaml-torch
ocaml-torch copied to clipboard
Use a GADT to add type constraints for tensor elements
The current tensor type is Tensor.t
. However tensors can embed multiple kind of elements and calling a function like to_float0_exn
on a tensor containing integers is likely to raise.
We could try adding some type information to Tensor.t
in the same way this is done for bigarray or the tensorflow tensors from tensorflow-ocaml. This would involve a type like 'a Tensor.t
where 'a
is the type of underlying element.
Then functions could have the following type:
type 'a t
type 'a kind
val float_kind : float kind
val int_kind : int kind
val create : 'a kind -> shape:int list -> 'a Tensor.t
val to_elem0_exn : 'a Tensor.t -> 'a
The wrapper code for tensor operations is automatically generated from the Declarations.yaml
(this file being generated when compiling PyTorch). This file describes all the operations but does not provide much type information, though IndexTensor are used for tensor holding integers, multiple tensors involved in the same ops should have the same type, etc.
There is some ongoing work on cleaning up Declarations.yaml
which is also likely to help https://github.com/pytorch/pytorch/issues/12562.