mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-07 21:34:00 +00:00
Fix missing messages in Gemini history (#2906)
* fix missing message in history * fix message handling * add list of Parts to Content object * add test for gemini message conversion function * add test for gemini message conversion * add message to asserts * add safety setting support for vertexai * remove vertexai safety settings
This commit is contained in:
parent
6d4cf406f9
commit
10b8fa548b
@ -194,7 +194,7 @@ class GeminiClient:
|
||||
for attempt in range(max_retries):
|
||||
ans = None
|
||||
try:
|
||||
response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
|
||||
response = chat.send_message(gemini_messages[-1], stream=stream)
|
||||
except InternalServerError:
|
||||
delay = 5 * (2**attempt)
|
||||
warnings.warn(
|
||||
@ -344,19 +344,19 @@ class GeminiClient:
|
||||
for i, message in enumerate(messages):
|
||||
parts = self._oai_content_to_gemini_content(message["content"])
|
||||
role = "user" if message["role"] in ["user", "system"] else "model"
|
||||
|
||||
if prev_role is None or role == prev_role:
|
||||
if (prev_role is None) or (role == prev_role):
|
||||
curr_parts += parts
|
||||
elif role != prev_role:
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=prev_role))
|
||||
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
|
||||
else:
|
||||
rst.append(Content(parts=curr_parts, role=prev_role))
|
||||
curr_parts = parts
|
||||
prev_role = role
|
||||
|
||||
# handle the last message
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=role))
|
||||
rst.append(VertexAIContent(parts=curr_parts, role=role))
|
||||
else:
|
||||
rst.append(Content(parts=curr_parts, role=role))
|
||||
|
||||
|
||||
@ -52,6 +52,48 @@ def test_valid_initialization(gemini_client):
|
||||
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_gemini_message_handling(gemini_client):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are my personal assistant."},
|
||||
{"role": "model", "content": "How can I help you?"},
|
||||
{"role": "user", "content": "Which planet is the nearest to the sun?"},
|
||||
{"role": "user", "content": "Which planet is the farthest from the sun?"},
|
||||
{"role": "model", "content": "Mercury is the closest palnet to the sun."},
|
||||
{"role": "model", "content": "Neptune is the farthest palnet from the sun."},
|
||||
{"role": "user", "content": "How can we determine the mass of a black hole?"},
|
||||
]
|
||||
|
||||
# The datastructure below defines what the structure of the messages
|
||||
# should resemble after converting to Gemini format.
|
||||
# Messages of similar roles are expected to be merged to a single message,
|
||||
# where the contents of the original messages will be included in
|
||||
# consecutive parts of the converted Gemini message
|
||||
expected_gemini_struct = [
|
||||
# system role is converted to user role
|
||||
{"role": "user", "parts": ["You are my personal assistant."]},
|
||||
{"role": "model", "parts": ["How can I help you?"]},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": ["Which planet is the nearest to the sun?", "Which planet is the farthest from the sun?"],
|
||||
},
|
||||
{
|
||||
"role": "model",
|
||||
"parts": ["Mercury is the closest palnet to the sun.", "Neptune is the farthest palnet from the sun."],
|
||||
},
|
||||
{"role": "user", "parts": ["How can we determine the mass of a black hole?"]},
|
||||
]
|
||||
|
||||
converted_messages = gemini_client._oai_messages_to_gemini_messages(messages)
|
||||
|
||||
assert len(converted_messages) == len(expected_gemini_struct), "The number of messages is not as expected"
|
||||
|
||||
for i, expected_msg in enumerate(expected_gemini_struct):
|
||||
assert expected_msg["role"] == converted_messages[i].role, "Incorrect mapped message role"
|
||||
for j, part in enumerate(expected_msg["parts"]):
|
||||
assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text"
|
||||
|
||||
|
||||
# Test error handling
|
||||
@patch("autogen.oai.gemini.genai")
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user