physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

GraphCast improvements - Part I

Open mnabian opened this issue 1 year ago • 4 comments

Modulus Pull Request

Description

Closes https://github.com/NVIDIA/modulus/issues/506, https://github.com/NVIDIA/modulus/issues/505, https://github.com/NVIDIA/modulus/issues/486, https://github.com/NVIDIA/modulus/issues/508, https://github.com/NVIDIA/modulus/issues/509, https://github.com/NVIDIA/modulus/issues/511

Checklist

  • [x] I am familiar with the Contributing Guidelines.
  • [x] New or existing tests cover these changes.
  • [x] The documentation is up to date with these changes.
  • [x] The CHANGELOG.md is up to date with these changes.
  • [x] An issue is linked to this pull request.

Dependencies

mnabian avatar May 21 '24 01:05 mnabian

/blossom-ci

mnabian avatar May 21 '24 01:05 mnabian

@mnabian Since you are revisiting GraphCast now, adding a few comments

  • Can we add the option to use transformer_engine.LayerNorm? In AIFS benchmarks, we just could get a 1.3x end-to-end improvement from doing so since the PyTorch implementation is rather bad for the sizes we encounter in these workloads.
  • Can you check whether the current combination of MeshGraphNodeBlock and MeshGraphEdgeBlock actually matches the paper (https://github.com/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_processor.py#L97-L98) I created a schematic of the GraphCast architecture for some Arch folks last week, and I think the order of residuals over the edges does not match the paper here. I might have made a mistake when trying to use shared primitives here the last time. The issue here is that in MeshGraphNet, EdgeBlock already applies the "residual" on the edge features, while the NodeBlock would expect then the features including the residual connection prior to message-passing while in GraphCast, all residual connections are only applied after both the updated edge and node features are computed (at least according to the paper).
  • What would you think of splitting the GraphCastNet into a GraphCastNetERA5 and a GraphCastNet model? The current issue I see with GraphCastNet is that it is very specific to the nature of the ERA5 dataset (e.g. when it comes to preparing the input and output to switch between the HxW layout and the typical "serial" graph layout. GraphCastNet then could be a rather data-agnostic model defining the operations on (g2m_graph, mesh_graph, m2g_graph), while GraphCastNetERA5 defines the things somewhat specific to the workload like checkpointing, input/output conversions, etc.. In the longer term, I think it really could make sense to try to make things a bit more modular. In particular, this also includes things like "history" or the actual "prediction" mode, i.e. whether GraphCastNetERA5 predicts y_t = f(x_t-1) or y_t = x_t - 1 + f(x_t-1). It could make sense if the "backbone" is agnostic to these things while having a specialized prediction wrapper.

stadlmax avatar May 21 '24 09:05 stadlmax

@mnabian Since you are revisiting GraphCast now, adding a few comments

  • Can we add the option to use transformer_engine.LayerNorm? In AIFS benchmarks, we just could get a 1.3x end-to-end improvement from doing so since the PyTorch implementation is rather bad for the sizes we encounter in these workloads.
  • Can you check whether the current combination of MeshGraphNodeBlock and MeshGraphEdgeBlock actually matches the paper (https://github.com/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_processor.py#L97-L98) I created a schematic of the GraphCast architecture for some Arch folks last week, and I think the order of residuals over the edges does not match the paper here. I might have made a mistake when trying to use shared primitives here the last time. The issue here is that in MeshGraphNet, EdgeBlock already applies the "residual" on the edge features, while the NodeBlock would expect then the features including the residual connection prior to message-passing while in GraphCast, all residual connections are only applied after both the updated edge and node features are computed (at least according to the paper).
  • What would you think of splitting the GraphCastNet into a GraphCastNetERA5 and a GraphCastNet model? The current issue I see with GraphCastNet is that it is very specific to the nature of the ERA5 dataset (e.g. when it comes to preparing the input and output to switch between the HxW layout and the typical "serial" graph layout. GraphCastNet then could be a rather data-agnostic model defining the operations on (g2m_graph, mesh_graph, m2g_graph), while GraphCastNetERA5 defines the things somewhat specific to the workload like checkpointing, input/output conversions, etc.. In the longer term, I think it really could make sense to try to make things a bit more modular. In particular, this also includes things like "history" or the actual "prediction" mode, i.e. whether GraphCastNetERA5 predicts y_t = f(x_t-1) or y_t = x_t - 1 + f(x_t-1). It could make sense if the "backbone" is agnostic to these things while having a specialized prediction wrapper.

Thanks @stadlmax , I'll add your comments to my epic and consider them all.

mnabian avatar May 21 '24 16:05 mnabian

Note to myself: API updates breaks GraphCast tests. Need to update them all.

mnabian avatar May 21 '24 17:05 mnabian

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup:https://github.com/NVIDIA/modulus/blob/main/modulus/models/gnn_layers/mesh_graph_mlp.py#L157... Did you also compare transformer_engine.LayerNorm with fused layernorm?

mnabian avatar May 21 '24 20:05 mnabian

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup (although I can't find it in the most recent code)... Did you also compare transformer_engine.LayerNorm with fused layernorm?

Yes, for AIFS, I found TE > APEX > PyTorch throughout a bunch of usual sizes AIFS had in their RFI benchmark. Especially the backward kernels in TE are much better for our cases. (reported numbers are runtimes, lower is better)

num_channels = 256

layer_norm_impl 1626240 x 256 327660 x 256 40962 x 256 542080 x 256 814540 x 256
apex 9.75127 2.03821 0.371149 3.32072 4.9402
pytorch 10.752 4.17265 0.957743 3.63721 10.2774
transformer_engine 2.59236 0.580879 0.801795 0.916124 1.33596

num_channels = 384

layer_norm_impl 1626240 x 384 327660 x 384 40962 x 384 542080 x 384 814540 x 384
apex 11.2164 2.3109 0.359366 3.79922 5.64847
pytorch 11.8419 4.33466 0.583828 3.99414 10.6802
transformer_engine 3.98762 0.849599 0.396184 1.38306 2.022

num_channels = 512

layer_norm_impl 1626240 x 512 327660 x 512 40962 x 512 542080 x 512 814540 x 512
apex 12.1739 2.50785 0.37578 4.11927 6.14573
pytorch 12.7752 4.5477 0.615464 4.30874 11.2191
transformer_engine 4.90182 1.04243 0.391352 1.6877 2.4967

stadlmax avatar May 21 '24 20:05 stadlmax

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup (although I can't find it in the most recent code)... Did you also compare transformer_engine.LayerNorm with fused layernorm?

Yes, for AIFS, I found TE > APEX > PyTorch throughout a bunch of usual sizes AIFS had in their RFI benchmark. Especially the backward kernels in TE are much better for our cases. (reported numbers are runtimes, lower is better)

num_channels = 256

layer_norm_impl 1626240 x 256 327660 x 256 40962 x 256 542080 x 256 814540 x 256 apex 9.75127 2.03821 0.371149 3.32072 4.9402 pytorch 10.752 4.17265 0.957743 3.63721 10.2774 transformer_engine 2.59236 0.580879 0.801795 0.916124 1.33596 num_channels = 384

layer_norm_impl 1626240 x 384 327660 x 384 40962 x 384 542080 x 384 814540 x 384 apex 11.2164 2.3109 0.359366 3.79922 5.64847 pytorch 11.8419 4.33466 0.583828 3.99414 10.6802 transformer_engine 3.98762 0.849599 0.396184 1.38306 2.022 num_channels = 512

layer_norm_impl 1626240 x 512 327660 x 512 40962 x 512 542080 x 512 814540 x 512 apex 12.1739 2.50785 0.37578 4.11927 6.14573 pytorch 12.7752 4.5477 0.615464 4.30874 11.2191 transformer_engine 4.90182 1.04243 0.391352 1.6877 2.4967

This is great comparison, thanks! I'll switch to te then. Do we have any reason to still keep fused layernorm from apex, or we should just remove it?

mnabian avatar May 21 '24 20:05 mnabian

This is great comparison, thanks! I'll switch to te then. Do we have any reason to still keep fused layernorm from apex, or we should just remove it?

I guess, no, not really. TE also should be decently covered when it comes to development specifically for Blackwell and beyond. I know a few POCs that try to optimize The LN in TE even further. If we are based on the DLFW containers, TE also should come pre-installed.

stadlmax avatar May 21 '24 20:05 stadlmax

@stadlmax added support for TE layernorm.

mnabian avatar May 21 '24 21:05 mnabian

Note to myself: API updates breaks GraphCast tests. Need to update them all.

Done

mnabian avatar May 21 '24 21:05 mnabian

/blossom-ci

mnabian avatar May 21 '24 21:05 mnabian

/blossom-ci

mnabian avatar May 21 '24 22:05 mnabian

/blossom-ci

mnabian avatar May 22 '24 17:05 mnabian

/blossom-ci

mnabian avatar May 22 '24 17:05 mnabian

/blossom-ci

mnabian avatar May 22 '24 18:05 mnabian

/blossom-ci

mnabian avatar May 22 '24 19:05 mnabian

Thanks for addressing the feedback, looks good to me.

stadlmax avatar May 22 '24 20:05 stadlmax

/blossom-ci

mnabian avatar May 22 '24 20:05 mnabian