infercode
infercode copied to clipboard
Question: Is there a reason that the batch size is capped at 5?
Hey, first off, love the repo, thanks for providing it.
I just have a question about the batch size that the model can handle.
When I put in a list of more than 5 pieces of code like so:
from infercode.client.infercode_client import InferCodeClient
import os
import logging
logging.basicConfig(level=logging.INFO)
# Change from -1 to 0 to enable GPU
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
infercode = InferCodeClient(language="java")
infercode.init_from_config()
# Here we put in 6 identical i initiailizations
vectors = infercode.encode(["int i = 0;"] * 6)
I get the following error:
AssertionError Traceback (most recent call last)
Input In [20], in <cell line: 1>()
----> 1 vectors = infercode.encode(["int i = 0;"] * 6)
File ~\Anaconda3\envs\infercode_new_env\lib\site-packages\infercode\client\infercode_client.py:76, in InferCodeClient.encode(self, batch_code_snippets)
75 def encode(self, batch_code_snippets):
---> 76 tensors = self.snippets_to_tensors(batch_code_snippets)
77 embeddings = self.sess.run(
78 [self.infercode_model.code_vector],
79 feed_dict={
(...)
87 }
88 )
89 return embeddings[0]
File ~\Anaconda3\envs\infercode_new_env\lib\site-packages\infercode\client\infercode_client.py:62, in InferCodeClient.snippets_to_tensors(self, batch_code_snippets)
60 def snippets_to_tensors(self, batch_code_snippets):
61 batch_tree_indexes = []
---> 62 assert len(batch_code_snippets) <= 5
63 for code_snippet in batch_code_snippets:
64 # tree-sitter parser requires bytes as the input, not string
65 code_snippet_to_byte = str.encode(code_snippet)
AssertionError:
This stems from the code here having an assert that the number of inputs is <=5:
https://github.com/bdqnghi/infercode/blob/8c22a3353aabb2f02ce1e044d439969df7463a0d/infercode/client/infercode_client.py#L62
Is there a reason this is hard-coded? Or would it make sense to make a batch_size
parameter that maybe defaults to 5 but is adjustable depending on computational capacity?