storch icon indicating copy to clipboard operation
storch copied to clipboard

Support TensorModule of distinct input/output types

Open davoclavo opened this issue 2 years ago • 2 comments

Currently TensorModule is parametrized on a single type, so keeps the transformation within the same DType:

trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]):
  override def toString(): String = "TensorModule"

However there are modules where the input might be different than the output, such as nn.Embedding which accepts Int or Longs as input indexes, and the output could be any DTtype. So the trait would need to be parametrized like this:

trait TensorModule[D <: DType, D2 <: DType] extends Module with (Tensor[D] => Tensor[D2]):
 override def toString(): String = "TensorModule"

and the example implementation would be something like:

final class Embedding[ParamType <: FloatNN | ComplexNN: Default](
    numEmbeddings: Int,
    ...
) extends ...
    with TensorModule[IntNN, ParamType]:
    
    def apply(t: Tensor[IntNN]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))

This is doable, however there are useful operators on TensorModule, such as nn.Sequential, which expects an array of modules to chain. By having a single parameter the compile time validation is straightforward, but having distinct input/output types things seem to get a bit more complex to validate at compile time.

I will do some research on this on how to solve it.

Any pointers or ideas are more than welcome!

davoclavo avatar Jun 27 '23 21:06 davoclavo

After doing a bit of research it seems like nn.Embedding module is an exception to the norm. There may be other types of pytorch modules that change dtype from input->output, but I am yet unaware of them (perhaps nn.Upsample but in practice I couldn't get it to change types)

So a couple of thoughts

  1. An additional TensorModuleDistinct[D <: DType, D2 <: DType] must be added if we want to have Embedding module

  2. Regarding nn.Sequential - here are the options I can think of: a) It would have to accept a Tuple/HList of tensors in order to have all the proper type validations if we wanted to also accept these Modules b) Adding an extra optional parameter which allows for an optional initial module which can do type transformations c) Ignore this for now and let the user know that they should handle Embedding modules as an extra step when defining their model's layer structure

davoclavo avatar Jun 27 '23 23:06 davoclavo

After doing a bit of research it seems like nn.Embedding module is an exception to the norm. There may be other types of pytorch modules that change dtype from input->output, but I am yet unaware of them (perhaps nn.Upsample but in practice I couldn't get it to change types)

That's good to know. One thing to keep in mind is that the included models are mostly basic building blocks, but I'm not sure how inputs/output structures and types look like with more complex custom modules. Another good reason to implement a few more architectures like transformers to get a better feeling for it. :)

1. An additional `TensorModuleDistinct[D <: DType, D2 <: DType]` must be added if we want to have `Embedding` module

:+1:

2. Regarding `nn.Sequential` - here are the options I can think of:
   a) It would have to accept a Tuple/HList of tensors in order to have all the proper type validations if we wanted to also accept these Modules
   b) Adding an extra optional parameter which allows for an optional initial module which can do type transformations
   c) Ignore this for now and let the user know that they should handle Embedding modules as an extra step when defining their model's layer structure

Would be interesting if it is possible to create an easy to use API for a)

One general thought regarding the parameter types of modules is that we need to consider their recursive structure and mutability. Here's what I said in another discussion about that:

In a PyTorch module (which are also mutable) you can convert the parameter types of all submodules recursively. Now if we have a val module MyModule[Float32] and call myModule.to(dtype=torch.float16), we break type-safety (I think).

So perhaps we need to provide an immutable module API and make sure to copy the module and all its submodules recursively to make this safe or perhaps you have another idea how to deal with it.

sbrunk avatar Jun 28 '23 20:06 sbrunk