pytorch model in CMSSW
- [x] ONNX export of the model from pytorch
- [ ] ONNX import to CMSSW
- [ ] native/direct pytorch import to CMSSW
- [ ] check compatibility of the results
ONNX export from pytorch currently doesn't work because the MLPF forward function expects pytorch-geometric style inputs, the padding is done internally if an attention/gnn-lsh based model is used.
def forward(self, batch):
# unfold the Batch object
if self.ssl:
input_ = batch.x.float()[:, : self.input_dim]
VICReg_embeddings = batch.x.float()[:, self.input_dim :]
else:
input_ = batch.x.float()
batch_idx = batch.batch
embeddings_id = []
embeddings_reg = []
need to break out the 3D padded forward function to a different kind of model.
Actually nevermind, changing the forward function as follows worked:
def forward(self, element_features, batch_idx):
# unfold the Batch object
if self.ssl:
input_ = element_features.float()[:, : self.input_dim]
VICReg_embeddings = element_features.float()[:, self.input_dim :]
else:
input_ = element_features.float()
embeddings_id = []
embeddings_reg = []
ONNX export now works for the GNN-LSH since https://github.com/jpata/particleflow/pull/215. Need to train a model with pytorch and test that the import in CMSSW gives reasonable results.
https://indico.cern.ch/event/1388888/contributions/5839133/attachments/2821898/4928058/2024_03_18%20ML%20production%20news.pdf
here's some material on how to integrate pytorch models directly via torchscript, rather than via ONNX.
Couple of notes:
- exporting the pytorch model to ONNX with dynamic axes does not actually produce a model that can be evaluated with dynamically sized inputs, because MHA uses
x.shape[0]internally, which gets converted into a static value during the torchscript export: https://github.com/pytorch/pytorch/issues/99701 (this was not the case when exporting from TF)- fixed by removing torch.no_grad: https://discuss.pytorch.org/t/multiheadattention-export-to-onnx-fails-when-using-torch-no-grad/198843
scaled_dot_product_attentionis only available in more recent onnxruntimes, CMSSW_12 does not support it -> need to upgrade the validation recipe to CMSSW_14 and confirm it still works- addressing it here: https://github.com/jpata/particleflow/pull/323
Here's the summary of today.
It's possible to export the model (both quantized and unquantized) with dynamic shapes using torch.onnx.export in #324.
However, scaled_dot_product_attention creates the inefficient fully unrolled attention implementation (i.e. naive or math version), so one attention layer looks something like this:
This results in somewhat slow runtimes and large memory usage:
timing/cpu_fp32.txt:Nelem=5120 mean_time=17029.90 ms stddev_time=126.80 ms mem_used=16672 MB
timing/gpu_fp16.txt:Nelem=5120 mean_time=85.91 ms stddev_time=10.96 ms mem_used=22884 MB
timing/gpu_fp32.txt:Nelem=5120 mean_time=134.03 ms stddev_time=20.34 ms mem_used=45426 MB
timing/gpu_int8.txt:Nelem=5120 mean_time=144.67 ms stddev_time=20.11 ms mem_used=45426 MB
timing/openvino_fp16.txt:Nelem=5120 mean_time=30045.50 ms stddev_time=2774.41 ms mem_used=31867 MB
timing/openvino_fp32.txt:Nelem=5120 mean_time=14351.32 ms stddev_time=642.40 ms mem_used=31592 MB
timing/openvino_int8.txt:Nelem=5120 mean_time=15503.07 ms stddev_time=60.57 ms mem_used=16661 MB
There is a special MultiHeadAttention op in ONNX contrib, but so far, I don't know how to convince torch / onnxscript to switch to it. https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MultiHeadAttention
Here's a potential example how to write the model by hand using onnxscript: https://github.com/microsoft/onnxruntime/issues/19924#issue-2187484945
Here's how the unfused vs. fused MHA looks like based on the example above
With this code
import torch
import time
import onnxruntime
import pathlib
import onnxscript
import onnx
import math
import numpy
dtype_map = {
numpy.dtype("float32"): onnx.TensorProto.FLOAT,
numpy.dtype("bool"): onnx.TensorProto.BOOL,
}
class Model(torch.nn.Module):
def forward(
self, query_states, key_states, value_states, mask
):
query_states = query_states
key_states = key_states
value_states = value_states
return torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
)
model = Model()
model.eval()
# (B, num_heads, N, head_dim)
query_states = torch.randn(1, 32, 4096, 64)
key_states = torch.randn(1, 32, 4096, 64)
value_states = torch.randn(1, 32, 4096, 64)
mask = torch.randn(1, 32, 4096, 1)
torch_out = model(query_states, key_states, value_states, mask)
print(torch_out.shape)
print(torch_out)
# Another reference perf comparison.
# torch.onnx.export(
# model,
# (query_states, key_states, value_states, mask),
# "sdpa.onnx",
# verbose=True,
# opset_version=14,
# input_names=["query_states", "key_states", "value_states", "mask"],
# )
model_dir = "multihead_attention"
fused_model_name = "multihead_attention.onnx"
fused_model_path = f"{model_dir}/{fused_model_name}"
unfused_model_name = "unfused_multihead_attention.onnx"
unfused_model_path = f"{model_dir}/{unfused_model_name}"
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = onnxscript.opset13
sqrt_head = math.sqrt(64)
query_states_ort = query_states.numpy()
key_states_ort = key_states.numpy()
value_states_ort = value_states.numpy()
attention_mask_ort = mask.numpy()
ort_inputs = {
"query_states": query_states_ort,
"key_states": key_states_ort,
"value_states": value_states_ort,
"mask": attention_mask_ort,
}
print(f"Benchmarking PT sdpa and ORT MultiHeadAttention...")
def run_pt():
# warmup
for _ in range(30):
model(query_states, key_states, value_states, mask)
total_time = 0
for _ in range(1000):
start_time = time.perf_counter()
model(query_states, key_states, value_states, mask)
total_time += time.perf_counter() - start_time
return total_time
total_time = run_pt()
print(
f"PT eager:"
)
print(f"Total time: {total_time:.2f}s")
def mha_onnx_model(query_states, key_states, value_states, mask):
# query_states = op.Reshape(query_states, shape=[1, 32, 128, 1, 64])
# key_states = op.Reshape(key_states, shape=[1, 32, 128, 1, 64])
# value_states = op.Reshape(value_states, shape=[1, 32, 128, 1, 64])
# qkv = op.Concat(query_states, key_states, value_states, axis=3)
query_states = op.Reshape(op.Transpose(query_states, perm=[0,2,1,3]), shape=[1,4096,2048])
key_states = op.Reshape(op.Transpose(key_states, perm=[0,2,1,3]), shape=[1,4096,2048])
value_states = op.Reshape(op.Transpose(value_states, perm=[0,2,1,3]), shape=[1,4096,2048])
output, _, _ = msft_op.MultiHeadAttention(
query_states,
key_states,
value_states,
num_heads=32,
)
output = op.Reshape(output, shape=[1, 4096, 32, 64])
output = op.Transpose(output, perm=[0,2,1,3])
return output
def unfused_onnx_model(
query_states, key_states, value_states, mask
):
scale = op.Constant(value_float=sqrt_head)
attn_weights = op.MatMul(query_states, op.Transpose(key_states, perm=[0, 1, 3, 2])) / scale
# attn_weights = op.Add(attn_weights, mask)
attn_weights = op.Softmax(attn_weights, axis=-1)
attn_output = op.MatMul(attn_weights, value_states)
return attn_output
def serialize_model(model_func, model_name, ort_inputs):
model_path = f"{model_dir}/{model_name}"
model_proto = onnxscript.script(
onnxscript.opset13, default_opset=onnxscript.opset13
)(model_func).to_model_proto()
for i, value in enumerate(ort_inputs.values()):
model_proto.graph.input[i].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=value.shape,
elem_type=dtype_map[value.dtype],
)
)
model_proto.graph.output[0].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=[1, 32, 4096, 64],
elem_type=onnx.TensorProto.FLOAT,
)
)
onnx.save(model_proto, model_path)
return model_proto, model_path
def save_tensor_data(numpy_tensor, output_path):
from onnx import numpy_helper
proto_tensor = numpy_helper.from_array(numpy_tensor)
with open(output_path, "wb") as f:
f.write(proto_tensor.SerializeToString())
def serialize_inputs_outputs(model_dir, onnx_inputs, onnx_outputs):
test_data_dir = pathlib.Path(f"{model_dir}/test_data_set_0")
test_data_dir.mkdir(parents=True, exist_ok=True)
for i, onnx_input in enumerate(onnx_inputs.values()):
save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
for i, onnx_output in enumerate(onnx_outputs):
save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
def run_ort(model_func, model_name, ort_inputs):
# Serialize model
model_proto, model_path = serialize_model(model_func, model_name, ort_inputs)
# Serialize inputs and outputs
sess = onnxruntime.InferenceSession(model_path)
ort_outputs = sess.run(None, ort_inputs)
# Parity
torch.testing.assert_close(
torch_out, torch.tensor(ort_outputs[0]), rtol=1e-3, atol=1e-3
)
serialize_inputs_outputs(model_dir, ort_inputs, ort_outputs)
# warmup
for _ in range(30):
sess.run(None, ort_inputs)
total_time = 0
for _ in range(10):
start_time = time.perf_counter()
sess.run(None, ort_inputs)
total_time += time.perf_counter() - start_time
print(
f"ORT {model_name}"
)
print(f"Total time: {total_time:.2f}s")
run_ort(unfused_onnx_model, unfused_model_name, ort_inputs)
run_ort(mha_onnx_model, fused_model_name, ort_inputs)
tt seems like at least on an M2 CPU the regular ONNX unfused scaled_dot_product_attention is fastest for a sequence len of 4096 elements:
Benchmarking PT sdpa and ORT MultiHeadAttention...
PT eager:
Total time: 660.70s
ORT unfused_multihead_attention.onnx
Total time: 11.15s
'MultiHeadAttention' is not a known op in 'com.microsoft'
ORT multihead_attention.onnx
Total time: 18.02s
pytorch sdpa:
onnx unfused:
onnx fused:
It looks as if the onnxruntime.transformer optimizer, specifically FusionAttention: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_attention.py#L712
should replace the attention block with the MultiHeadAttention that on GPU should support flash attention.
If I can replace this part (SDPA only):
with this (ignore the shapes):
then in principle it should be possible to try flash attention on the ONNX model.
Converting the model with the fused attention layer com.microsoft.MultiheadAttention to float16 does run flash attention on A100 with the expected speed and memory improvement.
The following code has batch size 1, sequence length 4096, num_heads 32, head_dim 64.
import torch
import time
import onnxruntime
import pathlib
import onnxscript
import onnx
import math
import numpy
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
num_iter = 1000
def get_mem_gpu_mb():
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.used / 1000 / 1000
dtype_map = {
numpy.dtype("float32"): onnx.TensorProto.FLOAT,
numpy.dtype("bool"): onnx.TensorProto.BOOL,
}
class Model(torch.nn.Module):
def forward(
self, query_states, key_states, value_states, mask
):
query_states = query_states
key_states = key_states
value_states = value_states
return torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
)
model = Model()
model.eval()
# (B, num_heads, N, head_dim)
query_states = torch.randn(1, 32, 4096, 64)
key_states = torch.randn(1, 32, 4096, 64)
value_states = torch.randn(1, 32, 4096, 64)
mask = torch.randn(1, 32, 4096, 1)
torch_out = model(query_states, key_states, value_states, mask)
print(torch_out.shape)
print(torch_out)
# Another reference perf comparison.
torch.onnx.export(
model,
(query_states, key_states, value_states, mask),
"sdpa.onnx",
verbose=True,
opset_version=18,
input_names=["query_states", "key_states", "value_states", "mask"],
)
model_dir = "multihead_attention"
fused_model_name = "multihead_attention.onnx"
fused_model_path = f"{model_dir}/{fused_model_name}"
unfused_model_name = "unfused_multihead_attention.onnx"
unfused_model_path = f"{model_dir}/{unfused_model_name}"
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = onnxscript.opset18
sqrt_head = math.sqrt(64)
query_states_ort = query_states.numpy()
key_states_ort = key_states.numpy()
value_states_ort = value_states.numpy()
attention_mask_ort = mask.numpy()
ort_inputs = {
"query_states": query_states_ort,
"key_states": key_states_ort,
"value_states": value_states_ort,
"mask": attention_mask_ort,
}
print(f"Benchmarking PT sdpa and ORT MultiHeadAttention...")
def run_pt():
# warmup
for _ in range(30):
model(query_states, key_states, value_states, mask)
total_time = 0
for _ in range(num_iter):
start_time = time.perf_counter()
model(query_states, key_states, value_states, mask)
total_time += time.perf_counter() - start_time
return total_time
total_time = run_pt()
print(
f"PT eager:"
)
print(f"Total time: {total_time:.2f}s")
def mha_onnx_model(query_states, key_states, value_states, mask):
# query_states = op.Reshape(query_states, shape=[1, 32, 128, 1, 64])
# key_states = op.Reshape(key_states, shape=[1, 32, 128, 1, 64])
# value_states = op.Reshape(value_states, shape=[1, 32, 128, 1, 64])
# qkv = op.Concat(query_states, key_states, value_states, axis=3)
query_states = op.Reshape(op.Transpose(query_states, perm=[0,2,1,3]), shape=[1,4096,2048])
key_states = op.Reshape(op.Transpose(key_states, perm=[0,2,1,3]), shape=[1,4096,2048])
value_states = op.Reshape(op.Transpose(value_states, perm=[0,2,1,3]), shape=[1,4096,2048])
output, _, _ = msft_op.MultiHeadAttention(
query_states,
key_states,
value_states,
num_heads=32,
)
output = op.Reshape(output, shape=[1, 4096, 32, 64])
output = op.Transpose(output, perm=[0,2,1,3])
return output
def unfused_onnx_model(
query_states, key_states, value_states, mask
):
scale = op.Constant(value_float=sqrt_head)
attn_weights = op.MatMul(query_states, op.Transpose(key_states, perm=[0, 1, 3, 2])) # / scale
# attn_weights = op.Add(attn_weights, mask)
attn_weights = op.Softmax(attn_weights, axis=-1)
attn_output = op.MatMul(attn_weights, value_states)
return attn_output
def serialize_model(model_func, model_name, ort_inputs):
model_path = f"{model_dir}/{model_name}"
model_proto = onnxscript.script(
onnxscript.opset18, default_opset=onnxscript.opset18
)(model_func).to_model_proto()
for i, value in enumerate(ort_inputs.values()):
model_proto.graph.input[i].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=value.shape,
elem_type=dtype_map[value.dtype],
)
)
model_proto.graph.output[0].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=[1, 32, 4096, 64],
elem_type=onnx.TensorProto.FLOAT,
)
)
onnx.save(model_proto, model_path)
return model_proto, model_path
def save_tensor_data(numpy_tensor, output_path):
from onnx import numpy_helper
proto_tensor = numpy_helper.from_array(numpy_tensor)
with open(output_path, "wb") as f:
f.write(proto_tensor.SerializeToString())
def serialize_inputs_outputs(model_dir, onnx_inputs, onnx_outputs):
test_data_dir = pathlib.Path(f"{model_dir}/test_data_set_0")
test_data_dir.mkdir(parents=True, exist_ok=True)
for i, onnx_input in enumerate(onnx_inputs.values()):
save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
for i, onnx_output in enumerate(onnx_outputs):
save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
def run_ort(model_func, model_name, ort_inputs):
# Serialize model
model_proto, model_path = serialize_model(model_func, model_name, ort_inputs)
from onnxconverter_common import float16
model = onnx.load(model_path)
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, model_path)
# Serialize inputs and outputs
sess = onnxruntime.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
ort_inputs_fp16 = {k: float16.convert_np_to_float16(v) for k, v in ort_inputs.items()}
ort_outputs = sess.run(None, ort_inputs_fp16)
# Parity
# torch.testing.assert_close(
# torch_out, torch.tensor(ort_outputs[0]), rtol=1e-3, atol=1e-3
# )
serialize_inputs_outputs(model_dir, ort_inputs_fp16, ort_outputs)
# warmup
for _ in range(30):
sess.run(None, ort_inputs_fp16)
total_time = 0
mem_used = []
for _ in range(num_iter):
start_time = time.perf_counter()
sess.run(None, ort_inputs_fp16)
mem_used.append(get_mem_gpu_mb())
total_time += time.perf_counter() - start_time
print(
f"ORT {model_name}"
)
max_mem = numpy.max(mem_used)
print(f"Total time: {total_time:.2f}s, mem: {max_mem:.0f}MB")
run_ort(unfused_onnx_model, unfused_model_name, ort_inputs)
run_ort(mha_onnx_model, fused_model_name, ort_inputs)
gives:
#this is on CPU
PT eager:
Total time: 79.24s
#this is on A100
ORT unfused_multihead_attention.onnx
Total time: 11.96s, mem: 5771MB
ORT multihead_attention.onnx
Total time: 9.60s, mem: 1680MB
In 7a301da I got the MultiheadAttention OP spliced into the graph using the onnxscript and torch.onnx.export (TorchScript) approach.
In e40f5c3 the ONNX export now works with dynamic shapes, f32/fp16, using com.microsoft.MultiheadAttention (that can use Flash Attention on GPU), and the pytorch and onnx versions return the same values. The timings and mem usage on fp16 are great now (tested on A100):
timing/gpu_fp16.txt:Nelem=5120 mean_time=14.26 ms stddev_time=0.45 ms mem_used=1946 MB
timing/gpu_fp32.txt:Nelem=5120 mean_time=122.88 ms stddev_time=6.46 ms mem_used=12274 MB
One self-attention block looks something like this:
and inside SDPA:
The trick was that pytorch needs (batch, seq_len, num_heads, head_dim), while com.microsoft.MultiheadAttention needs (batch, seq_len, num_heads*head_dim).
Actually ONLY the MultiHeadAttention op needs to run in fp16:
@onnxscript.script(custom_opset)
def SDPA(
query: TFloat,
key: TFloat,
value: TFloat,
) -> TFloat:
# Unlike pytorch scaled_dot_product_attention,
# the input here MUST BE (batch, seq_len, num_head*head_dim).
# Also, for the op to be fast on GPU, it needs to run in float16.
query = op.Cast(query, to=onnx.TensorProto.FLOAT16)
key = op.Cast(key, to=onnx.TensorProto.FLOAT16)
value = op.Cast(value, to=onnx.TensorProto.FLOAT16)
output, _, _ = msft_op.MultiHeadAttention(query, key, value, num_heads=NUM_HEADS)
output = op.Cast(output, to=onnx.TensorProto.FLOAT)
return output
Then the outputs are basically equivalent to the base model (at least on CPU).
Importing the new model in CMSSW still todo. Need #323 merged and some results from it first and then do the CMSSW updates on top of https://github.com/jpata/cmssw/releases/tag/pfanalysis_caloparticle_CMSSW_14_1_0_pre3_acat2022.
The required changes in the CMSSW side to import the new ONNX model are here: https://github.com/jpata/cmssw/commit/3d5455b8fa310af1cb7aa5ee8d0426f9b4353f84
It runs and produces nonzero/nongarbage outputs. Submitted jobs on CPU, will see validations soon.
A couple of tracks being produced
ielem=1203 inputs:0=1 1=0.369013 2=0.492823 3=0.240846 4=0.970563 5=0.41474 6=0 7=0 8=1 9=0 10=0 11=0 12=2.98601 13=-2.15722 14=0 15=0 16=0 17=0.358151 18=0.0888752 19=0.18931 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=8 29=0 30=0 31=0 32=0.0194093 33=0.0901159 34=3.56876 35=0.0072579 36=0.00452945 37=0.00699194 38=0.474006 39=0.00403006 40=1.09679 41=0.00403006 42=0 43=0 44=0 45=0 46=0 47=0 48=0 49=0 50=0 51=0 52=0 53=0 54=0
ielem=1203 pred: pid=211 E=0.434683 pt=0.365148 eta=0.49956 phi=0.233267 charge=1
ielem=1204 inputs:0=1 1=0.25898 2=1.00609 3=0.542937 4=0.839773 5=0.401488 6=0 7=0 8=1 9=0 10=0 11=0 12=2.90243 13=-0.830946 14=0 15=0 16=0 17=0.217484 18=0.14061 19=0.306793 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=5 29=0 30=0 31=0 32=0.0252961 33=0.0149914 34=1.73247 35=0.00628026 36=0.00790802 37=0.00939261 38=0.869708 39=0.00510107 40=0.701089 41=0.00510107 42=0 43=0 44=0 45=0 46=0 47=0 48=0 49=0 50=0 51=0 52=0 53=0 54=0
ielem=1204 pred: pid=211 E=0.428164 pt=0.261295 eta=1.00512 phi=0.57336 charge=1
and a couple of neutrals
ielem=1858 inputs:0=8 1=2.61464 2=3.93908 3=-0.258819 4=-0.965926 5=67.1841 6=11 7=1 8=0 9=0 10=0 11=0 12=0 13=0 14=0 15=0 16=0 17=-42.7738 18=-11.4612 19=1137 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=3 29=0 30=-1 31=-1 32=0 33=0 34=0 35=38.7886 36=0.614548 37=0.705366 38=0 39=0 40=0 41=0 42=7.05472 43=0 44=0.487688 45=0 46=0 47=0 48=0.225962 49=0 50=0 51=0 52=30.1196 53=10.016 54=9.34346e-11
ielem=1858 pred: pid=2 E=65.0437 pt=2.45356 eta=3.97031 phi=-2.88292 charge=0
ielem=1859 inputs:0=8 1=0.774559 2=3.06393 3=-0.573577 4=0.819152 5=8.31035 6=11 7=1 8=0 9=0 10=0 11=0 12=0 13=0 14=0 15=0 16=0 17=87.1876 18=-61.0494 19=1137 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=1 29=0 30=-1 31=-1 32=0 33=0 34=0 35=0 36=0 37=0 38=0 39=0 40=0 41=0 42=8.5 43=0 44=0 45=0 46=0 47=0 48=0 49=0 50=0 51=0 52=0 53=0 54=0
ielem=1859 pred: pid=2 E=8.56867 pt=0.709752 eta=3.18238 phi=-0.612545 charge=0
Here I managed to make the CMSSW ONNX GPU inference work, I think: https://github.com/jpata/cmssw/commit/36be715fa00457c310acae3c033f4788bd47a26b
CPU PF:
log_cpu_pf.txt:TimeModule> 35002 1 particleFlowTmp PFProducer 0.00893436
log_cpu_pf.txt:TimeModule> 35005 1 particleFlowTmp PFProducer 0.00696006
log_cpu_pf.txt:TimeModule> 35001 1 particleFlowTmp PFProducer 0.0205714
log_cpu_pf.txt:TimeModule> 35004 1 particleFlowTmp PFProducer 0.0115013
log_cpu_pf.txt:TimeModule> 35003 1 particleFlowTmp PFProducer 0.010012
CPU MLPF:
log_cpu.txt:TimeModule> 35002 1 particleFlowTmp MLPFProducer 9.4116
log_cpu.txt:TimeModule> 35005 1 particleFlowTmp MLPFProducer 8.02389
log_cpu.txt:TimeModule> 35001 1 particleFlowTmp MLPFProducer 13.4437
log_cpu.txt:TimeModule> 35004 1 particleFlowTmp MLPFProducer 10.4151
log_cpu.txt:TimeModule> 35003 1 particleFlowTmp MLPFProducer 12.1385
GPU MLPF (A100, 1 event per batch):
log_gpu.txt:TimeModule> 35002 1 particleFlowTmp MLPFProducer 0.177305
log_gpu.txt:TimeModule> 35005 1 particleFlowTmp MLPFProducer 0.0156437
log_gpu.txt:TimeModule> 35001 1 particleFlowTmp MLPFProducer 0.0187983
log_gpu.txt:TimeModule> 35004 1 particleFlowTmp MLPFProducer 0.0158696
log_gpu.txt:TimeModule> 35003 1 particleFlowTmp MLPFProducer 0.0171756
All done, moved to CMSSW_14 and updated the C++ inference code.