transformers
transformers copied to clipboard
[Summary] Regarding memory issue in tests
Description
This is a short summary of the memory issue in our tests
The following tests definitely have memory issues
-
PyTorch (increase ~
15 MB
each call):- test_torch_fx
- test_torch_fx_output_loss
-
TensorFlow:
- test_xla_fit
- test_xla_generate_fast (increase ~
100 MB
each call) - test_xla_generate_slow
- test_xla_mode
- test_onnx_runtime_optimize (increase ~
8 MB
each call) - test_dataset_conversion (increase ~
0.2 M
B each call)
-
Flax:
- Almost all test methods have memory issue!
- The CircleCI job run page demonstrates this issue too
Some tests are also suspicious, but need more investigations.
-
For example, the test
test_graph_mode
have the following memory difference in consecutive runs (in KB):[936.0, 520.0, 260.0, 520.0, 0.0, 0.0, 260.0, 520.0, 0.0, 0.0, 260.0, 0.0, 0.0, 0.0, 260.0, 260.0, 260.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
(not always increase, but it continues to happen)
-
For
test_saved_model_creation_extended
(in KB):[144436.0, -104552.0, 1280.0, -103908.0, -1536.0, 177868.0, -33572.0, 20240.0, 170852.0, -51704.0, -8448.0, 59904.0, -48128.0, 2440.0, 34856.0, 3068.0, -3420.0, -36864.0, -6756.0, 36136.0, -2048.0, -17400.0, -4608.0, -25896.0, 4096.0, 1024.0, 22344.0, 25784.0, -256.0]
(sometimes some amount of memory is released, but still leaks in the long run?)
Pytest itself will accumulate some memory usage as tests continue to run.
This is just my hypothesis: sometimes I see an increase of a few KB after a sequence of runs without leak.
Possible actions to take
-
(It's probably worth it to fix this issue for a few tests mentioned above to gain some experience):
- In this case, we can only focus on
non-slow
tests
- In this case, we can only focus on
-
[Not to go] There is a
pytest
pluginpytest-forked
to run each test in a forked subprocess. But it doesn't work well with TensorFlow and Flax (some tests will hang forever). I will provide some details in the comments. -
We can try to run the tests per model in each CircleCI job steps. However, the output on job run pages will be a bit noisy, but we can have an extra step to print the test failures in a cleaner way.
TensorFlow hangs if a TF model is forked
This will hangs
import tensorflow as tf
from transformers import TFDistilBertModel, DistilBertConfig
import multiprocessing
config = DistilBertConfig()
config.n_layers = 1
config.n_heads = 2
config.dim = 4
config.hidden_dim = 4
model = TFDistilBertModel(config)
def func(i):
print(f"func with arg {i}: start")
inputs = tf.ones(shape=(2, 3), dtype=tf.int32)
outputs = model(inputs)
print(f"func with arg {i}: done")
return outputs
print("start")
with multiprocessing.Pool(processes=1) as pool:
r = pool.map(func=func, iterable=range(16))
print("all done")
print(len(r))
Strange hanging with TensorFlow Probability
Running the test with --forked
python3 -m pytest --forked -n 2 --max-worker-restart=0 --dist=loadfile -s --make-reports=tests_tf tests/models/auto/test_modeling_tf_auto.py | tee tests_output.txt
with tensorflow-probability
installed will hang. After uninstalling tensorflow-probability
, the tests finish quickly.
(I am not sure what happens with tensorflow-probability
here though)
Actually, running the following also hangs:
python3 -m pytest --forked -v test_tf.py
with test_tf.py
being
from transformers import TFAutoModelWithLMHead
#import tensorflow_probability as tfp
from transformers.models.tapas.modeling_tf_tapas import TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
def test_foo():
model = TFAutoModelWithLMHead.from_pretrained("julien-c/dummy-unknown")
--forked hang with Flax tests
Running the following test with --forked
will hang
python3 -m pytest --forked -v test_flax.py
with test_flax.py
being
def test_flax_foo():
from transformers import FlaxDistilBertModel, DistilBertConfig
import numpy as np
config = DistilBertConfig()
config.n_layers = 1
config.n_heads = 2
config.dim = 4
config.hidden_dim = 4
model = FlaxDistilBertModel(config)
cc @LysandreJik for reading :-)
To ease the debugging process, the code snippet below is a self-contained script for running FlaxBart
. The results looks like
(mem_FlaxBartForConditionalGeneration.json
, the memory usage in MB
)
[
157772.0,
823724.0,
850768.0,
878004.0,
905340.0,
933288.0,
959816.0,
986800.0,
1013596.0,
1041560.0,
1067088.0,
1095960.0,
1121640.0,
1149596.0,
1175144.0,
1203396.0,
1228764.0,
1256536.0,
1282528.0,
1309668.0,
1337724.0,
1362584.0,
1390300.0,
1417172.0,
1443084.0,
1471568.0,
1494896.0,
1500424.0,
1512176.0,
1519920.0,
1529484.0
]
Here is the code snippet to run test_beam_search_generate
.
(This removes all unittest
elements, and running without pytest)
import copy
import json
import numpy as np
import os
import psutil
import random
import jax.numpy as jnp
from jax import jit
from transformers import BartConfig, FlaxBartModel, FlaxBartForConditionalGeneration, FlaxBartForSequenceClassification, FlaxBartForQuestionAnswering
def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
output = np.array(values, dtype=jnp.int32).reshape(shape)
return output
def random_attention_mask(shape, rng=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
# make sure that at least one token is attended to for each batch
attn_mask[:, -1] = 1
return attn_mask
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
def prepare_bart_inputs_dict(
config,
input_ids,
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
if decoder_attention_mask is None:
decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0)
if head_mask is None:
head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
}
class FlaxBartModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=32,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.initializer_range = initializer_range
def prepare_config_and_inputs(self):
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
decoder_input_ids = shift_tokens_right(input_ids, 1, 2)
config = BartConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
initializer_range=self.initializer_range,
use_cache=False,
)
inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
class FlaxBartModelTest:
is_encoder_decoder = True
def __init__(self, model_class):
self.model_tester = FlaxBartModelTester(self)
self.model_class = model_class
def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = copy.deepcopy(inputs_dict)
# hack for now until we have AutoModel classes
if "ForMultipleChoice" in model_class.__name__:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
if isinstance(v, (jnp.ndarray, np.ndarray))
else v
for k, v in inputs_dict.items()
}
return inputs_dict
def _get_input_ids_and_config(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = inputs["input_ids"].shape[-1] // 2
input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]
attention_mask = jnp.ones_like(input_ids)
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 5 tokens
max_length = input_ids.shape[-1] + 5
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model_inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**model_inputs)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, self.model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, self.model_class)
def test_beam_search_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.num_beams = 2
model = self.model_class(config)
generation_outputs = model.generate(input_ids).sequences
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
if __name__ == "__main__":
all_model_classes = (
(
# FlaxBartModel,
FlaxBartForConditionalGeneration,
# FlaxBartForSequenceClassification,
# FlaxBartForQuestionAnswering,
)
)
for model_class in all_model_classes:
test = FlaxBartModelTest(model_class)
all_rss = []
p = psutil.Process(os.getpid())
m = p.memory_full_info()
rss = m.rss / 1024
all_rss.append(rss)
for i in range(30):
# This is fine
# test.test_hidden_states_output()
# Mem. leak
test.test_beam_search_generate()
m = p.memory_full_info()
rss = m.rss / 1024
all_rss.append(rss)
fn = f"mem_{model_class.__name__}.json"
with open(fn, "w") as fp:
json.dump(all_rss, fp, ensure_ascii=False, indent=4)
Thanks for summarizing all the info, @ydshieh!
To debug test_torch_fx
more easily:
with n_iter = 500
:
- with new process: + 60 MB
- without new process: + 1700 MB
- without
scripted(**filtered_inputs)
: + 400 MB - without
scripted(**filtered_inputs)
andtorch.jit.script(traced_model)
: + 30 MB
- without
import copy
import torch
import tempfile
import os
import json
import pickle
import psutil
import multiprocessing
from transformers.utils.fx import symbolic_trace
from transformers import BartConfig, BartModel
torch_device = "cpu"
model_class = BartModel
config_dict = {
"activation_dropout": 0.0,
"activation_function": "gelu",
"attention_dropout": 0.1,
"bos_token_id": 0,
"classifier_dropout": 0.0,
"d_model": 16,
"decoder_attention_heads": 4,
"decoder_ffn_dim": 4,
"decoder_layerdrop": 0.0,
"decoder_layers": 2,
"decoder_start_token_id": 2,
"dropout": 0.1,
"encoder_attention_heads": 4,
"encoder_ffn_dim": 4,
"encoder_layerdrop": 0.0,
"encoder_layers": 2,
"eos_token_id": 2,
"forced_eos_token_id": None,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_encoder_decoder": True,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"max_position_embeddings": 20,
"model_type": "bart",
"num_hidden_layers": 2,
"pad_token_id": 1,
"scale_embedding": False,
"transformers_version": "4.22.0.dev0",
"use_cache": True,
"vocab_size": 99
}
config = BartConfig(**config_dict)
inputs = {
'input_ids': torch.tensor([
[22, 30, 84, 13, 46, 95, 2],
[74, 91, 58, 38, 3, 48, 2],
[43, 32, 21, 60, 12, 42, 2],
[20, 24, 75, 46, 62, 55, 2],
[59, 91, 36, 57, 40, 36, 2],
[23, 24, 33, 70, 13, 93, 2],
[15, 4, 11, 45, 5, 87, 2],
[78, 76, 67, 38, 3, 46, 2],
[ 3, 31, 35, 85, 81, 46, 2],
[47, 45, 97, 80, 75, 91, 2],
[92, 49, 42, 65, 74, 98, 2],
[67, 37, 84, 88, 55, 57, 2],
[24, 53, 44, 36, 45, 24, 2],
], dtype=torch.int32),
'decoder_input_ids': torch.tensor([
[50, 56, 84, 91, 16, 49, 54],
[ 2, 71, 62, 39, 27, 4, 93],
[73, 45, 61, 63, 35, 25, 7],
[27, 33, 23, 86, 13, 49, 32],
[74, 36, 46, 83, 18, 40, 22],
[45, 69, 41, 3, 29, 56, 49],
[ 3, 38, 8, 52, 17, 55, 15],
[63, 79, 42, 64, 62, 39, 40],
[28, 59, 69, 14, 77, 45, 36],
[56, 55, 82, 35, 66, 51, 19],
[18, 96, 43, 34, 16, 69, 94],
[68, 65, 52, 17, 77, 78, 54],
[68, 57, 74, 42, 60, 13, 91]
]),
'attention_mask': torch.tensor([
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True]
], dtype=torch.bool),
'decoder_attention_mask': torch.tensor([
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True],
[True, True, True, True, True, True, True]
], dtype=torch.bool),
'head_mask': torch.tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]]),
'decoder_head_mask': torch.tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]]),
'cross_attn_head_mask': torch.tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]])
}
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
def _run_torch_jit(in_queue, out_queue):
model, input_names, filtered_inputs = in_queue.get()
traced_model = symbolic_trace(model, input_names)
# blocked if forked
with torch.no_grad():
traced_output = traced_model(**filtered_inputs)
# Test that the model can be TorchScripted
scripted = torch.jit.script(traced_model)
with torch.no_grad():
scripted_output = scripted(**filtered_inputs)
out_queue.put((traced_model, scripted_output))
out_queue.join()
def create_and_check_torch_fx_tracing(model_class, config, inputs, n_iter=100, with_new_proc=False):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
model.config.use_cache = False
input_names = [
"attention_mask",
"decoder_attention_mask",
"decoder_input_ids",
"input_features",
"input_ids",
"input_values",
]
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
all_rss = []
p = psutil.Process(os.getpid())
m = p.memory_full_info()
rss = m.rss / 1024
all_rss.append(rss)
for i in range(n_iter):
print(f"idx: {i} - start")
if not with_new_proc:
traced_model = symbolic_trace(model, input_names)
with torch.no_grad():
traced_output = traced_model(**filtered_inputs)
# Test that the model can be TorchScripted
scripted = torch.jit.script(traced_model)
with torch.no_grad():
scripted_output = scripted(**filtered_inputs)
else:
ctx = multiprocessing.get_context('spawn')
in_queue = ctx.Queue()
out_queue = ctx.JoinableQueue()
in_queue.put((model, input_names, filtered_inputs))
process = ctx.Process(target=_run_torch_jit, args=(in_queue, out_queue))
process.start()
traced_model, scripted_output = out_queue.get()
out_queue.task_done()
process.join()
print(f"idx: {i} - end")
print("=" * 40)
m = p.memory_full_info()
rss = m.rss / 1024
all_rss.append(rss)
fn = f"torch_jit_script_mem_with_new_proc={with_new_proc}.json"
with open(fn, "w") as fp:
json.dump(all_rss, fp, ensure_ascii=False, indent=4)
if __name__ == "__main__":
create_and_check_torch_fx_tracing(model_class, config, inputs, n_iter=500, with_new_proc=True)
create_and_check_torch_fx_tracing(model_class, config, inputs, n_iter=500, with_new_proc=False)
@patil-suraj @sanchit-gandhi @patrickvonplaten
We have memory leak issue in some Flax tests. Basically, I observed this happens for test_beam_search_generate
, test_beam_search_generate_attn_mask
and test_beam_search_generate_logits_warper
, but there might be more.
Each call to them increase memory usage by 10~30 MB.
The CircleCI job run page also shows memory issue in Flax testing (https://app.circleci.com/pipelines/github/huggingface/transformers/45317/workflows/5bcb8b8a-776c-4c58-ad99-cf2700304c05/jobs/528556/resources)
To reproduce, see here for test_beam_search_generate
.
Not very urgent, but we will have trouble once models are added. Could you have a look, please? Let me know if you need more information.
Hey @ydshieh,
I'm a bit under water at the moment - I'll put the issue on my TODO-list, but I can't promise to find time to look into it very soon. This link: https://app.circleci.com/pipelines/github/huggingface/transformers/45317/workflows/5bcb8b8a-776c-4c58-ad99-cf2700304c05/jobs/528556/resources doesn't seem to show anything useful to me.
Also just to understand better, are the flax tests running on GPU or CPU?