escnn
escnn copied to clipboard
Batched equivariant maps basis expansion (?)
Hi @Gabri95,
Do you see any easy way to enable a batched basis expansion of equivariant maps? What do I mean...
The process of construction of a linear equivariant map T
from an array of w=weights
(of the same dimension as the dimension of the basis of T
) seems to be tailored for a single weight vector and a single resultant equivariant linear map. This is perfectly suitable for building the basis of linear layers.
However, it is not suitable for parametrically building several equivariant linear maps from a batched collection of weights (batch, dim(w))
, resulting in batch
equivariant linear operators. I tried my best to understand and devise a way to do this, but with the current implementation, it seems rather tricky.
When using the EMLP library, this was possible by finding the nullspace projector matrix Q [nxn, basis_dim], which we can use to project several weight vectors T =reshape(Q w)
to their corresponding equiv linear matrices. This process had an immense memory complexity (because of the nxn
: n being the dimension of the T
, assuming squared T
). I understand your approach is elegantly avoiding this memory complexity problem. Do you think of a way of making a batch version of your basis expansion?
hi @Danfoa
The single-block basis expansion and sampler classes could be used for that. That's actually what I also internally do in the BlockBasisSampler class for example.
The external interface of the library (via the conv layers) does not directly support this, though. Could you maybe provide a more detailed example of what you'd like to do, so I can suggest something more concrete or try to write some example of code?
For instance, do you need to compute a number of convolution kernels for an RdConv or do you want to run multiple RdPointConv in parallel? Or are you only interested in LinearLayers?
Best, Gabriele
This sounds amazing thanks for the help!.
Let me describe my application case.
TLDR: I want to construct multiple equivariant linear maps T
of shape [nxn]
. We know that the basis of T
is of dimension d
. I don't want to learn this map, instead:
- I want to learn a function
T(.): X -> R^d
that parameterizes the linear mapsT(x) \in R^(nxn)
, as a function of their inputx
of shape(batch, |x|)
. - The output of the network of shape
(batch, d)
will be used to parameterizebatch
distinct equivariant maps, resulting in(batch, n, n)
. - Then, I would like to apply the linear maps to each of the input vectors.
More details: I am learning equivariant dynamical systems with transition Operators. The nice thing about this approach is that if you find the appropriate non-linear change of coordinates x = f(z)
, the dynamics of your system become linear dx/dt = T(x)x | T(x) \in R^(nxn)
, instead of the potentially non-linear dynamics of z
. Here, think of z
as the state of your dynamical system (e.g., position and momentum) and x
as a new "observable" state (e.g., a set of relevant functions of x, such as energy, polynomials, etc.). For equivariant systems, T(x) needs to be an equivariant linear map. And here is where I need to learn the function T(.): X -> R^d
. Here d
is the dimension of the space of endomorphisms X->X
. Which is why your basis expansion has become so useful to me.
Hi @Danfoa
That sounds like a really cool application!
So, if you know in advance the size of batch
, the simplest strategy you can use now is to generate a linear map of shape batch*n x n
, and then reshaping it into batch, n, n
.
You can just use a BlockBasisExpansion
for expanding these weights.
I can make something a bit more flexible to achieve exactly what you want by removing this assert and just use the last dimension of weights
.
I am not sure I have time to implement it properly right now, but you could try that yourself and open a PR maybe?
I can certainly try @Gabri95,
So, if you know in advance the size of batch, the simplest strategy you can use now is to generate a linear map of shape batch*n x n, and then reshaping it into batch, n, n. You can just use a BlockBasisExpansion for expanding these weights.
I know the batch dimension, but I am a bit insecure about how to interact with the BlockBasisExpansion
. The code is a bit hard to digest without investing a large amount of time on it. Any hints?
I can make something a bit more flexible to achieve exactly what you want by removing this assert and just use the last dimension of weights. I am not sure I have time to implement it properly right now, but you could try that yourself and open a PR maybe?
I will give it a try. I think I already see the problem. It should not be difficult.