gpt-2-simple
gpt-2-simple copied to clipboard
Memory leak / performance slowdown due to graph rebuilding on generate() calls
I have a project where I am using this package as a base for a chat bot. To speed it up, I've dissected the code in this package and have found the source for the memory leak mentioned in #71, as well as the main factor causing latency in general for the model on subsequent runs. It's mostly this part right here:
output = sample.sample_sequence(
hparams=hparams,
length=min(length, 1023 - (len(context_tokens) if prefix else 0)),
start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
context=context if prefix else None,
batch_size=batch_size,
temperature=temperature, top_k=top_k, top_p=top_p
)[:, 1:]
Every time this function gets called, it rebuilds a new computation graph on top of the already existing graph. This can be resolved by doing two things:
- Make tf.placeholders for all of the inputs into the function.
- Somehow memoize the output on subsequent calls
The first one I've been able to do experimentally with partial success, but not without having to rewrite small portions of sample.sample_sequence to accommodate the changes. The second part is an organizational change. One could do this easily by making the package an object and handling it that way. I think another way may be smart caching of the object. The computation graph is independent of the python runtime so you should hypothetically be able to get the output by name. I know with placeholders that you can add them to the feed dictionary by adding strings representing their names (i.e., {'context:0':foo}
instead of {context:foo}
where context is the placeholder object). It may be possible to do the same for the output of a session.run(). This would allow for a set up and a check and to keep holding on to the same computation graph, dramatically speeding up the package and removing any memory leaks you have.
Of course, this might already be getting taken care of migrating to tensorflow 2.0, as I know it deals with graphs more pythonic-ly.
The workaround in the meantime is to reset the graph after a few generations; not great but it works.
How far off is the TF 2.0 upgrade?
I wouldn't plan on it.
Haha, ouch. Resetting the graph, if I run
import tensorflow as tf tf.reset_default_graph()
I get the error
ValueError: Tensor Tensor("strided_slice_1:0", shape=(1, ?), dtype=int32) is not an element of this graph.
Is there a better way to reset a (fine-tuned) graph? I'm running a non GPU tensorflow to just generate from a pretrained model (trying to set up an autocomplete service)
The workaround in the meantime is to reset the graph after a few generations; not great but it works.
tf.reset_default_graph()
doesn't work for me:
AssertionError:("Do not use tf.reset_default_graph() to clear nested graphs")
@minimaxir How would you go about resetting the graph?
For some reason you have to reset the session, then the graph.
@minimaxir Hello, I ran into the same mem leak issue. How would you go about resetting the session and the graph?
@minimaxir Hello, I ran into the same mem leak issue. How would you go about resetting the session and the graph?
Use this: https://github.com/minimaxir/gpt-2-simple/blob/92d35962d9aaeadba70e39d11d040f1e377ffdb3/gpt_2_simple/gpt_2.py#L113