pie
pie copied to clipboard
Improve prediction time (predict_max)
The current prediction time is quite slow, we agree that there might be room for improvement.
After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS).
I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here.
For a group of 19 sentences over 100 iterations
Average tagging time with default
: 0.556127781867981 s
Median tagging time with default
: 0.5420029163360596
Total tagging time with default
: 55.612778186798096 s
For a group of 19 sentences over 100 iterations
Average tagging time with new
: 0.4061899709701538 s
Median tagging time with new
: 0.40130531787872314
Total tagging time with new
: 40.61899709701538 s
- 27 % time for the whole tagging (lemma only)
Improving predict_max
The idea is to reduce computation at prediction time: the current system continues to predict even when EOS is reached for tokens. So until the biggest prediction is reached, all tokens are computed on.
Setup
import torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time
from pie.models.decoder import AttentionalDecoder
import torch.nn.functional as F
from pie.models import BaseModel
from pie.tagger import Tagger
from pie.data import Dataset, Reader
from pie.settings import load_default_settings, settings_from_file
TINY = 1e-8
DEVICE = "cpu"
import copy
old = copy.deepcopy(AttentionalDecoder.predict_max)
def load_and_monkey_patch(patch = None):
if not patch:
patch = old
AttentionalDecoder.predict_max = patch
tagger = Tagger()
tagger.add_model("models/Final-Latin-Lemma-H384-C700-lemma-2020_08_06-18_44_24.tar", "lemma")
NormalModel = tagger.models[0][0]
settings = NormalModel._settings
settings.device = DEVICE
settings.shuffle = False # avoid shuffling
return tagger, NormalModel
sentences = """Lorem ipsum dolor sit amet, consectetur adipiscing elit.
Phasellus dolor sapien, laoreet non turpis eget, tincidunt commodo magna. Duis at dapibus ipsum.
Etiam fringilla et magna sed vehicula.
Nunc tristique eros non faucibus viverra.
Sed dictum scelerisque tortor, eu ullamcorper odio.
Aenean fermentum a urna quis tempus.
Maecenas imperdiet est a nisi pellentesque dictum.
Maecenas ac hendrerit ante. Vestibulum eleifend nulla at vulputate sagittis.
Maecenas sed magna diam sed facilisis tempus ipsum, nec mattis elit tincidunt lobortis Phasellus vel ex lorem nulla nunc odio, tempor non consequat in, luctus elementum dolor.
Nullam tincidunt purus vel lorem placerat, ac pulvinar turpis sodales.
Sed eget urna ac quam cursus porta.
Pellentesque luctus aliquet sem, a egestas purus finibus ac.
Mauris nec mauris non metus tempor faucibus non in est.
Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos.
Proin tristique nulla nec purus iaculis, eu scelerisque mi egestas.
In hac habitasse platea dictumst.
Ut placerat a neque eget aliquet. """.lower().replace("\n", "").replace(",", "").split(".")
Test = [
sent.split()
for sent in sentences
if sent.split()
]
Testing function
import time
import statistics
def test(patch, n_iters=10, sentences=Test, name="default"):
tagger, _ = load_and_monkey_patch(patch)
lengths = [len(x) for x in sentences]
times = []
for i in range(n_iters):
start = time.time()
out = tagger.tag(sentences, lengths)
times.append(time.time() - start)
print("====")
print(f"For a group of {len(sentences)} sentences over {n_iters} iterations")
print(f"Average tagging time with `{name}`: {sum(times) / n_iters} s")
print(f"Median tagging time with `{name}`: {statistics.median(times)}")
print(f"Total tagging time with `{name}`: {sum(times)} s")
return out
New function
def predict_max_debug(self, enc_outs, lengths,
max_seq_len=20, bos=None, eos=None,
context=None):
"""
Decoding routine for inference with step-wise argmax procedure
Parameters
===========
enc_outs : tensor(src_seq_len x batch x hidden_size)
context : tensor(batch x hidden_size), optional
"""
eos = eos or self.label_encoder.get_eos()
bos = bos or self.label_encoder.get_bos()
hidden, batch, device = None, enc_outs.size(1), enc_outs.device
inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos
hyps, scores = [], [0 for _ in range(batch)]
# We store a conversion table for tensor index to
# Tensor Index -> Hyp Index
indexes = {
x: x for x in range(batch)
}
for _ in range(max_seq_len):
# prepare input
# Context is NEVER changed after the method has been called
emb = self.embs(inp)
if context is not None:
emb = torch.cat([emb, context], dim=1)
# run rnn
# Move embeddings to a 2-d Tensor to a 3-D tensor (1, word number, emb size(+context))
emb = emb.unsqueeze(0)
# Hidden is always reused
# -> Hidden is (1, word number, emb size)
outs, hidden = self.rnn(emb, hidden)
outs, _ = self.attn(outs, enc_outs, lengths)
outs = self.proj(outs).squeeze(0)
# get logits
probs = F.log_softmax(outs, dim=1)
# sample and accumulate
score, inp = probs.max(1)
# We create a mask of value that are not ending the string
non_eos = (inp != eos)
# Keep are the index of item we choose to keep (ie, not ending with EOS)
keep = torch.nonzero(non_eos, as_tuple=True)[0]
# add new chars to hypotheses
# We prepare a list the size of the output (with EOS)
# Once done, we replace the values using the table of equivalencies
to_append = [eos for _ in range(batch)]
new_scores = [0 for _ in range(batch)]
for ind, (hyp, sc) in enumerate(zip(inp.tolist(), score.tolist())):
to_append[indexes[ind]] = hyp
if hyp != eos:
scores[indexes[ind]] += sc
hyps.append(to_append)
# If there is no non_eos, it's the end of the prediction time
if True not in non_eos:
break
# We update the indexes so that tensor "row" index maps to the correct
# hypothesis value
indexes = {elem: indexes[former_index] for elem, former_index in enumerate(keep.tolist())}
# print(indexes)
# Stop are the index of elements we remove from the input tensor
inp = inp[keep]
context = context[keep]
lengths = lengths[keep]
# Hidden is 3D with 1 in first dimension
hidden = hidden.squeeze(0)[keep].unsqueeze(0)
# enc_outs is seq * batch * size, so we tranpose and transpose back
# Seq_len is supposed to be equal to max(lengths), but if the maximum length is popped
# We need to reduce the dimension of enc_outs as well
max_seq_len = lengths.max()
enc_outs = enc_outs[:max_seq_len].transpose(0, 1)[keep].transpose(0, 1)
hyps = [self.label_encoder.stringify(hyp) for hyp in zip(*hyps)]
scores = [s/(len(hyp) + TINY) for s, hyp in zip(scores, hyps)]
return hyps, scores
new_out = test(patch=predict_max_debug, sentences=Test, name="new", n_iters=100)
====
For a group of 19 sentences over 100 iterations
Average tagging time with `new`: 0.4061899709701538 s
Median tagging time with `new`: 0.40130531787872314
Total tagging time with `new`: 40.61899709701538 s
Current function
def predict_max(self, enc_outs, lengths,
max_seq_len=20, bos=None, eos=None,
context=None):
"""
Decoding routine for inference with step-wise argmax procedure
Parameters
===========
enc_outs : tensor(src_seq_len x batch x hidden_size)
context : tensor(batch x hidden_size), optional
"""
eos = eos or self.label_encoder.get_eos()
bos = bos or self.label_encoder.get_bos()
hidden, batch, device = None, enc_outs.size(1), enc_outs.device
mask = torch.ones(batch, dtype=torch.int64, device=device)
inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos
hyps, scores = [], 0
for _ in range(max_seq_len):
if mask.sum().item() == 0:
break
# prepare input
emb = self.embs(inp)
if context is not None:
emb = torch.cat([emb, context], dim=1)
# run rnn
emb = emb.unsqueeze(0)
outs, hidden = self.rnn(emb, hidden)
outs, _ = self.attn(outs, enc_outs, lengths)
outs = self.proj(outs).squeeze(0)
# get logits
probs = F.log_softmax(outs, dim=1)
# sample and accumulate
score, inp = probs.max(1)
hyps.append(inp.tolist())
mask = mask * (inp != eos).long()
score = score.cpu()
score[mask == 0] = 0
scores += score
hyps = [self.label_encoder.stringify(hyp) for hyp in zip(*hyps)]
scores = [s/(len(hyp) + TINY) for s, hyp in zip(scores.tolist(), hyps)]
return hyps, scores
former_out = test(patch=predict_max, sentences=Test, name="default", n_iters=100)
====
For a group of 19 sentences over 100 iterations
Average tagging time with `default`: 0.556127781867981 s
Median tagging time with `default`: 0.5420029163360596
Total tagging time with `default`: 55.612778186798096 s
new_out == former_out
True
This time I'll wait for your feedback @emanjavacas
Note that the improvement I noted are highly dependant of the seq_len and the disparity accross prediction sizes.
Are you testing this on a gpu?
On Fri, Aug 7, 2020 at 3:09 PM Thibault Clérice [email protected] wrote:
Note that the improvement I noted are highly dependant of the seq_len and the disparity accross prediction sizes.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/emanjavacas/pie/issues/73#issuecomment-670508105, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABPIPI45S2Y3ZUKLKUITM7TR7P4KFANCNFSM4PXTDLSA .
-- Enrique Manjavacas
I tested on CPU for now. Which is generally what the users have for inference or webservices.
For a stupid evaluation, I ran the same model on README.md, went from 13 seconds to 7 seconds...
I actually PRed so that you can pull and test :)
I am quite busy right now, so it will take some time until I can look at this. My impression is that gpu will actually be slower, since parallelizing is very cheap. I am not opposed to the idea of optimizing at this level if people can profit from that (even if the gains aren't really that significant), but the code would have to improve readability, because this is a bug-sensitive part of the codebase that I eventually come back to modify.
On Fri, Aug 7, 2020 at 3:46 PM Thibault Clérice [email protected] wrote:
I actually PRed so that you can pull and test :)
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/emanjavacas/pie/issues/73#issuecomment-670524795, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABPIPIYFW2VS6YXKIFARIJDR7QATDANCNFSM4PXTDLSA .
-- Enrique Manjavacas
Actually, on real world data, the improvement are substantial:
For this test, CPU = i7 6700k, GPU = 1080
CPU | GPU | |||||||
---|---|---|---|---|---|---|---|---|
Unit | New | Actual | Diff | New | Actual | Diff | ||
Aeneid (9843 units) | sec | 58 | 81 | 71.60% | # | 19 | 23 | 82.61% |
it/sec | 169.64 | 120.45 | 140.84% | # | 495.12 | 417.56 | 118.57% | |
Priapea (127 units) | sec | 3 | 5 | 60.00% | # | 1 | 2 | 50.00% |
it/sec | 39.98 | 24.4 | 163.85% | # | 64.25 | 61.3 | 104.81% | |
Martial (1600 units) | sec | 44 | 76 | 57.89% | # | 8 | 9 | 88.89% |
it/sec | 36.06 | 20.88 | 172.70% | # | 187.6 | 163.6 | 114.67% |
For this test, i5 8265u, on a laptop (Base Freq: 1,6 GHz, 4 cores, 8 threads)
New | Actual | Diff | ||
---|---|---|---|---|
Aeneid | sec | 176 | 245 | 71,84 % |
it/sec | 55,74 | 40,11 | 138,97 % | |
Priapea | sec | 8 | 15 | 53,33 % |
it/sec | 14,26 | 8,1 | 176,05 % | |
Martial | sec | 125 | 224 | 55,80 % |
it/sec | 12,77 | 7,12 | 179,35 % |
it/sec > 100 % is improved, sec < 100 % is improved
It's actually better with the files test.zip
It consistently beats CPU by a huge margin, and beats GPU with a relatively small margin (from 4 to 18 %)
I am quite busy right now,
I did not address this, but please, take the time you need. It's august, you are probably on vacation and if you are not, I technically am. So it can wait ;)
I finished the PR by adding a lot of comments, making sure variable names are clear. I also added / edited my comments to add some more information.
I know your time is busy, but if you are merging PR, the #74 might be somewhere you could stop, if you have time :) Those are some nice improvements :)