mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-09 14:24:05 +00:00
Fix missing messages in Gemini history (#2906)
* fix missing message in history * fix message handling * add list of Parts to Content object * add test for gemini message conversion function * add test for gemini message conversion * add message to asserts * add safety setting support for vertexai * remove vertexai safety settings
This commit is contained in:
parent
6d4cf406f9
commit
10b8fa548b
@ -194,7 +194,7 @@ class GeminiClient:
|
|||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
ans = None
|
ans = None
|
||||||
try:
|
try:
|
||||||
response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
|
response = chat.send_message(gemini_messages[-1], stream=stream)
|
||||||
except InternalServerError:
|
except InternalServerError:
|
||||||
delay = 5 * (2**attempt)
|
delay = 5 * (2**attempt)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -344,19 +344,19 @@ class GeminiClient:
|
|||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
parts = self._oai_content_to_gemini_content(message["content"])
|
parts = self._oai_content_to_gemini_content(message["content"])
|
||||||
role = "user" if message["role"] in ["user", "system"] else "model"
|
role = "user" if message["role"] in ["user", "system"] else "model"
|
||||||
|
if (prev_role is None) or (role == prev_role):
|
||||||
if prev_role is None or role == prev_role:
|
|
||||||
curr_parts += parts
|
curr_parts += parts
|
||||||
elif role != prev_role:
|
elif role != prev_role:
|
||||||
if self.use_vertexai:
|
if self.use_vertexai:
|
||||||
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=prev_role))
|
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
|
||||||
else:
|
else:
|
||||||
rst.append(Content(parts=curr_parts, role=prev_role))
|
rst.append(Content(parts=curr_parts, role=prev_role))
|
||||||
|
curr_parts = parts
|
||||||
prev_role = role
|
prev_role = role
|
||||||
|
|
||||||
# handle the last message
|
# handle the last message
|
||||||
if self.use_vertexai:
|
if self.use_vertexai:
|
||||||
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=role))
|
rst.append(VertexAIContent(parts=curr_parts, role=role))
|
||||||
else:
|
else:
|
||||||
rst.append(Content(parts=curr_parts, role=role))
|
rst.append(Content(parts=curr_parts, role=role))
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,48 @@ def test_valid_initialization(gemini_client):
|
|||||||
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
|
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||||
|
def test_gemini_message_handling(gemini_client):
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are my personal assistant."},
|
||||||
|
{"role": "model", "content": "How can I help you?"},
|
||||||
|
{"role": "user", "content": "Which planet is the nearest to the sun?"},
|
||||||
|
{"role": "user", "content": "Which planet is the farthest from the sun?"},
|
||||||
|
{"role": "model", "content": "Mercury is the closest palnet to the sun."},
|
||||||
|
{"role": "model", "content": "Neptune is the farthest palnet from the sun."},
|
||||||
|
{"role": "user", "content": "How can we determine the mass of a black hole?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# The datastructure below defines what the structure of the messages
|
||||||
|
# should resemble after converting to Gemini format.
|
||||||
|
# Messages of similar roles are expected to be merged to a single message,
|
||||||
|
# where the contents of the original messages will be included in
|
||||||
|
# consecutive parts of the converted Gemini message
|
||||||
|
expected_gemini_struct = [
|
||||||
|
# system role is converted to user role
|
||||||
|
{"role": "user", "parts": ["You are my personal assistant."]},
|
||||||
|
{"role": "model", "parts": ["How can I help you?"]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": ["Which planet is the nearest to the sun?", "Which planet is the farthest from the sun?"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": ["Mercury is the closest palnet to the sun.", "Neptune is the farthest palnet from the sun."],
|
||||||
|
},
|
||||||
|
{"role": "user", "parts": ["How can we determine the mass of a black hole?"]},
|
||||||
|
]
|
||||||
|
|
||||||
|
converted_messages = gemini_client._oai_messages_to_gemini_messages(messages)
|
||||||
|
|
||||||
|
assert len(converted_messages) == len(expected_gemini_struct), "The number of messages is not as expected"
|
||||||
|
|
||||||
|
for i, expected_msg in enumerate(expected_gemini_struct):
|
||||||
|
assert expected_msg["role"] == converted_messages[i].role, "Incorrect mapped message role"
|
||||||
|
for j, part in enumerate(expected_msg["parts"]):
|
||||||
|
assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text"
|
||||||
|
|
||||||
|
|
||||||
# Test error handling
|
# Test error handling
|
||||||
@patch("autogen.oai.gemini.genai")
|
@patch("autogen.oai.gemini.genai")
|
||||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user