Fix generate_init_message for Multimodal Messages (#2124)

* multimodal carryover

* adds mm carryover tests

* more tests + cleanup code

* check content instead

* beibin suggestion

* cleanup

* fix async

* use deepcopy

* handle carryover method

* remove content copy

* sonichi suggestions

---------

Co-authored-by: Beibin Li <BeibinLi@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Wael Karkoub 2024-03-30 03:10:24 +01:00 committed by GitHub
parent 6ed8f696ef
commit 7a685b52d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 26 deletions

View File

@ -2259,30 +2259,54 @@ class ConversableAgent(LLMAgent):
""" """
if message is None: if message is None:
message = self.get_human_input(">") message = self.get_human_input(">")
if isinstance(message, str):
return self._process_carryover(message, kwargs) return self._handle_carryover(message, kwargs)
elif isinstance(message, dict):
message = message.copy() def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
# TODO: Do we need to do the following? if not kwargs.get("carryover"):
# if message.get("content") is None:
# message["content"] = self.get_human_input(">")
message["content"] = self._process_carryover(message.get("content", ""), kwargs)
return message return message
def _process_carryover(self, message: str, kwargs: dict) -> str: if isinstance(message, str):
carryover = kwargs.get("carryover") return self._process_carryover(message, kwargs)
if carryover:
# if carryover is string elif isinstance(message, dict):
if isinstance(carryover, str): if isinstance(message.get("content"), str):
message += "\nContext: \n" + carryover # Makes sure the original message is not mutated
elif isinstance(carryover, list): message = message.copy()
message += "\nContext: \n" + ("\n").join([t for t in carryover]) message["content"] = self._process_carryover(message["content"], kwargs)
else: elif isinstance(message.get("content"), list):
raise InvalidCarryOverType( # Makes sure the original message is not mutated
"Carryover should be a string or a list of strings. Not adding carryover to the message." message = message.copy()
) message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
else:
raise InvalidCarryOverType("Carryover should be a string or a list of strings.")
return message return message
def _process_carryover(self, content: str, kwargs: dict) -> str:
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content
# if carryover is string
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
return content
def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
"""Prepends the context to a multimodal message."""
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content
return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content
async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
"""Generate the initial message for the agent. """Generate the initial message for the agent.
If message is None, input() will be called to get the initial message. If message is None, input() will be called to get the initial message.
@ -2295,12 +2319,8 @@ class ConversableAgent(LLMAgent):
""" """
if message is None: if message is None:
message = await self.a_get_human_input(">") message = await self.a_get_human_input(">")
if isinstance(message, str):
return self._process_carryover(message, kwargs) return self._handle_carryover(message, kwargs)
elif isinstance(message, dict):
message = message.copy()
message["content"] = self._process_carryover(message["content"], kwargs)
return message
def register_function(self, function_map: Dict[str, Union[Callable, None]]): def register_function(self, function_map: Dict[str, Union[Callable, None]]):
"""Register functions to the agent. """Register functions to the agent.

View File

@ -1263,6 +1263,54 @@ def test_messages_with_carryover():
with pytest.raises(InvalidCarryOverType): with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context) agent1.generate_init_message(**context)
# Test multimodal messages
mm_content = [
{"type": "text", "text": "hello"},
{"type": "text", "text": "goodbye"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"},
},
]
mm_message = {"content": mm_content}
context = dict(
message=mm_message,
carryover="Testing carryover.",
)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4
context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"])
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4
context = dict(message=mm_message, carryover=3)
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)
# Test without carryover
print(mm_message)
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 3
# Test without text in multimodal message
mm_content = [
{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}},
]
mm_message = {"content": mm_content}
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 1
generated_message = agent1.generate_init_message(**context, carryover="Testing carryover.")
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 2
if __name__ == "__main__": if __name__ == "__main__":
# test_trigger() # test_trigger()