executorch icon indicating copy to clipboard operation
executorch copied to clipboard

executorch model Inference time is higher than the torch model

Open tdasika17 opened this issue 8 months ago • 14 comments

I have a model object, converted it to pte model with xnn backend using below: exported_graph = export(model, inp) # Core Aten graph torch.export.save(exported_program, 'model.pt2') edge = to_edge(exported_graph) # Edge Dialect edge_delegated = edge.to_backend(XnnpackPartitioner()) #using xnnbackend executorch_program = edge_delegated.to_executorch() # with open("model.pte", "wb") as file: file.write(executorch_program.buffer)

then used it in C++ frontend to run my llm application in similar lines of example Application has executorch in the third-party folder.

I want help with two things,

  1. executorch runtime is taking more time ~16 seconds, where as torch inference would run in around 1.3 seconds. I want some help in improving the inference times. I can share pte graph log in private if needed If fusing ops / removing ops will help in reducing time. <bound method EdgeProgramManager.exported_program of <executorch.exir.program._program.EdgeProgramManager object at 0x77ae4cff0640>> graph():
  2. I want to know If I can selectively build based on the ops needed by graph. I could see that my exported graph (pt2) has some around 16 aten ops. How should I delegate it to backend as it may have different operator set? or is it taken care by selective print gen_selected_ops function based on arguments given?

How ever I'm unable to build selectively only based on the ops, I would appreciate some help here too. Below is part of my cmakelist to add selected ops and include library to target set(_kernel_lib) gen_selected_ops(LIB_NAME "select_build_lib" "" ROOT_OPS "aten::add.out" INCLUDE_ALL_OPS "OFF") generate_bindings_for_kernels(LIB_NAME "select_build_lib" FUNCTIONS_YAML ${EXECUTORCH_ROOT}/kernels/portable/functions.yaml) gen_operators_lib(LIB_NAME "select_build_lib" KERNEL_LIBS ${_kernel_lib} DEPS executorch) target_link_libraries( my_app PRIVATE executorch extension_module_static extension_tensor xnnpack_backend select_build_lib)

cc @digantdesai @mcr229 @cbilgin

tdasika17 avatar Apr 18 '25 10:04 tdasika17

Hi @tdasika17,

As a first step, can you try a few things:

  • Switch from to_backend() to to_edge_transform_and_lower(). Here's an example: https://pytorch.org/executorch/0.6/backends-xnnpack.html#using-the-xnnpack-backend.
  • Make sure you are testing with a release build (full optimizations).
  • Try setting the thread count to 4. Call this before running the model.
#include <executorch/extension/threadpool.h>
...
::executorch::extension::threadpool::get_threadpool()->_unsafe_reset_threadpool(4);

There are a few other suggestions here: https://pytorch.org/executorch/0.6/using-executorch-faqs.html#inference-is-slow-performance-troubleshooting. If this doesn't solve it, we can do operator-level profiling to drill down.

GregoryComer avatar Apr 18 '25 20:04 GregoryComer

Hi @GregoryComer ,

The timings I reported was with to_edge_transform_and_lower(). While experimenting I tried to_backend() and the inference time is around ~27 seconds here.

So, My original observation of ~16 sec time is with below, I have also set threadcount. executorch_program = to_edge_transform_and_lower( exported_program, partitioner=[XnnpackPartitioner()] ).to_executorch()

Also used release build type by specifying, cmake .. -DCMAKE_BUILD_TYPE=Release

I'm not sure on how to debug this further. Any help would be greatly appreciated.

Thanks.

tdasika17 avatar Apr 21 '25 06:04 tdasika17

https://pytorch.org/executorch/0.6/runtime-profiling.html

you can try doing some operator profiling through here. On the other hand, if you can share a flame graph of the model run, that might also be helpful for understanding the performance: https://github.com/brendangregg/FlameGraph

mcr229 avatar Apr 21 '25 18:04 mcr229

pte_model_profiling.txt Attaching the model profiling information.

tdasika17 avatar Apr 22 '25 10:04 tdasika17

great! it looks like the issue is that there are aten.mm.default nodes which aren't getting delegated to xnnpack?

Are you using torch.mm between two dynamic inputs? As in one of the inputs is not a constant weight tensor?

mcr229 avatar Apr 22 '25 19:04 mcr229

Ah, Thanks..!! I fixed that, modified graph to move 'aten_mm_default' ops to use 'aten_mul_tensor' instead, Now the time came to 3.5 seconds. Are there any other ops from the list 'occurrences_in_non_delegated_graphs' that can be delegated to backend?

tdasika17 avatar Apr 23 '25 06:04 tdasika17

it looks like concatenate could potentially be lowered? I think the only reason it would fail to be lowered if it wasn't a float value. Based on the timings, it looks the execute takes on 340ms btw, is this profiling run the one for the 16s run? or the 3.5s run?

mcr229 avatar Apr 23 '25 19:04 mcr229

Another suggestion is building with the optimized operator library on: https://github.com/pytorch/executorch/blob/main/CMakeLists.txt#L213

there are some ops that fall through XNNPACK, and are run on executorch (native_layer_norm) which can be accelerated.

mcr229 avatar Apr 23 '25 19:04 mcr229

Hi @mcr229, I have used this option in my build already option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON) and linked the library 'optimized_native_cpu_ops_lib' to my app.

pte_model_profiling_3seconds.txt

Attaching the profiling of model with 3.5seconds. Please have a look at the 'occurrences_in_non_delegated_graphs' in the latest attached and let me know if any ops can be lowered.

Thanks,

tdasika17 avatar Apr 24 '25 07:04 tdasika17

hi @tdasika17 , the profiling you shared is suggesting that running the model is only taking 105ms (see the execute row at the bottom). Are your profiling elsewhere to find the 3.5 seconds?

mcr229 avatar Apr 24 '25 18:04 mcr229

The model trace (.etdp) is collected at actual model inference, below is the code snippet.

Module model(model_path, Module::LoadMode::MmapUseMlockIgnoreErrors, std::move(etdump_gen_)); 
vector<int64_t> output_ids = generate(model, input_tokens);
ETDumpGen* etdump_gen = static_cast<ETDumpGen*>(model.event_tracer());
ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks());
ETDumpResult result = etdump_gen->get_etdump_data();
if (result.buf != nullptr && result.size > 0) {
		// On a device with a file system, users can just write it to a file.
		FILE* f = fopen("etdump.etdp", "w+");
		fwrite((uint8_t*)result.buf, 1, result.size, f);
		fclose(f);
		free(result.buf);
	}

Later, I used this .etdp to profile the model,

inspector = Inspector(etdump_path="./etdump.etdp")
inspector.print_data_tabular()

I'm attaching the latest log, I see few warnings saying "No delegate mapping found for delegate with instruction id " Please refer to the latest log attatched.

model_profiling_latest.txt

tdasika17 avatar Apr 25 '25 03:04 tdasika17

hi i saw this line in the profiling:

python ssm_single_forward_gen.py -weights ../../models/
Python model setup time: 3.759671 seconds

is this where the 3.7 seconds you're getting is coming from? Could you share what is being run that is being measured at 3.759671?

mcr229 avatar Apr 25 '25 19:04 mcr229

Hi,

Sorry for the confusion. That is another python script to generate pt.model and export it.

The 3.5 second, that I'm talking is purely c++ inference time,

Which is timed just before and after the genrate method call.

The generate method has nothing but the pure inference logic, to infer each tokens by calling model.forward and collect logits.

On Sat, 26 Apr, 2025, 1:09 am Max Ren, @.***> wrote:

mcr229 left a comment (pytorch/executorch#10297) https://github.com/pytorch/executorch/issues/10297#issuecomment-2831290490

hi i saw this line in the profiling:

python ssm_single_forward_gen.py -weights ../../models/ Python model setup time: 3.759671 seconds

is this where the 3.7 seconds you're getting is coming from? Could you share what is being run that is being measured at 3.759671?

— Reply to this email directly, view it on GitHub https://github.com/pytorch/executorch/issues/10297#issuecomment-2831290490, or unsubscribe https://github.com/notifications/unsubscribe-auth/BRT3NJLJGPTIYA7DYSBZQ5T23KFPDAVCNFSM6AAAAAB3MZLE5CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDQMZRGI4TANBZGA . You are receiving this because you were mentioned.Message ID: @.***>

tdasika17 avatar Apr 26 '25 00:04 tdasika17

i see, is what you're suggesting the measured 3.5 seconds is the entire ::generate call? Which would likely be running inference multiple times for prefill and decode? That would clear up my misunderstanding between the profiling info and the observed timings.

Based on the most recent logs, it still looks like many matrix multiplication layers are still being undelegated. Would it be possible to share how those .mm are used? What are the dtypes? Is one of the inputs constant (weight tensor)?

mcr229 avatar Apr 28 '25 01:04 mcr229

@tdasika17 just wanted to follow up to make sure we have a path forward for this. Let me know if you're still encountering issues with the high inference times.

mcr229 avatar May 01 '25 19:05 mcr229

Hi @mcr229 ,

yes, the time is computed for overall inference, which contains injesion phase + output token generation, post processing of o/p tokens. I have converted a .pt model to .pte model as explained in the initial description.

I'm still seeing the high inference times, which is around 3.5 seconds, where as for the same model run using torch the inference times are around 1.5 S on the same architecture (x86_64 system).

Again, Below is the ops list,

  op_type occurrences_in_delegated_graphs occurrences_in_non_delegated_graphs
0 aten_add_tensor 120 0
1 aten_arange_start_step 0 80
2 aten_copy_default 0 80
3 aten_embedding_default 0 1
4 aten_eq_scalar 0 80
5 aten_expand_copy_default 0 80
6 aten_linear_default 2 0
7 aten_mul_tensor 280 0
8 aten_native_layer_norm_default 0 40
9 aten_relu_default 1 0
10 aten_select_copy_int 0 440
11 aten_sigmoid_default 40 0
12 aten_squeeze_copy_dims 0 40
13 aten_sub_tensor 40 0
14 aten_sum_dim_int_list 0 80
15 aten_unsqueeze_copy_default 0 80
16 aten_view_copy_default 0 80
17 aten_where_self 0 80
18 getitem 0 240
19 Total 483 1401

Please help me in diagnosing what more ops can be delegated to XNNPACK to improve inference times.

tdasika17 avatar May 07 '25 09:05 tdasika17

@tdasika17 could you share the profiled timings associated with the above table? The previous timings .txt files show that there were native_call_mm.out, however the ops lists shared don't have native_call_mm.out. based on the profiling information the mm.out is taking 68% of the inference time. As for further delegation, based on the ops provide, all the possible ops have been delegated, so I don't think there is any additional lowering potential. For faster inference I could only suggest quantization.

mcr229 avatar May 09 '25 21:05 mcr229

@mcr229 ,

Here is the profiling of model during real inference. Generated etdump.bin while exporting, and used dev tools to generate inspector.txt. Attaching the excel related to the above table.

inspector_out.xlsx

tdasika17 avatar Jun 13 '25 08:06 tdasika17

Image

Hi @tdasika17 , based on your timings it looks like the model inference is only taking 71ms, and model load is around 167ms. I believe this should be significantly faster than pytorch? I know that you've mentioned measuring 3.5s but I can't seem to see that from the profiling logs. Are you measuring the inference (3.5s) on a separate device/environment from where you're getting these profilings?

mcr229 avatar Jun 13 '25 17:06 mcr229

Hi @mcr229 , The profiling is done on the same machine.

When I meant the inference time ~3.5 Sec, I timed it out before and after the generate method.

My concern is that for the same llm model, The generate method is taking around 1.3 S when using pytorch model and 3.5 S for pte model.

I want to know if this is expected and how to validate the inference time I'm getting is correct or not.

The inference time is total time taken for the rag-llm model to generate the output. In both torch and executorch I used exact same question.

Please help me to find If there is any benchmark or any other way to validate the inference times

tdasika17 avatar Jun 13 '25 18:06 tdasika17

@tdasika17 thanks for the clarification!

The executorch model should be expected to be faster than the pytorch model on CPU, but it might be helpful to share how the pytorch benchmark is being run, and how you're running the model. Generating a flamegraph as well will also help us understand what kernels the model is spending the most time in. You can take a look at how to generate one from here: https://github.com/brendangregg/FlameGraph.

mcr229 avatar Jun 13 '25 20:06 mcr229

Hi @mcr229 , Today, I tried executorch model on a different x86 server. I got different inference time here for the same application (~7.8 sec), this may be because of the different clock frequencies. Attaching the FlameGraph generated.

I will try to comeup with the FlameGraph on same machine, somehow I couln't use that machine today. Meanwhile, Can you see if you can infer something on this?

Image

tdasika17 avatar Jun 16 '25 10:06 tdasika17

@tdasika17 yes this is definitely useful. I think the graph you sent is for a single thread. Would it be possible to share the entire .svg file? this would definitely help with debugging the executorch model. If you could also share the PyTorch inference, that will also give us a lot of insight as well. Thanks!

mcr229 avatar Jun 16 '25 18:06 mcr229

Sure @mcr229 , Here I'm attaching svg for both torch and executorch.

Image Image

tdasika17 avatar Jun 17 '25 09:06 tdasika17

hi @mcr229 , Can you please help me understand the differences b/w inference times in torch and executorch and why executorch is taking more time? Do you need any additional info part from the flamegraphs provided for both.

tdasika17 avatar Jun 20 '25 11:06 tdasika17

would it be possible to share the flame graph files? it'll help with inspecting the call stacks. On cursory look at these, I can't immediately tell what the discrepancy is between the inference times.

For the torch model, could you share how you're lowering, running and, profiling the model?

mcr229 avatar Jun 24 '25 00:06 mcr229

Hi @mcr229 ,

I have used below snippet to convert and lower torchscript to executorch.

#convert.py partitioner=[XnnpackPartitioner()] # Lower and partition for ExecuTorch exported_program = export(model, inp) executorch_program = to_edge_transform_and_lower( exported_program, partitioner=partitioner, ) executorch_program = executorch_program.to_executorch()

For Torch model, I just ran the Torch c++ application and captured the graphs for C++ application execution, similarily for ExecuTorch program. Attaching the

torch_et_fg.zip

flamegraphs

tdasika17 avatar Jun 24 '25 06:06 tdasika17

For Torch model, I just ran the Torch c++ application and captured the graphs for C++ application execution

what do you mean by this? What is the capture flow? How are you loading and running the captured graph in c++?

mcr229 avatar Jun 24 '25 18:06 mcr229

For Torch model, I just ran the Torch c++ application and captured the graphs for C++ application execution

what do you mean by this? What is the capture flow? How are you loading and running the captured graph in c++?

I meant, I have a simple minded C++ application, that loads the Torch script model using jit.load and calls model.generate on the encoded input and then decode the output tokens

tdasika17 avatar Jun 24 '25 19:06 tdasika17

For Torch model, I just ran the Torch c++ application and captured the graphs for C++ application execution, similarly for ExecuTorch program

Thanks for the FlameGraph and other detailed information about profiling.

Skimming the attached FlameGraphs, and focusing on the simple_ssm func and its children, I saw,

  • On ET side we are spending a lot of time on element wise multiplication which I don't see on the torchscript side.
  • Unlike torchscript side, I don't see GEMM kernel on ET side (are you exporting with dynamic shape enabled? I would expect would use GEMM kernel (16x32 for example) for prefill and GEMV (1x32) for decode, I only see GEMV kernel).
    • There is a large-ish unknown function on ET side, not sure what's in there
  • Lastly no large elementwise mul on torchscript side.

digantdesai avatar Jun 25 '25 03:06 digantdesai