Calling the model in parallel
Hello,
I'm currently using RelBERT for my experiment and it is quite slow when calling it.
pairlist = [['token1', 'token2'], ...]
for pair in pairlist:
emb = model.get_embedding(pair) # here is slow
I tried using multiprocessing for it, but an error occured.
def getemb(model, pair):
return model.get_embedding(pair)
pairlist = [['token1', 'token2'], ...]
model = RelBERT('asahi417/relbert-roberta-large')
with mp.pool(cpucount) as pool:
pool.map(partial(getemb, model = model), pairlist) # an error occured
The error was:
AssertionError: daemonic processes are not allowed to have children
It would be very helpful if you can point out how I can run it in parallel.
Thank you in advance, Steven
Thanks for using RelBERT! You can directly give the list of word pairs (pairlist in your example) and model will run inference in parallel. You can specify the batch size of the parallelization by specifying batch_size argument.
pairlist = [['token1', 'token2'], ...]
emb = model.get_embedding(pairlist, batch_size=8) # if you not provide batch_size, it will process everything at once
Hello, Thank you for your in-detail explanation.
I'm using the second method you have mentioned. However, there are some tokens that will raise an exception when I'm calling the model.get_embedding() function. I want to filter it out but keep the model running in parallel.
i.e. I want to avoid the following style
try:
# e.g. pairlist has over a million pairs of ['token_1', 'token_1_match'] pairs
# each token in pairlist only contains ascii characters, with _ replacing the space(s) between the words
relList = model.get_embedding(pairlist, batch_size = 10000)
except:
# split pairlist into two halves
# try get_embedding for each half
# if exception: further halving the pairlist
because the model is still required to be called multiple times, which reduces its efficiency.
It would also be essentially helpful if there is a way to skip the pair that causes the exception, or if there is a efficient way to pre-process the pairs so that they only contain valid tokens
Could you share the error message and the word pairs that raise the error? Maybe I can fix the code to pre-process them in a way the model prediction can work without issues.
I have selected a random 10,000-token batch and tested which tokens cause the error. All the tokens come from Dresden Web Table Corpus (https://wwwdb.inf.tu-dresden.de/misc/dwtc/). The runtime is approximately 10 hours.
The model I have used is the model claimed to be the best:
model = RelBERT('asahi417/relbert-roberta-large')
For example, when I got
['China', 'U0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u002600_u0026u0026u0026u0026u002600']
I received the following Error:
AssertionError: exceeded length [0, 100, 938, 17, 27, 90, 2542, 9, 42, 1291, 6, 53, 38, 95, 1166, 11, 5, 45975, 14, 436, 16, 5, 50264, 9, 121, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 2]
The full stack trace is:
Traceback (most recent call last):
File "D:\Project\embtransform\main.py", line 727, in <module>
model.get_embedding(['China', 'U0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u0026u002600_u0026u0026u0026u0026u002600'])
File "D:\Software\anaconda\anaconda\envs\rweenv\lib\site-packages\relbert\lm.py", line 327, in get_embedding
data = self.preprocess(x, pairwise_input=False)
File "D:\Software\anaconda\anaconda\envs\rweenv\lib\site-packages\relbert\lm.py", line 268, in preprocess
positive_embedding = pool_map(positive_samples_list.flatten_list)
File "D:\Software\anaconda\anaconda\envs\rweenv\lib\site-packages\relbert\lm.py", line 264, in pool_map
out = pool.map(EncodePlus(**shared), _list)
File "D:\Software\anaconda\anaconda\envs\rweenv\lib\multiprocessing\pool.py", line 364, in map
return self._map_async(func, iterable, mapstar, chunksize).get()
File "D:\Software\anaconda\anaconda\envs\rweenv\lib\multiprocessing\pool.py", line 771, in get
raise self._value
AssertionError: exceeded length [0, 100, 938, 17, 27, 90, 2542, 9, 42, 1291, 6, 53, 38, 95, 1166, 11, 5, 45975, 14, 436, 16, 5, 50264, 9, 121, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 257, 612, 2481, 2]
Another example is quite different.
The tested token pair is
['Vanuatu','2010-09-_08']
If tested in batch, the same AssertionError will occur.
However, if I test it separately by using
model.get_embedding(['Vanuatu','2010-09-_08'])
no error occurs.
Totally, there are 488 pairs out of 10,000 pairs that will cause error.
Please find the tested batch, tested batch in print form and token pairs that caused error, as well as the stand-alone script I have used to test the token pairs under my GitHub Repo:
https://github.com/RoManInv/relbert-batchtest
where
fail-0.pkl ---- A pickled file for 10,000 tested token pairs
fail-0.txt ---- All 10,000 tested token pairs (using it directly may not get the same error as using the pkl. Maybe not-printable token exists?)
fail-0-breakdown.txt --- All token pairs that cause error
main.py ---- Script I used to test if a token pair causes error