skorch icon indicating copy to clipboard operation
skorch copied to clipboard

Adding DGL and DGL-LifeSci support

Open BernardoHernandezAdame opened this issue 3 years ago • 4 comments

Hi,

I have been currently working with Skorch for standard NN frameworks on tensors and recently started experimenting with some graph neural networks. In particular dgl and dgl-LifeSci. I have a workaround to get things to work with skorch where I add a pass for dgl.DGLGraph() objects in the utils.py check. However skorch throws an error when measuring the batch size due to the number of nodes in the graph batches vary.

Fixes are here; is DGL integration something currently being considered or is there a separate workaround recommended? Screen Shot 2021-08-03 at 12 09 51 PM

Screen Shot 2021-08-03 at 12 11 07 PM

Fixing these two allow me to train dgl-lifesci models using skorch by defining my own dataloader and dataset classes.

Thanks!

BernardoHernandezAdame avatar Aug 03 '21 16:08 BernardoHernandezAdame

Hi, that's pretty cool. I haven't worked with dgl yet. What could perhaps be really helpful is if you could provide a full working example (maybe even a jupyter notebook that's added to the skorch repo), including your dataloaders and dataset classes.

If it's not too much effort to make it work, I would gladly add better support for dgl. The way you changed (I assume) to_device wouldn't quite work, because it would fail for users who don't have dgl installed in their environment. But I have an idea how to make it easier for users to add their own types.

Regarding your second problem, I would need to understand better what the issue is, which is probably best achieved by a full example mentioned above.

BenjaminBossan avatar Aug 03 '21 22:08 BenjaminBossan

Any updates @BernardoHernandezAdame

BenjaminBossan avatar Nov 13 '21 12:11 BenjaminBossan

Here is a script to reproduce. This issue is quite pressing: it affects basically all easy workarounds for using Graph Neural Networks (PyG, DGL, etc) with Skorch. And the assertion actually might be deleterious for debugging shape issues anyway. better to let the torch linalg do the bug raising and change this one to a warning, maybe?

import torch
import skorch
import numpy

class model(torch.nn.Module):
     def __init__(self, **kwargs):
         super().__init__()
         self.lin = torch.nn.Linear(10,1)
     def forward(self, x0, x1):
          z =  x0.dot(x1)
          return self.lin(z)

M = skorch.NeuralNetRegressor(model)
X = {'x0':numpy.random.rand(1000, 25), 'x1': numpy.random.rand(25, 10)}
y = (X['x0'].dot(X['x1']) * numpy.random.rand(10) + numpy.random.rand(10)).sum(-1)
M.fit(X, y)

DCoupry avatar Nov 06 '23 12:11 DCoupry

I think this issue can be closed with a simple update in documentation. The behaviour described can be circumvented by using a torch Dataset object as X with some collate function passed to the iterators. This plus a bit of SliceDataset wizardry should make most scenarii soveable - made it work for Pytorch Geometric.

DCoupry avatar Nov 07 '23 21:11 DCoupry