diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 20398965c..2ca06e2fe 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -2259,30 +2259,54 @@ class ConversableAgent(LLMAgent): """ if message is None: message = self.get_human_input(">") - if isinstance(message, str): - return self._process_carryover(message, kwargs) - elif isinstance(message, dict): - message = message.copy() - # TODO: Do we need to do the following? - # if message.get("content") is None: - # message["content"] = self.get_human_input(">") - message["content"] = self._process_carryover(message.get("content", ""), kwargs) + + return self._handle_carryover(message, kwargs) + + def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]: + if not kwargs.get("carryover"): return message - def _process_carryover(self, message: str, kwargs: dict) -> str: - carryover = kwargs.get("carryover") - if carryover: - # if carryover is string - if isinstance(carryover, str): - message += "\nContext: \n" + carryover - elif isinstance(carryover, list): - message += "\nContext: \n" + ("\n").join([t for t in carryover]) - else: - raise InvalidCarryOverType( - "Carryover should be a string or a list of strings. Not adding carryover to the message." - ) + if isinstance(message, str): + return self._process_carryover(message, kwargs) + + elif isinstance(message, dict): + if isinstance(message.get("content"), str): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_carryover(message["content"], kwargs) + elif isinstance(message.get("content"), list): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_multimodal_carryover(message["content"], kwargs) + else: + raise InvalidCarryOverType("Carryover should be a string or a list of strings.") + return message + def _process_carryover(self, content: str, kwargs: dict) -> str: + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + # if carryover is string + if isinstance(kwargs["carryover"], str): + content += "\nContext: \n" + kwargs["carryover"] + elif isinstance(kwargs["carryover"], list): + content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]]) + else: + raise InvalidCarryOverType( + "Carryover should be a string or a list of strings. Not adding carryover to the message." + ) + return content + + def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]: + """Prepends the context to a multimodal message.""" + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content + async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: """Generate the initial message for the agent. If message is None, input() will be called to get the initial message. @@ -2295,12 +2319,8 @@ class ConversableAgent(LLMAgent): """ if message is None: message = await self.a_get_human_input(">") - if isinstance(message, str): - return self._process_carryover(message, kwargs) - elif isinstance(message, dict): - message = message.copy() - message["content"] = self._process_carryover(message["content"], kwargs) - return message + + return self._handle_carryover(message, kwargs) def register_function(self, function_map: Dict[str, Union[Callable, None]]): """Register functions to the agent. diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 6d9c1ac32..5ff2c0c10 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1263,6 +1263,54 @@ def test_messages_with_carryover(): with pytest.raises(InvalidCarryOverType): agent1.generate_init_message(**context) + # Test multimodal messages + mm_content = [ + {"type": "text", "text": "hello"}, + {"type": "text", "text": "goodbye"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"}, + }, + ] + mm_message = {"content": mm_content} + context = dict( + message=mm_message, + carryover="Testing carryover.", + ) + generated_message = agent1.generate_init_message(**context) + assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 4 + + context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"]) + generated_message = agent1.generate_init_message(**context) + assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 4 + + context = dict(message=mm_message, carryover=3) + with pytest.raises(InvalidCarryOverType): + agent1.generate_init_message(**context) + + # Test without carryover + print(mm_message) + context = dict(message=mm_message) + generated_message = agent1.generate_init_message(**context) + assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 3 + + # Test without text in multimodal message + mm_content = [ + {"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}, + ] + mm_message = {"content": mm_content} + context = dict(message=mm_message) + generated_message = agent1.generate_init_message(**context) + assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 1 + + generated_message = agent1.generate_init_message(**context, carryover="Testing carryover.") + assert isinstance(generated_message, dict) + assert len(generated_message["content"]) == 2 + if __name__ == "__main__": # test_trigger()