DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[REQUEST] How to use `Flops Profiler` to test `model.generate()`

Open CaffreyR opened this issue 3 years ago • 3 comments
trafficstars

Is your feature request related to a problem? Please describe.

Nowadays, the Flops Profiler can test forward for transformers such as bert, but for t5, it actually uses model.generate() during inference. So how to use Flops Profiler to test model.generate() ?

Now I use this code

tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)model = transformers.T5ForConditionalGeneration.from_pretrained('t5-base')

flops, macs, params = get_model_profile(
          model,
          kwargs=t5_input_constructor(batch_size, seq_len, tokenizer),
          print_profile=True,
          detailed=True,
          warm_up=10,
          module_depth=-1,
          top_modules=1,
      )

Describe the solution you'd like A clear and concise description of what you want to happen. To test model.generate() time cost

Describe alternatives you've considered A clear and concise description of any alternative solutions or features you've considered.

Additional context Add any other context or screenshots about the feature request here.

CaffreyR avatar Nov 17 '22 13:11 CaffreyR

Hi @CaffreyR , here is how you can use it for model.generate(). I will add this to README examples. thanks.

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids

prof = FlopsProfiler(model)
# start proifle
prof.start_profile()

outputs = model.generate(input_ids)

# stop proifle and collect the profiled results
prof.stop_profile()
flops = prof.get_total_flops()
macs = prof.get_total_macs()
params = prof.get_total_params()
prof.print_model_profile()
prof.end_profile()

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

cli99 avatar Jan 31 '23 20:01 cli99

Hi @cli99 I have written a PR a few weeks ago, they probably could merge together. https://github.com/microsoft/DeepSpeed/pull/2515

CaffreyR avatar Feb 01 '23 10:02 CaffreyR

Hi @cli99 I have written a PR a few weeks ago, they probably could merge together. #2515

Thanks. Merging them together.

cli99 avatar Feb 03 '23 05:02 cli99