gpt-2-simple icon indicating copy to clipboard operation
gpt-2-simple copied to clipboard

Memory leak / performance slowdown due to graph rebuilding on generate() calls

Open huntrontrakkr opened this issue 5 years ago • 8 comments

I have a project where I am using this package as a base for a chat bot. To speed it up, I've dissected the code in this package and have found the source for the memory leak mentioned in #71, as well as the main factor causing latency in general for the model on subsequent runs. It's mostly this part right here:

output = sample.sample_sequence(
        hparams=hparams,
        length=min(length, 1023 - (len(context_tokens) if prefix else 0)),
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature, top_k=top_k, top_p=top_p
    )[:, 1:]

Every time this function gets called, it rebuilds a new computation graph on top of the already existing graph. This can be resolved by doing two things:

  1. Make tf.placeholders for all of the inputs into the function.
  2. Somehow memoize the output on subsequent calls

The first one I've been able to do experimentally with partial success, but not without having to rewrite small portions of sample.sample_sequence to accommodate the changes. The second part is an organizational change. One could do this easily by making the package an object and handling it that way. I think another way may be smart caching of the object. The computation graph is independent of the python runtime so you should hypothetically be able to get the output by name. I know with placeholders that you can add them to the feed dictionary by adding strings representing their names (i.e., {'context:0':foo} instead of {context:foo} where context is the placeholder object). It may be possible to do the same for the output of a session.run(). This would allow for a set up and a check and to keep holding on to the same computation graph, dramatically speeding up the package and removing any memory leaks you have.

Of course, this might already be getting taken care of migrating to tensorflow 2.0, as I know it deals with graphs more pythonic-ly.

huntrontrakkr avatar Oct 06 '19 17:10 huntrontrakkr

The workaround in the meantime is to reset the graph after a few generations; not great but it works.

minimaxir avatar Oct 13 '19 20:10 minimaxir

How far off is the TF 2.0 upgrade?

inspire22 avatar Mar 06 '20 03:03 inspire22

I wouldn't plan on it.

minimaxir avatar Mar 06 '20 03:03 minimaxir

Haha, ouch. Resetting the graph, if I run import tensorflow as tf tf.reset_default_graph() I get the error ValueError: Tensor Tensor("strided_slice_1:0", shape=(1, ?), dtype=int32) is not an element of this graph. Is there a better way to reset a (fine-tuned) graph? I'm running a non GPU tensorflow to just generate from a pretrained model (trying to set up an autocomplete service)

inspire22 avatar Mar 06 '20 03:03 inspire22

The workaround in the meantime is to reset the graph after a few generations; not great but it works.

tf.reset_default_graph() doesn't work for me: AssertionError:("Do not use tf.reset_default_graph() to clear nested graphs") @minimaxir How would you go about resetting the graph?

bjoernhommel avatar Apr 10 '20 19:04 bjoernhommel

For some reason you have to reset the session, then the graph.

minimaxir avatar Apr 11 '20 20:04 minimaxir

@minimaxir Hello, I ran into the same mem leak issue. How would you go about resetting the session and the graph?

wsong-fv avatar May 29 '20 14:05 wsong-fv

@minimaxir Hello, I ran into the same mem leak issue. How would you go about resetting the session and the graph?

Use this: https://github.com/minimaxir/gpt-2-simple/blob/92d35962d9aaeadba70e39d11d040f1e377ffdb3/gpt_2_simple/gpt_2.py#L113

zacc avatar Jun 20 '20 06:06 zacc