loopy icon indicating copy to clipboard operation
loopy copied to clipboard

Poor scaling with many calls to add_prefetch

Open nchristensen opened this issue 2 years ago • 3 comments

Even with #755, attempting to prefetch many arrays scales poorly. By the 19th add_prefetch operation it takes around 5 seconds for add_prefetch to complete on one fused Mirgecom kernels with 100+ einsums. Profiling shows a lot of time is spent in get_grid_sizes_for_insn_ids_as_dicts.

add_prefetch time
0.10026049613952637
0.11436939239501953
0.13449859619140625
0.15656232833862305
0.18714141845703125
0.22336697578430176
0.2757580280303955
0.34452342987060547 
0.4413919448852539
0.5716948509216309 
0.7412521839141846
0.9670014381408691 
1.2657301425933838 
1.6620988845825195 
2.180263042449951 
2.8522701263427734 
3.696044683456421 
4.755332946777344 
6.091721773147583 
         452893 function calls (368789 primitive calls) in 8.962 seconds

   Ordered by: cumulative time
   List reduced from 557 to 30 due to restriction <30>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    8.962    8.962 __init__.py:779(prefetch_and_project)
        1    0.000    0.000    8.837    8.837 data.py:302(add_prefetch)
        1    0.000    0.000    8.837    8.837 data.py:153(add_prefetch_for_single_kernel)
        1    0.001    0.001    8.784    8.784 precompute.py:353(precompute_for_single_kernel)
2879/1595    0.003    0.000    8.018    0.005 __init__.py:752(wrapper)
      5/1    0.000    0.000    7.958    7.958 tools.py:800(assign_automatic_axes)
       44    0.002    0.000    7.890    0.179 __init__.py:802(get_iname_bounds)
        1    0.000    0.000    7.818    7.818 __init__.py:1031(get_grid_size_upper_bounds_as_exprs)
        1    0.000    0.000    7.818    7.818 __init__.py:990(get_grid_sizes_for_insn_ids_as_exprs)
        1    0.000    0.000    7.817    7.817 __init__.py:939(get_grid_sizes_for_insn_ids)
        1    0.002    0.002    7.817    7.817 __init__.py:845(get_grid_sizes_for_insn_ids_as_dicts)
     6948    7.439    0.001    7.482    0.001 __init__.py:925(wrapper)
       92    0.001    0.000    7.469    0.081 tools.py:352(op)
       46    0.000    0.000    3.740    0.081 tools.py:364(dim_min)
       46    0.038    0.001    3.729    0.081 tools.py:343(_get_dim_min)
       46    0.000    0.000    3.729    0.081 tools.py:370(dim_max)
       46    0.043    0.001    3.719    0.081 tools.py:339(_get_dim_max)
68309/37887    0.635    0.000    0.656    0.000 __init__.py:936(wrapper)
        1    0.001    0.001    0.512    0.512 array_buffer_map.py:196(__init__)
      154    0.002    0.000    0.479    0.003 __init__.py:1263(align_spaces)
       54    0.000    0.000    0.467    0.009 __init__.py:1312(align_two)
        1    0.000    0.000    0.467    0.467 array_buffer_map.py:173(compute_bounds)
        1    0.000    0.000    0.456    0.456 array_buffer_map.py:162(find_var_base_indices_and_shape_from_inames)
        1    0.000    0.000    0.456    0.456 array_buffer_map.py:165(<listcomp>)
        2    0.000    0.000    0.456    0.228 tools.py:379(base_index_and_length)
      462    0.069    0.000    0.434    0.001 __init__.py:1182(_align_dim_type)
       81    0.001    0.000    0.321    0.004 __init__.py:801(expr_like_add)
       84    0.313    0.004    0.320    0.004 __init__.py:769(_number_to_expr_like)
     6773    0.031    0.000    0.220    0.000 __init__.py:1178(_set_dim_id)
       49    0.000    0.000    0.215    0.004 __init__.py:1061(obj_project_out_except)

nchristensen avatar Feb 28 '23 03:02 nchristensen

Passing default_tag=None to add_prefetch and parallelizing the prefetch by explicitly calling split_inames might help.

kaushikcfd avatar Mar 01 '23 09:03 kaushikcfd

Also, if the workload is coming from Mirge-Com, it might be useful to evaluate if such big batched einsums are relevant. See https://github.com/illinois-ceesd/mirgecom/issues/777 for context.

kaushikcfd avatar Mar 01 '23 17:03 kaushikcfd

Good point. I'll dump out the kernels for the current y3 driver to see if anything has changed in terms of batch sizes.

In any case, just setting default_tag=None doesn't affect the overall scaling. It does change the profiling results somewhat.

   Ordered by: cumulative time
   List reduced from 467 to 30 due to restriction <30>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    8.846    8.846 __init__.py:791(prefetch_and_project)
        1    0.000    0.000    6.842    6.842 data.py:302(add_prefetch)
        1    0.000    0.000    6.842    6.842 data.py:153(add_prefetch_for_single_kernel)
        1    0.001    0.001    6.761    6.761 precompute.py:353(precompute_for_single_kernel)
35274/7782    4.665    0.000    4.703    0.001 __init__.py:936(wrapper)
        1    0.001    0.001    3.908    3.908 array_buffer_map.py:196(__init__)
        1    0.000    0.000    3.814    3.814 array_buffer_map.py:173(compute_bounds)
        1    0.000    0.000    3.788    3.788 array_buffer_map.py:162(find_var_base_indices_and_shape_from_inames)
        1    0.000    0.000    3.788    3.788 array_buffer_map.py:165(<listcomp>)
        2    0.000    0.000    3.788    1.894 tools.py:379(base_index_and_length)
       44    2.102    0.048    2.107    0.048 __init__.py:769(_number_to_expr_like)
       44    0.000    0.000    2.106    0.048 __init__.py:801(expr_like_add)
        3    0.000    0.000    2.059    0.686 __init__.py:1061(obj_project_out_except)
        2    0.000    0.000    2.004    1.002 translation_unit.py:677(_collective_transform)
        1    0.001    0.001    2.004    2.004 decouple_domain.py:38(decouple_domain)
        4    0.000    0.000    0.852    0.213 tools.py:352(op)
        2    0.779    0.390    0.780    0.390 isl_helpers.py:576(find_max_of_pwaff_with_params)
        2    0.000    0.000    0.479    0.240 tools.py:370(dim_max)
        2    0.478    0.239    0.478    0.239 tools.py:339(_get_dim_max)
 6450/630    0.012    0.000    0.435    0.001 __init__.py:256(__call__)
1636/1166    0.002    0.000    0.422    0.000 __init__.py:752(wrapper)
      153    0.001    0.000    0.401    0.003 instruction.py:858(with_transformed_expressions)
      119    0.001    0.000    0.389    0.003 symbolic.py:134(map_reduction)
  190/162    0.001    0.000    0.386    0.002 __init__.py:524(map_sum)
  190/162    0.000    0.000    0.385    0.002 __init__.py:525(<listcomp>)
        1    0.000    0.000    0.379    0.379 precompute.py:302(map_kernel)
        5    0.000    0.000    0.379    0.076 symbolic.py:1370(__call__)
        2    0.000    0.000    0.379    0.189 precompute.py:320(<lambda>)
        2    0.000    0.000    0.378    0.189 symbolic.py:1314(map_call)
        1    0.000    0.000    0.378    0.378 precompute.py:227(map_substitution)

nchristensen avatar Mar 06 '23 19:03 nchristensen