Add resnet50 benchmark (#443)
Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
What does this PR do?
Fixes #443
Modifications of the original resnet50 in torchvision:
- replace
out += identitywithout = out + identity - ReLU(replace=False)
- set
num_batches_tracked = Nonefor BatchNorm (a workaround sinceadd_is utilized innn.BatchNormwhennum_batches_trackedis not None).
adaptive_avg_pool2d is added (#363 ) needs: #365
@kiya00 could you try resnet50 w/o modifications? I expect it to work following https://github.com/Lightning-AI/lightning-thunder/issues/633
@kiya00 could you try resnet50 w/o modifications? I expect it to work following #633
sure, I'll try the original one
Hi @jjsjann123 , there's an error when run pytest thunder/benchmarks/targets.py -k test_resnet50[backward-thunder] that could be related to NumberProxy, nvfuserex_impl.py doesn't seem to be able to handle shape list with mixed IntegerProxy and int, is this a known problem?
[...] = nvFusion1(....., i3233, ........)
...
# t14266 = prims.reshape(t14251, (i3233, 512, 1, 7, 7)) # t14266: "cuda:0 bf16[1, 512, 1, 7, 7]"
...
thunder.backward_fn_3:2793: in backward_fn
thunder/executors/nvfuserex_impl.py:402: in __call__
fd = self.get_fd(to_descriptors(args))
thunder/executors/nvfuserex_impl.py:512: in get_fd
return create_fd(bsyms, input_descriptors, sorted_unique_inputs, sorted_unique_outputs)
thunder/executors/nvfuserex_impl.py:274: in create_fd
translate_bound_symbol(bsym)
thunder/executors/nvfuserex_impl.py:264: in translate_bound_symbol
nvresults = translator(*bsym.args, **bsym.kwargs, fd=fd, lc_to_nv_map=lc_to_nv_map)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = t14251, shape = ([IntegerProxy name=i3233, value=1, static=CONSTRAINT.DYNAMIC], 512, 1, 7, 7)
def reshape(a: TensorProxy, shape: list[int], *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
nv_a = getnv(a, fd, lc_to_nv_map)
if nv_version < LooseVersion("0.0.22"):
return fd.ops.reshape(nv_a, a.shape, shape)
else:
print(shape)
> return fd.ops.reshape(nv_a, shape)
E RuntimeError: Unsupported iterable object type for define_vector! Index:0
Another problem is when running pytest thunder/benchmarks/targets.py -k test_resnet50[backward-thunder+nvfuser+torch.compile], torch.compile executor has problem handling prims.convert_element_type with non-tensor input, e.g.: i4803 = prims.convert_element_type(f4802, int) # i4803: "int 0"
torch.compile executor takes it as fusible although the checker in torchex returns False
https://github.com/Lightning-AI/lightning-thunder/blob/8c953b37efe2428722b30e2abf910d4a1a9edfc0/thunder/executors/torch_compile.py#L151
and trigger error in
https://github.com/Lightning-AI/lightning-thunder/blob/8c953b37efe2428722b30e2abf910d4a1a9edfc0/thunder/executors/torch_compile.py#L40
error msg:
thunder/__init__.py:593: in get_computation_and_inputs
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
thunder/executors/torch_autograd.py:209: in split_forward_backward
bw_extrace = transform_for_execution(
thunder/executors/passes.py:155: in transform_for_execution
extrace = ex.fusion_pass(extrace)
thunder/executors/torch_compile.py:170: in fusion_pass
fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter)
thunder/executors/torch_compile.py:125: in fuse
compiled: Callable = make_compiled(region.bound_symbols, sorted_unique_inputs, sorted_unique_outputs)
thunder/executors/torch_compile.py:84: in make_compiled
torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs)
thunder/core/interpreter.py:1313: in fn_
return fn(*args, **kwargs)
thunder/common.py:574: in _trace
result = fn(*proxyargs, **proxykwargs)
thunder/executors/torch_compile.py:65: in torch_interpreted_func
return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator)
thunder/core/transforms.py:1517: in eval_trace
result = prim_func(*args, **kwargs)
thunder/executors/torch_compile.py:40: in _to_torch
return impl_info.execution_transform(*args, **kwargs)
thunder/executors/torchex.py:100: in _convert_element_type_transform
torch_dtype: torch.dtype = to_torch_dtype(dtype)
thunder/core/dtypes.py:592: in to_torch_dtype
baseutils.check_type(x, dtype)
thunder/core/baseutils.py:107: in check_type
check(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cond = False, s = <function check_type.<locals>.<lambda> at 0x7fa2fc8e7880>, exception_type = <class 'ValueError'>
def check(cond: bool, s: Callable[[], str], exception_type: type[Exception] = RuntimeError) -> None:
"""Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
s is a callable producing a string to avoid string construction if the error check is passed.
"""
if not cond:
> raise exception_type(s())
E ValueError: <class 'int'> had an unexpected type <class 'type'>. Supported types are <class 'thunder.core.dtypes.dtype'>
thunder/core/baseutils.py:103: ValueError
handle shape list with mixed IntegerProxy and int, is this a known problem?
Two parts to this question: 1. nvfuser should work with dynamic reshape (AFAIK). So it could be just a nvfuser_impl thing in thunder.
2. thunder doesn't really support dynamic shape properly yet. I'm scared about having a shape list with IntegerProxy in it at this point. I think we should go dig out where it's coming from.
- thunder doesn't really support dynamic shape properly yet. I'm scared about having a shape list with IntegerProxy in it at this point. I think we should go dig out where it's coming from.
t4702 = prims.reshape(t4687, (i1001, 1024, 64, 13, 13)), the i1001 comes from the IntegerProxy passed from forward trace:
the augment_fwd_trace return int numbers, but backward_fn takes it as IntergerProxy, is that expected? @jjsjann123
# fwd return
return {'output': t1234, 'flat_args': [x, t_bn1_bias, t_bn1_num_batches_tracked, ...], 'flat_output': (t1234,)}, ((t0, t1011, ...), (False, 3, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 2, 2, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 2, 2, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 3, 2, 2, 1, 1, 2, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 2, 0, 1, 2, 2, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 2, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 2, 2, 1, 1, 1, 1, 0, 0, 1, 0, 3, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 2, 2, 0, 0, 1, 1, 0))
# bwd inputs
C0, C1, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t11, = cotangents
clear_mutable_collection(cotangents)
del cotangents
t0, t1011, t1012, t1013, t1021, t1024, t1084, t1085, t1086, t1094, t1096, \
... = C0
clear_mutable_collection(C0)
del C0
b66, i10, i1000, i1001, i1011, i11, i12, i136, i137, i138, i139, i14, i140, \
i141, i143, i144, i145, i15, i155, i16, i198, i199, i200, i201, i202, i203, \
i205, i206, i207, i217, i257, i258, i259, i26, i260, i261, i262, i264, i265, \
i266, i276, i319, i320, i321, i322, i323, i324, i326, i327, i328, i338, i381, \
i382, i383, i384, i385, i386, i388, i389, i390, i400, i443, i444, i445, i446, \
i447, i448, i450, i451, i452, i462, i502, i503, i504, i505, i506, i507, i509, \
i510, i511, i521, i564, i565, i566, i567, i568, i569, i571, i572, i573, i583, \
i62, i626, i627, i628, i629, i63, i630, i631, i633, i634, i635, i64, i645, i65, \
i688, i689, i690, i691, i692, i693, i695, i696, i697, i7, i707, i74, i747, \
i748, i749, i75, i750, i751, i752, i754, i755, i756, i76, i766, i77, i78, i79, \
i8, i809, i81, i810, i811, i812, i813, i814, i816, i817, i818, i82, i828, i83, \
i871, i872, i873, i874, i875, i876, i878, i879, i880, i890, i9, i93, i933, \
i934, i935, i936, i937, i938, i940, i941, i942, i952, i992, i993, i994, i995, \
i996, i997, i999, = C1
clear_mutable_collection(C1)
del C1
the augment_fwd_trace return int numbers, but backward_fn takes it as IntergerProxy, is that expected?
That's not. i.e. we should keep it consistent if the forward generates a vanilla integer. And likewise when it is a numberproxy. We just need to further patch #244 I'll give it a try later.
#706 patches nvfuser's reshape with dynamic input.
Meanwhile. I'm still trying to get a repro where grad transform generates a numberproxy where it shouldn't have. Looking at the trace, I think it's coming from convolution....
Status update: Thanks to @jjsjann123 's patch https://github.com/Lightning-AI/lightning-thunder/pull/706, the failure in https://github.com/Lightning-AI/lightning-thunder/pull/451#issuecomment-2186631228 is gone, but there is an nvfuser failure about " Unsupported loop structure. Two loops are mapped together.bS323{1} and bS319{1}", I'm asking nvfuser team's help to identify if it's an nvfuser issue
Status update: Thanks to @jjsjann123 's patch #706, the failure in #451 (comment) is gone, but there is an nvfuser failure about " Unsupported loop structure. Two loops are mapped together.bS323{1} and bS319{1}", I'm asking nvfuser team's help to identify if it's an nvfuser issue
This might be https://github.com/NVIDIA/Fuser/issues/2685, but the RN50 on this PR is likely a more sane reproducer.
Apology for the slow turnaround.
With nvfuser's issue patched, I'm no longer seeing any failures coming from this. cc'ing @kiya00
root@aca1f438e9da:/opt/pytorch/lightning-thunder# pytest thunder/benchmarks/targets.py -k test_resnet50[backward-thunder]
======================================================= test session starts =======================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=torch.utils.benchmark.utils.timer.timer disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=True warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, hypothesis-6.104.2, random-order-1.1.1, timeout-2.3.1, cov-5.0.0, benchmark-4.0.0, xdist-3.6.1, shard-0.1.2, typeguard-4.3.0
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 728 items / 727 deselected / 1 selected
Running 1 items in this shard
thunder/benchmarks/targets.py
. [100%]
------------------------------------------------------ benchmark: 1 tests ------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
--------------------------------------------------------------------------------------------------------------------------------
test_resnet50[backward-thunder] 110.6647 111.1892 110.7535 0.1542 110.7138 0.0271 1;1 9.0291 10 1
--------------------------------------------------------------------------------------------------------------------------------
Legend:
Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
OPS: Operations Per Second, computed as 1 / Mean
==================================== 1 passed, 727 deselected, 7 warnings in 134.28s (0:02:14) ====================================
root@aca1f438e9da:/opt/pytorch/lightning-thunder#
Hi all, thanks to @jjsjann123 's fix, we can get the resnet50 working now, please help to review again, thanks