Fix missing messages in Gemini history (#2906)

* fix missing message in history

* fix message handling

* add list of Parts to Content object

* add test for gemini message conversion function

* add test for gemini message conversion

* add message to asserts

* add safety setting support for vertexai

* remove vertexai safety settings
This commit is contained in:
Zoltan Lux 2024-06-14 20:13:19 +02:00 committed by GitHub
parent 6d4cf406f9
commit 10b8fa548b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 5 deletions

View File

@ -194,7 +194,7 @@ class GeminiClient:
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
response = chat.send_message(gemini_messages[-1], stream=stream)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
@ -344,19 +344,19 @@ class GeminiClient:
for i, message in enumerate(messages):
parts = self._oai_content_to_gemini_content(message["content"])
role = "user" if message["role"] in ["user", "system"] else "model"
if prev_role is None or role == prev_role:
if (prev_role is None) or (role == prev_role):
curr_parts += parts
elif role != prev_role:
if self.use_vertexai:
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=prev_role))
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
else:
rst.append(Content(parts=curr_parts, role=prev_role))
curr_parts = parts
prev_role = role
# handle the last message
if self.use_vertexai:
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=role))
rst.append(VertexAIContent(parts=curr_parts, role=role))
else:
rst.append(Content(parts=curr_parts, role=role))

View File

@ -52,6 +52,48 @@ def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_message_handling(gemini_client):
messages = [
{"role": "system", "content": "You are my personal assistant."},
{"role": "model", "content": "How can I help you?"},
{"role": "user", "content": "Which planet is the nearest to the sun?"},
{"role": "user", "content": "Which planet is the farthest from the sun?"},
{"role": "model", "content": "Mercury is the closest palnet to the sun."},
{"role": "model", "content": "Neptune is the farthest palnet from the sun."},
{"role": "user", "content": "How can we determine the mass of a black hole?"},
]
# The datastructure below defines what the structure of the messages
# should resemble after converting to Gemini format.
# Messages of similar roles are expected to be merged to a single message,
# where the contents of the original messages will be included in
# consecutive parts of the converted Gemini message
expected_gemini_struct = [
# system role is converted to user role
{"role": "user", "parts": ["You are my personal assistant."]},
{"role": "model", "parts": ["How can I help you?"]},
{
"role": "user",
"parts": ["Which planet is the nearest to the sun?", "Which planet is the farthest from the sun?"],
},
{
"role": "model",
"parts": ["Mercury is the closest palnet to the sun.", "Neptune is the farthest palnet from the sun."],
},
{"role": "user", "parts": ["How can we determine the mass of a black hole?"]},
]
converted_messages = gemini_client._oai_messages_to_gemini_messages(messages)
assert len(converted_messages) == len(expected_gemini_struct), "The number of messages is not as expected"
for i, expected_msg in enumerate(expected_gemini_struct):
assert expected_msg["role"] == converted_messages[i].role, "Incorrect mapped message role"
for j, part in enumerate(expected_msg["parts"]):
assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text"
# Test error handling
@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")