generative-ai-python
generative-ai-python copied to clipboard
Function Calling Does Not Work With Stop Sequences and Streaming
Description of the bug:
Function calling does not work when providing stop_sequences
and stream=True
.
Actual vs expected behavior:
Actual:
import google.generativeai as genai
import os
from google.generativeai.types import GenerationConfig
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
genai.configure(api_key=GOOGLE_API_KEY)
def multiply(a: float, b: float):
"""returns a * b."""
return a * b
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
response = model.generate_content(
"What is 4 * 3?",
stream=True,
tools=[multiply],
generation_config=GenerationConfig(stop_sequences=["Anything"]),
)
for chunk in response:
print(chunk.parts)
print("_" * 80)
No FunctionCall
is returned.
[text: "print(default_api.multiply(a = 4, b = 3))"
]
Expected (after removing stop_sequences
):
import google.generativeai as genai
import os
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
genai.configure(api_key=GOOGLE_API_KEY)
def multiply(a: float, b: float):
"""returns a * b."""
return a * b
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
response = model.generate_content(
"What is 4 * 3?",
stream=True,
tools=[multiply],
)
for chunk in response:
print(chunk.parts)
print("_" * 80)
A Part
containing a FunctionCall
is returned.
[function_call {
name: "multiply"
args {
fields {
key: "b"
value {
number_value: 3
}
}
fields {
key: "a"
value {
number_value: 4
}
}
}
}
]
Expected (after setting stream=False
):
import google.generativeai as genai
import os
from google.generativeai.types import GenerationConfig
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
genai.configure(api_key=GOOGLE_API_KEY)
def multiply(a: float, b: float):
"""returns a * b."""
return a * b
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
response = model.generate_content(
"What is 4 * 3?",
stream=False,
tools=[multiply],
generation_config=GenerationConfig(stop_sequences=["Anything"]),
)
for chunk in response:
print(chunk.parts)
print("_" * 80)
A Part
containing a FunctionCall
is returned.
[function_call {
name: "multiply"
args {
fields {
key: "b"
value {
number_value: 3
}
}
fields {
key: "a"
value {
number_value: 4
}
}
}
}
]
Any other information you'd like to share?
I am using version 0.7.2
of this library.