mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Introduce Graph Network module

Open francescofarina opened this issue 1 year ago • 6 comments

Proposed changes

This PR introduces a generic Graph Network block module, as defined in https://arxiv.org/pdf/1806.01261.pdf

Some points are open for discussion:

  • I wrote a working example in the documentation for the GraphNetworkBlock in which I defined the modules for node, edge and global updates. It may be worth including them as actual modules but I'm worried they're not generic enough
  • There are currently no tests as there's not much to test in the GraphNetworkBlock

Checklist

Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

francescofarina avatar Jan 05 '24 19:01 francescofarina

I think it is a great idea to have support for GNNs in MLX. However, I think that it is better to implement GNN-related stuff in a separate library (e.g. mlx-gnn, mlx-graph, mlx-geometric). Here are some reasons:

  • There are a lot of domain-specific concepts related to GNNs. As soon as GNNs are added to MLX, people will likely want to implement graph batching and memory-efficient aggregations. Similarly, there are a lot of neighborhood aggregation layers for graphs, such as GAT, GATv2, GIN... These are commonly used in graph ML, but are probably too specific for implementation in mlx.
  • Efficient implementations of GNN layers often require the support for sparse tensors and sparse linear algebra. I am not sure this is currently implemented in MLX.

gboduljak avatar Jan 06 '24 15:01 gboduljak

When it comes to tests, you could test whether neighborhood aggregation is correctly implemented.

gboduljak avatar Jan 06 '24 15:01 gboduljak

I think it is a great idea to have support for GNNs in MLX. However, I think that it is better to implement GNN-related stuff in a separate library (e.g. mlx-gnn, mlx-graph, mlx-geometric). Here are some reasons:

  • There are a lot of domain-specific concepts related to GNNs. As soon as GNNs are added to MLX, people will likely want to implement graph batching and memory-efficient aggregations. Similarly, there are a lot of neighborhood aggregation layers for graphs, such as GAT, GATv2, GIN... These are commonly used in graph ML, but are probably too specific for implementation in mlx.
  • Efficient implementations of GNN layers often require the support for sparse tensors and sparse linear algebra. I am not sure this is currently implemented in MLX.

You're right, but I'm in two minds on that. It's true that in all other frameworks GNNs are implemented in external libraries, so one could think of continuing that trend. However, I personally still find it weird having to rely on external packages to use GNNs in other frameworks and I believe there are historical reasons for that. Nowadays GNNs are first-class citizens in ML and it would be a bit counterintuitive to not have them implemented in the core package. The way I see it, sparse tensors, memory-efficient aggregation ops, scatter ops, etc., are pretty widely used beyond pure GNN applications.

Since we're rightfully including modern architectural layers in the core, like transformers, I think it could make sense to have GNNs in there too.

francescofarina avatar Jan 06 '24 17:01 francescofarina

When it comes to tests, you could test whether neighborhood aggregation is correctly implemented.

Yes that makes sense, but right now that part is not exposed (just used in an example in the docstring). If we take it out than absolutely yes. Perhaps that could be done in another PR where we implement more graph ops.

francescofarina avatar Jan 06 '24 17:01 francescofarina

I think it is a great idea to have support for GNNs in MLX. However, I think that it is better to implement GNN-related stuff in a separate library (e.g. mlx-gnn, mlx-graph, mlx-geometric). Here are some reasons:

  • There are a lot of domain-specific concepts related to GNNs. As soon as GNNs are added to MLX, people will likely want to implement graph batching and memory-efficient aggregations. Similarly, there are a lot of neighborhood aggregation layers for graphs, such as GAT, GATv2, GIN... These are commonly used in graph ML, but are probably too specific for implementation in mlx.
  • Efficient implementations of GNN layers often require the support for sparse tensors and sparse linear algebra. I am not sure this is currently implemented in MLX.

You're right, but I'm in two minds on that. It's true that in all other frameworks GNNs are implemented in external libraries, so one could think of continuing that trend. However, I personally still find it weird having to rely on external packages to use GNNs in other frameworks and I believe there are historical reasons for that. Nowadays GNNs are first-class citizens in ML and it would be a bit counterintuitive to not have them implemented in the core package. The way I see it, sparse tensors, memory-efficient aggregation ops, scatter ops, etc., are pretty widely used beyond pure GNN applications.

Since we're rightfully including modern architectural layers in the core, like transformers, I think it could make sense to have GNNs in there too.

Those are really good arguments. I agree about historical reasons.

gboduljak avatar Jan 06 '24 17:01 gboduljak

It makes sense to close this, given the discussion in https://github.com/ml-explore/mlx/pull/398.

gboduljak avatar Jan 08 '24 23:01 gboduljak

Yes, closing now.

francescofarina avatar Jan 09 '24 08:01 francescofarina