Fix generate_init_message for Multimodal Messages (#2124)

* multimodal carryover

* adds mm carryover tests

* more tests + cleanup code

* check content instead

* beibin suggestion

* cleanup

* fix async

* use deepcopy

* handle carryover method

* remove content copy

* sonichi suggestions

---------

Co-authored-by: Beibin Li <BeibinLi@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Wael Karkoub 2024-03-30 03:10:24 +01:00 committed by GitHub
parent 6ed8f696ef
commit 7a685b52d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 26 deletions

View File

@ -2259,30 +2259,54 @@ class ConversableAgent(LLMAgent):
"""
if message is None:
message = self.get_human_input(">")
if isinstance(message, str):
return self._process_carryover(message, kwargs)
elif isinstance(message, dict):
message = message.copy()
# TODO: Do we need to do the following?
# if message.get("content") is None:
# message["content"] = self.get_human_input(">")
message["content"] = self._process_carryover(message.get("content", ""), kwargs)
return self._handle_carryover(message, kwargs)
def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
if not kwargs.get("carryover"):
return message
def _process_carryover(self, message: str, kwargs: dict) -> str:
carryover = kwargs.get("carryover")
if carryover:
# if carryover is string
if isinstance(carryover, str):
message += "\nContext: \n" + carryover
elif isinstance(carryover, list):
message += "\nContext: \n" + ("\n").join([t for t in carryover])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
if isinstance(message, str):
return self._process_carryover(message, kwargs)
elif isinstance(message, dict):
if isinstance(message.get("content"), str):
# Makes sure the original message is not mutated
message = message.copy()
message["content"] = self._process_carryover(message["content"], kwargs)
elif isinstance(message.get("content"), list):
# Makes sure the original message is not mutated
message = message.copy()
message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
else:
raise InvalidCarryOverType("Carryover should be a string or a list of strings.")
return message
def _process_carryover(self, content: str, kwargs: dict) -> str:
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content
# if carryover is string
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
return content
def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
"""Prepends the context to a multimodal message."""
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content
return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content
async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
"""Generate the initial message for the agent.
If message is None, input() will be called to get the initial message.
@ -2295,12 +2319,8 @@ class ConversableAgent(LLMAgent):
"""
if message is None:
message = await self.a_get_human_input(">")
if isinstance(message, str):
return self._process_carryover(message, kwargs)
elif isinstance(message, dict):
message = message.copy()
message["content"] = self._process_carryover(message["content"], kwargs)
return message
return self._handle_carryover(message, kwargs)
def register_function(self, function_map: Dict[str, Union[Callable, None]]):
"""Register functions to the agent.

View File

@ -1263,6 +1263,54 @@ def test_messages_with_carryover():
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)
# Test multimodal messages
mm_content = [
{"type": "text", "text": "hello"},
{"type": "text", "text": "goodbye"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"},
},
]
mm_message = {"content": mm_content}
context = dict(
message=mm_message,
carryover="Testing carryover.",
)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4
context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"])
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4
context = dict(message=mm_message, carryover=3)
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)
# Test without carryover
print(mm_message)
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 3
# Test without text in multimodal message
mm_content = [
{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}},
]
mm_message = {"content": mm_content}
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 1
generated_message = agent1.generate_init_message(**context, carryover="Testing carryover.")
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 2
if __name__ == "__main__":
# test_trigger()