pinot icon indicating copy to clipboard operation
pinot copied to clipboard

changing the `embedding_dim` of `SemiSupNet` induces error

Open yuanqing-wang opened this issue 4 years ago • 2 comments

Hi,

Looks like if I specify any embedding_dim other than 64 would result in an error? For example,

import torch
import pinot
import numpy as np
ds = pinot.data.moonshot()
ds = pinot.data.utils.batch(ds, len(ds))
ds = pinot.data.datasets.Dataset(ds)
g, y = next(iter(ds))
layer = pinot.representation.dgl_legacy.gn(model_name='GraphConv')
representation = pinot.representation.Sequential(
    layer=layer,
    config=[32, 'tanh',]
)
net = pinot.generative.semi_supervised_net.SemiSupervisedNet(
    representation,
    embedding_dim=32,
)
net.loss(g, y)

gives me:

RuntimeError                              Traceback (most recent call last)
<ipython-input-7-a5dcd9617e3a> in <module>
     15     embedding_dim=32,
     16 )
---> 17 net.loss(g, y)

~/Documents/GitHub/pinot/pinot/generative/semi_supervised_net.py in loss(self, g, y)
    150         h = self.representation.forward(g, pool=None)  # We always call this
    151         # Compute unsupervised loss
--> 152         total_loss = self.loss_unsupervised(g, h) * self.unsup_scale
    153         # Compute the graph representation from node representation
    154         # Then compute graph representation, by pooling

~/Documents/GitHub/pinot/pinot/generative/semi_supervised_net.py in loss_unsupervised(self, g, h)
    228         # Compute the ELBO loss
    229         # First the reconstruction loss (~~ negative expected log likelihood)
--> 230         recon_loss = self.decoder.decode_and_compute_recon_error(g, z_sample)
    231         # KL-divergence term
    232         KLD = (

~/Documents/GitHub/pinot/pinot/generative/decoder.py in decode_and_compute_recon_error(self, g, z_sample)
    293                     negative expected likelihood term in the ELBO
    294         """
--> 295         decoded_subgraphs = self.forward(g, z_sample)
    296         gs_unbatched = dgl.unbatch(g)
    297         assert len(decoded_subgraphs) == len(gs_unbatched)

~/Documents/GitHub/pinot/pinot/generative/decoder.py in forward(self, g, z_sample)
    411             # Decode each subgraph
    412             decoded_subgraphs = [
--> 413                 self.decode(g_sample.ndata["h"]) for g_sample in gs_unbatched
    414             ]
    415             return decoded_subgraphs

~/Documents/GitHub/pinot/pinot/generative/decoder.py in <listcomp>(.0)
    411             # Decode each subgraph
    412             decoded_subgraphs = [
--> 413                 self.decode(g_sample.ndata["h"]) for g_sample in gs_unbatched
    414             ]
    415             return decoded_subgraphs

~/Documents/GitHub/pinot/pinot/generative/decoder.py in decode(self, z)
    377 
    378         # This has shape (n, n, 2*self.Dx2)
--> 379         e_tensor = temp.view(n, n, 2 * h)
    380 
    381         # e_tensor -> E_tilde

RuntimeError: shape '[29, 29, 128]' is invalid for input of size 80736

yuanqing-wang avatar Dec 13 '20 06:12 yuanqing-wang

Trying to figure out what's going on here. In this line,

https://github.com/choderalab/pinot/blob/8a03838903c2ac3d09b3d7538040d3a8a33f32ae/pinot/generative/decoder.py#L372

Shouldn't it be something like:

temp2 = zx.repeat(n, 1).view(n * n, h)

yuanqing-wang avatar Dec 13 '20 06:12 yuanqing-wang

@yuanqing-wang The goal here was to create a (n x n x (2* Dx2)) tensor temp such that each temp[i, j, :] is the concatenation of the vector representation of nodes i and j. The hope was that the neural networks could use interesting information from the representation of both nodes to predict whether there is an edge between them

https://github.com/choderalab/pinot/blob/8a03838903c2ac3d09b3d7538040d3a8a33f32ae/pinot/generative/decoder.py#L371

        # zx, Atilde -> E_tilde
        temp1 = zx.repeat(1, n).view(n * n, h)  # Shape should be (n, n, Dx2)
        temp2 = z.repeat(n, 1)  # Shape is also (n, n, Dx2)
        temp = torch.cat(
            (temp1, temp2), 1
        )  # This creates a (n, n, 2 * Dx2) tensor <------------------ This is what we ultimately want

dnguyen1196 avatar Dec 13 '20 22:12 dnguyen1196