generative-ai-python
generative-ai-python copied to clipboard
Docs fail to document thread-safety of genai client, and it fails irrecoverably on multi-threaded use
Description of the bug:
The Generative Service Client or GenerativeModel classes don't document thread safety assumptions, and don't appear to be usable in a multithreaded environment for making concurrent API requests.
I'd suggest either:
- documenting thread safety assumptions and guarantees, or
- investigating behaviour when a client is shared between threads
Behaviour observed: After trying to make concurrent calls to the generative text api, most calls failed with a 60s timeout. The client never recovered (that is, every new call attempt also froze for 60s then ultimately timed out with an error).
Sample error output:
10%|▉ | 199/2047.0 [29:31<5:46:51, 11.26s/it]
HTTPConnectionPool(host='localhost', port=46423): Read timed out. (read timeout=60.0)
10%|▉ | 204/2047.0 [30:22<4:59:27, 9.75s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
10%|█ | 209/2047.0 [31:10<6:08:26, 12.03s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
11%|█ | 216/2047.0 [31:43<3:43:00, 7.31s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
11%|█ | 225/2047.0 [32:48<3:52:42, 7.66s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
11%|█▏ | 231/2047.0 [33:38<4:22:00, 8.66s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
12%|█▏ | 245/2047.0 [35:55<6:14:28, 12.47s/it]
HTTPConnectionPool(host='localhost', port=46423): Read timed out. (read timeout=60.0)
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
14%|█▍ | 296/2047.0 [43:38<4:30:46, 9.28s/it]
HTTPConnectionPool(host='localhost', port=46423): Read timed out. (read timeout=60.0)
Example snippet:
# [... regular imports ...]
from concurrent.futures import ThreadPoolExecutor
import tqdm
safety_settings = ...
executor = ThreadPoolExecutor(max_workers=5)
def build_data_batch():
## build batches of data to process
pass
def generate(data_batch):
model_out = 'error'
try:
# this ends up failing whether or not the model client is
# created freshly per-request, or shared across threads
model = genai.GenerativeModel('models/gemini-pro', safety_settings=safety_settings)
model_out = model.generate_content(build_prompt(data_batch)).text
except Exception as e:
print(e)
return model_out
all_outputs = []
all_outputs = executor.map(generate, build_data_batch())
with open('./outputs.txt', 'w') as f:
for result in tqdm.tqdm(all_outputs, total=totalbatches):
f.write(result)
Actual vs expected behavior:
Actual: all calls fail.
Expected: this case should either work, or client docs should document as non-thread-safe for concurrent usage given how common batch inference scenarios are likely to be.
Any other information you'd like to share?
No response
Hey blocked on the same issue - did you manage to find any sort of workaround?
@adenalhardan Yes, instead of using the API client library, I ended up directly using the HTTP API endpoint. Using independent HTTP clients (1 per worker thread), then no problems.
Awesome thank you
@adenalhardan Yes, instead of using the API client library, I ended up directly using the HTTP API endpoint. Using independent HTTP clients (1 per worker thread), then no problems.
Can you kindly share some general code how you did that?
Oh wait - it's in the docs.. Figured it out! thanks https://ai.google.dev/api/rest/v1beta/media/upload?hl=en This works
UPDATE - Halfway through the upload, I get this though Error processing batch: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
@RukshanJS Here's what I've been using, maybe it'll fix your error
import base64
import json
import requests
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account
class GoogleHTTPClient:
def __init__(self, model: str ='gemini-1.5-pro-preview-0409', max_retries: int = 5):
self.model = model
self.max_retries = max_retries
self.project = 'YOUR PROJECT ID'
self.region = 'YOUR REGION'
self.access_token = self._get_access_token()
def request_message(self, messages: list[object], retries: int = 0) -> str:
url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project}/locations/{self.region}/publishers/google/models/{self.model}:generateContent"
headers = {
'Authorization': f'Bearer {self.access_token}',
'Content-Type': 'application/json',
}
data = json.dumps({
'contents': {
'role': 'user',
'parts': [messages]
}
})
response = requests.post(url, headers=headers, data=data)
response = json.loads(response.text)
try:
response = response['candidates'][0]['content']['parts'][0]['text'].strip()
return response
except Exception as error:
if retries == self.max_retries:
raise Exception(f'Failed to fetch from google:', error)
return self.request_message(messages, retries + 1)
def format_image_message(self, image: bytes) -> object:
return {
"inlineData": {
"mimeType": 'image/png',
"data": base64.b64encode(image).decode('utf-8')
}
}
def format_text_message(self, text: str) -> object:
return { 'text': text }
def _get_access_token(self) -> str:
service_account_key = './google_credentials.json'
credentials = service_account.Credentials.from_service_account_file(
service_account_key,
scopes=['https://www.googleapis.com/auth/cloud-platform']
)
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)
return credentials.token
@RukshanJS Here's what I've been using, maybe it'll fix your error
import base64 import json import requests import google.auth import google.auth.transport.requests from google.oauth2 import service_account class GoogleHTTPClient: def __init__(self, model: str ='gemini-1.5-pro-preview-0409', max_retries: int = 5): self.model = model self.max_retries = max_retries self.project = 'YOUR PROJECT ID' self.region = 'YOUR REGION' self.access_token = self._get_access_token() def request_message(self, messages: list[object], retries: int = 0) -> str: url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project}/locations/{self.region}/publishers/google/models/{self.model}:generateContent" headers = { 'Authorization': f'Bearer {self.access_token}', 'Content-Type': 'application/json', } data = json.dumps({ 'contents': { 'role': 'user', 'parts': [messages] } }) response = requests.post(url, headers=headers, data=data) response = json.loads(response.text) try: response = response['candidates'][0]['content']['parts'][0]['text'].strip() return response except Exception as error: if retries == self.max_retries: raise Exception(f'Failed to fetch from google:', error) return self.request_message(messages, retries + 1) def format_image_message(self, image: bytes) -> object: return { "inlineData": { "mimeType": 'image/png', "data": base64.b64encode(image).decode('utf-8') } } def format_text_message(self, text: str) -> object: return { 'text': text } def _get_access_token(self) -> str: service_account_key = './google_credentials.json' credentials = service_account.Credentials.from_service_account_file( service_account_key, scopes=['https://www.googleapis.com/auth/cloud-platform'] ) auth_req = google.auth.transport.requests.Request() credentials.refresh(auth_req) return credentials.token
This is very helpful thanks a lot!. Do you mind sharing the usage of this client as well. I have a little bit of trouble understanding why we have to use an access token here. Can't we use the Gemini API itself using the key? without using the genai client?
My current code is,
def upload_file(file_path):
return genai.upload_file(pathlib.Path(file_path))
def analyze_frames(frame_paths):
model = genai.GenerativeModel("gemini-1.5-pro")
prompt = """
Give me a description after looking at these images
"""
# Upload the images using the File API
logger.info("Uploading files...")
file_references = []
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = [
executor.submit(upload_file, frame_path) for frame_path in frame_paths
]
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Uploading files",
):
file_references.append(future.result())
logger.info("Uploading files completed...")
# Generate content using the file references
logger.info("Making inferences using the model...", model.model_name)
response = model.generate_content(
[
prompt,
*file_references,
]
)
return response.text
def process_batches(frames, batch_size):
# Split the frames into batches
batch_list = [frames[i : i + batch_size] for i in range(0, len(frames), batch_size)]
batch_results = []
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(analyze_frames, batch): batch for batch in batch_list
}
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Processing batches",
):
batch_results.append(future.result())
return batch_results
I call this as,
result = process_batches(frame_paths, 3550)
Problem
It uploads parallel correctly for sometime, and then hangs (doesn't progress after this)
Uploading files: 4%|████▍ 155/3550 [00:53<19:29, 2.90it/s]
Uploading files: 9%|████████▊ 158/1850 [00:55<09:53, 2.85it/s]
Observations
- It always fails around at ~160 images in each batch process
Hey everybody,
I think @jpdaigle's original issue is resolved.
Aside from 429 Quota exceeded I haven't gotten any errors. I added request_options=dict(retry=retry.Retry(timeout=600)) to allow retries with along timeout.
The retries may help with a lot of other errors, which seem much less common now.
https://gist.github.com/MarkDaoust/dcd65b626bf4683860aa510b79bc225e
So I think this bug is fixed.
@RukshanJS I've heard other reports of failures with threading specifically for file uploads.
How about we continue this in https://github.com/google-gemini/generative-ai-python/issues/327