mlx
mlx copied to clipboard
Introduce Graph Network module
Proposed changes
This PR introduces a generic Graph Network block module, as defined in https://arxiv.org/pdf/1806.01261.pdf
Some points are open for discussion:
- I wrote a working example in the documentation for the
GraphNetworkBlock
in which I defined the modules for node, edge and global updates. It may be worth including them as actual modules but I'm worried they're not generic enough - There are currently no tests as there's not much to test in the
GraphNetworkBlock
Checklist
Put an x
in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes - [ ] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)
I think it is a great idea to have support for GNNs in MLX. However, I think that it is better to implement GNN-related stuff in a separate library (e.g. mlx-gnn
, mlx-graph
, mlx-geometric
). Here are some reasons:
- There are a lot of domain-specific concepts related to GNNs. As soon as GNNs are added to MLX, people will likely want to implement graph batching and memory-efficient aggregations. Similarly, there are a lot of neighborhood aggregation layers for graphs, such as
GAT
,GATv2
,GIN
... These are commonly used in graph ML, but are probably too specific for implementation inmlx
. - Efficient implementations of GNN layers often require the support for sparse tensors and sparse linear algebra. I am not sure this is currently implemented in MLX.
When it comes to tests, you could test whether neighborhood aggregation is correctly implemented.
I think it is a great idea to have support for GNNs in MLX. However, I think that it is better to implement GNN-related stuff in a separate library (e.g.
mlx-gnn
,mlx-graph
,mlx-geometric
). Here are some reasons:
- There are a lot of domain-specific concepts related to GNNs. As soon as GNNs are added to MLX, people will likely want to implement graph batching and memory-efficient aggregations. Similarly, there are a lot of neighborhood aggregation layers for graphs, such as
GAT
,GATv2
,GIN
... These are commonly used in graph ML, but are probably too specific for implementation inmlx
.- Efficient implementations of GNN layers often require the support for sparse tensors and sparse linear algebra. I am not sure this is currently implemented in MLX.
You're right, but I'm in two minds on that. It's true that in all other frameworks GNNs are implemented in external libraries, so one could think of continuing that trend. However, I personally still find it weird having to rely on external packages to use GNNs in other frameworks and I believe there are historical
reasons for that. Nowadays GNNs are first-class citizens in ML and it would be a bit counterintuitive to not have them implemented in the core package. The way I see it, sparse tensors, memory-efficient aggregation ops, scatter ops, etc., are pretty widely used beyond pure GNN applications.
Since we're rightfully including modern architectural layers in the core, like transformers, I think it could make sense to have GNNs in there too.
When it comes to tests, you could test whether neighborhood aggregation is correctly implemented.
Yes that makes sense, but right now that part is not exposed (just used in an example in the docstring). If we take it out than absolutely yes. Perhaps that could be done in another PR where we implement more graph ops.
I think it is a great idea to have support for GNNs in MLX. However, I think that it is better to implement GNN-related stuff in a separate library (e.g.
mlx-gnn
,mlx-graph
,mlx-geometric
). Here are some reasons:
- There are a lot of domain-specific concepts related to GNNs. As soon as GNNs are added to MLX, people will likely want to implement graph batching and memory-efficient aggregations. Similarly, there are a lot of neighborhood aggregation layers for graphs, such as
GAT
,GATv2
,GIN
... These are commonly used in graph ML, but are probably too specific for implementation inmlx
.- Efficient implementations of GNN layers often require the support for sparse tensors and sparse linear algebra. I am not sure this is currently implemented in MLX.
You're right, but I'm in two minds on that. It's true that in all other frameworks GNNs are implemented in external libraries, so one could think of continuing that trend. However, I personally still find it weird having to rely on external packages to use GNNs in other frameworks and I believe there are
historical
reasons for that. Nowadays GNNs are first-class citizens in ML and it would be a bit counterintuitive to not have them implemented in the core package. The way I see it, sparse tensors, memory-efficient aggregation ops, scatter ops, etc., are pretty widely used beyond pure GNN applications.Since we're rightfully including modern architectural layers in the core, like transformers, I think it could make sense to have GNNs in there too.
Those are really good arguments. I agree about historical
reasons.
It makes sense to close this, given the discussion in https://github.com/ml-explore/mlx/pull/398.
Yes, closing now.