mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Add Message Passing Neural Network (MPNN)

Open TristanBilot opened this issue 1 year ago • 22 comments

Proposed changes

This PR proposes a draft for the MPNN architecture, used as a basic block to build almost any GNNs using message passing paradigm. This abstract architecture is also used in PyTorch Geometric and DGL.

This is a draft, tests will come very soon. For anyone willing to contribute to other GNN-based stuff, please tell it here!

Checklist

Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have updated the necessary documentation (if needed)

TristanBilot avatar Jan 08 '24 01:01 TristanBilot

As mentioned here by @gboduljak, an important question we may ask before diving into the implementation of various GNN implementations is:

Should we create a separate library specifically made for this purpose, or should we integrate these graph-related features directly within MLX?

I personally think that integrating everything within MLX (graphs + other specific fields) could lead to a very heavy MLX library, with plenty of specific features used by only a few people. What's nice with MLX is also its lightweight structure and understandable code. Adding too many features may impact the cool experience of the framework. On the other hand, I agree with @francescofarina that GNN libraries like PyG can quickly become hard to install and to maintain because of the many dependencies and varying versions.

We're lucky to be at the early stage of MLX to make this kind of decisions, which will shape its future direction! Any ideas @awni @angeloskath ?

TristanBilot avatar Jan 08 '24 02:01 TristanBilot

I would like to add that an efficient MPNN layer implementation also depends on scatter_{op} and sparse linear algebra support. In addition, popular graph attention layers such as GAT and GATv2 depend on scatter_softmax, which can be either fused or implemented using scatter_add and scatter_mul/div. Given the issues with scatter vjp implementation (https://github.com/ml-explore/mlx/pull/394), we might want to wait to first implement scatter operations.

gboduljak avatar Jan 08 '24 02:01 gboduljak

I would like to contribute to implementation of MPNN layers, graph normalization layers, and graph utilities, such as graph (mini) batching.

gboduljak avatar Jan 08 '24 02:01 gboduljak

Yes, implementing all scatter_* operations will be one of the main things to achieve. I plan to be working on that too.

TristanBilot avatar Jan 08 '24 02:01 TristanBilot

Thanks for the contribution! This looks really nice so far. I'm thinking about the best way to incorporate it with MLX. At this stage my thought is it belongs better in either one of two places:

  1. mlx-examples
  2. A higher level mlx-gnn package which depends on mlx

My experience with these higher level models is they often work really well as a simple and hackable example which can be used as a starting point for people doing graph NNs and likely want customize it one way or another.

A higher level package could also work well but there one has to know more the users and usage patterns (which I do not really for this). Edit: I see you posted a couple examples in the message. I will take a look at those.

What do you think between those two options?

awni avatar Jan 08 '24 05:01 awni

On the other hand, I agree with @francescofarina that GNN libraries like PyG can quickly become hard to install and to maintain because of the many dependencies and varying versions.

Can you say more about that @federicobucchi? It seems like maintaining a dependency on mlx should not be that bad?

awni avatar Jan 08 '24 05:01 awni

I think that the mlx-examples option is a no go, as this PR goes beyond just releasing a new model. The MPNN is one of the fundamental basic blocks that will be used by many other features related to GNNs. We thus want to build an entire ecosystem that people would easily build on top of it, not just an example.

@awni I really encourage you to take a look at the torch geometric lib, as it contains all the features we want to implement also in MLX, and it has a nice API also built on top of MPNN. Basically, it's written in plain Python, with some parts of fast implementations written in cpp in different modules, like torch_sparse, torch_scatter. All these modules are hosted by its founder, Matthias Fey, and are not included by default in PyTorch. To use torch_geometric properly, it's thus necessary to import these modules, all with torch. This is why there are so many dependencies to rely on.

As you don't consider the option "integrate these features directly within MLX", it seems that it would be preferable to create a new dedicated module like for torch_geometric and other GNN libs. I'm not sure that using separate modules for every types of low-level operations is a good option though. I would say that we can integrate all these operations directly within the same mlx_graphs or mlx_geometric lib.

GNNs are a rapidly evolving field with a large community, I'm sure we'll find many contributors with some proper communication about this new lib. I'm willing to fully engage into it!

TristanBilot avatar Jan 08 '24 08:01 TristanBilot

Just to add on the comment I made here https://github.com/ml-explore/mlx/pull/380#issuecomment-1879761512 . It seems we're going down the road of implementing scatter and sparse ops in the mlx core, so maintaining an external library specific for GNNs shouldn't be too much of an extra burden. The open question I still have is whether it would make sense to have some basic GNN blocks as part of the mlx core and delegate more specific/high-level GNN architectures to an external library (and even mlx-examples could make sense at that point IMO).

francescofarina avatar Jan 08 '24 11:01 francescofarina

It would probably be much easier for maintenance to have all GNN-related code within the separate mlx-graphs lib (MPNN, GraphNet, etc), and the basic operations (scattering, sparse operations) remain in mlx.core as they may be used in non-GNN scenarios.

TristanBilot avatar Jan 08 '24 17:01 TristanBilot

It would probably be much easier for maintenance to have all GNN-related code within the separate mlx-graphs lib (MPNN, GraphNet, etc), and the basic operations (scattering, sparse operations) remain in mlx.core as they may be used in non-GNN scenarios.

I also think this is the best idea.

gboduljak avatar Jan 08 '24 17:01 gboduljak

It would probably be much easier for maintenance to have all GNN-related code within the separate mlx-graphs lib

The fastest path to get this going is an external repo (outside ml-explore that is). If you all are interested in making one I think that is great and we will gladly support the features / encourage contributions in core needed to make it work.

If you're interested in it, we can look into including it as a repo in the ml-explore org but that will take a little more time to set up.

Any thoughts there? @TristanBilot @gboduljak @francescofarina ?

@francescofarina to address your question:

is whether it would make sense to have some basic GNN blocks as part of the mlx core

I assume the message passing base class is what you are referring to? Are there others you were thinking of?

Potentially the answer will be on a case-by-case basis, but I would also suggest we cross that bridge at a future time. No need to close the door now on including it but also no need to do it yet either. Having these in mlx-graphs could be a good way to refine the API a bit before merging them back.

awni avatar Jan 08 '24 20:01 awni

@awni, I think this is a wonderful idea! And it is IMO the most logical and straightforward way to do. I personally can allocate a lot of time on this project during my PhD so it will be a real pleasure contributing to this mlx-graphs package.

Although creating the new repo may take time, we can still think about the new API and core functions to develop beforehand.

TristanBilot avatar Jan 08 '24 20:01 TristanBilot

So you suggest starting the development of the package locally in our repo, and then merge it within ml-explore later?

TristanBilot avatar Jan 08 '24 20:01 TristanBilot

At this point, I cannot allocate enough time to lead development of mlx-graph, but I plan to contribute to it. I think it is the best to develop core mlx-graph features outside of mlx. When more graph features are developed and more is known about its usage, we might consider including it or its parts within mlx. Assuming that scatter operations and sparse linear algebra is at some point supported natively in mlx, maintaining dependencies should not be too difficult.

gboduljak avatar Jan 08 '24 20:01 gboduljak

Totally agree with you @gboduljak. On my side, I have enough time to take the lead. I just created a draft repo to start playing with some implementations. Just moved the 2 files from this PR for the moment.

TristanBilot avatar Jan 08 '24 20:01 TristanBilot

The fastest path to get this going is an external repo (outside ml-explore that is). If you all are interested in making one I think that is great and we will gladly support the features / encourage contributions in core needed to make it work. If you're interested in it, we can look into including it as a repo in the ml-explore org but that will take a little more time to set up.

@awni no preference there, but I believe whether to start it inside or outside the ml-explore org may also be a question for you? I'm happy to allocate time and contribute in both cases.

francescofarina avatar Jan 08 '24 20:01 francescofarina

I assume the message passing base class is what you are referring to? Are there others you were thinking of?

@awni yes I was thinking of the MPNN and the generic block I was working in https://github.com/ml-explore/mlx/pull/380. But as you suggest I think it probably makes sense to cross that bridge later in time.

francescofarina avatar Jan 08 '24 20:01 francescofarina

I also will be able to allocate a lot of time on the project during my PhD and will be collaborating tightly with @TristanBilot :).

Kheims avatar Jan 08 '24 21:01 Kheims

I'm considering to create a chat group to ease quick communication between us, @francescofarina @gboduljak do you use X?

TristanBilot avatar Jan 08 '24 21:01 TristanBilot

@TristanBilot sure, handler in my profile.

francescofarina avatar Jan 08 '24 21:01 francescofarina

I'm considering to create a chat group to ease quick communication between us, @francescofarina @gboduljak do you use X?

I do not use X :(. Please create the group using communication channel you like the most and let me know how to join :)

gboduljak avatar Jan 08 '24 22:01 gboduljak

To be honest, I just installed X last week and it's very nice for sharing ideas with people from the open source community. I'd encourage you to give it a try ^^

TristanBilot avatar Jan 08 '24 22:01 TristanBilot

Ok it sounds like the immediate next step is to start a community lead repo mlx-graph (or something similarly named).. @TristanBilot may have a made the repo already.

I'm very excited to see how this develops! Thanks for the discussion everyone!!

For now I will close this PR since we will not immediately add anything from it.

awni avatar Jan 09 '24 01:01 awni