mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-13 11:11:25 +00:00
Fix generate_init_message
for Multimodal Messages (#2124)
* multimodal carryover * adds mm carryover tests * more tests + cleanup code * check content instead * beibin suggestion * cleanup * fix async * use deepcopy * handle carryover method * remove content copy * sonichi suggestions --------- Co-authored-by: Beibin Li <BeibinLi@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
6ed8f696ef
commit
7a685b52d8
@ -2259,30 +2259,54 @@ class ConversableAgent(LLMAgent):
|
|||||||
"""
|
"""
|
||||||
if message is None:
|
if message is None:
|
||||||
message = self.get_human_input(">")
|
message = self.get_human_input(">")
|
||||||
if isinstance(message, str):
|
|
||||||
return self._process_carryover(message, kwargs)
|
return self._handle_carryover(message, kwargs)
|
||||||
elif isinstance(message, dict):
|
|
||||||
message = message.copy()
|
def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
|
||||||
# TODO: Do we need to do the following?
|
if not kwargs.get("carryover"):
|
||||||
# if message.get("content") is None:
|
|
||||||
# message["content"] = self.get_human_input(">")
|
|
||||||
message["content"] = self._process_carryover(message.get("content", ""), kwargs)
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def _process_carryover(self, message: str, kwargs: dict) -> str:
|
if isinstance(message, str):
|
||||||
carryover = kwargs.get("carryover")
|
return self._process_carryover(message, kwargs)
|
||||||
if carryover:
|
|
||||||
# if carryover is string
|
elif isinstance(message, dict):
|
||||||
if isinstance(carryover, str):
|
if isinstance(message.get("content"), str):
|
||||||
message += "\nContext: \n" + carryover
|
# Makes sure the original message is not mutated
|
||||||
elif isinstance(carryover, list):
|
message = message.copy()
|
||||||
message += "\nContext: \n" + ("\n").join([t for t in carryover])
|
message["content"] = self._process_carryover(message["content"], kwargs)
|
||||||
else:
|
elif isinstance(message.get("content"), list):
|
||||||
raise InvalidCarryOverType(
|
# Makes sure the original message is not mutated
|
||||||
"Carryover should be a string or a list of strings. Not adding carryover to the message."
|
message = message.copy()
|
||||||
)
|
message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
|
||||||
|
else:
|
||||||
|
raise InvalidCarryOverType("Carryover should be a string or a list of strings.")
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
def _process_carryover(self, content: str, kwargs: dict) -> str:
|
||||||
|
# Makes sure there's a carryover
|
||||||
|
if not kwargs.get("carryover"):
|
||||||
|
return content
|
||||||
|
|
||||||
|
# if carryover is string
|
||||||
|
if isinstance(kwargs["carryover"], str):
|
||||||
|
content += "\nContext: \n" + kwargs["carryover"]
|
||||||
|
elif isinstance(kwargs["carryover"], list):
|
||||||
|
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
|
||||||
|
else:
|
||||||
|
raise InvalidCarryOverType(
|
||||||
|
"Carryover should be a string or a list of strings. Not adding carryover to the message."
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
|
||||||
|
"""Prepends the context to a multimodal message."""
|
||||||
|
# Makes sure there's a carryover
|
||||||
|
if not kwargs.get("carryover"):
|
||||||
|
return content
|
||||||
|
|
||||||
|
return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content
|
||||||
|
|
||||||
async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
|
async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
|
||||||
"""Generate the initial message for the agent.
|
"""Generate the initial message for the agent.
|
||||||
If message is None, input() will be called to get the initial message.
|
If message is None, input() will be called to get the initial message.
|
||||||
@ -2295,12 +2319,8 @@ class ConversableAgent(LLMAgent):
|
|||||||
"""
|
"""
|
||||||
if message is None:
|
if message is None:
|
||||||
message = await self.a_get_human_input(">")
|
message = await self.a_get_human_input(">")
|
||||||
if isinstance(message, str):
|
|
||||||
return self._process_carryover(message, kwargs)
|
return self._handle_carryover(message, kwargs)
|
||||||
elif isinstance(message, dict):
|
|
||||||
message = message.copy()
|
|
||||||
message["content"] = self._process_carryover(message["content"], kwargs)
|
|
||||||
return message
|
|
||||||
|
|
||||||
def register_function(self, function_map: Dict[str, Union[Callable, None]]):
|
def register_function(self, function_map: Dict[str, Union[Callable, None]]):
|
||||||
"""Register functions to the agent.
|
"""Register functions to the agent.
|
||||||
|
@ -1263,6 +1263,54 @@ def test_messages_with_carryover():
|
|||||||
with pytest.raises(InvalidCarryOverType):
|
with pytest.raises(InvalidCarryOverType):
|
||||||
agent1.generate_init_message(**context)
|
agent1.generate_init_message(**context)
|
||||||
|
|
||||||
|
# Test multimodal messages
|
||||||
|
mm_content = [
|
||||||
|
{"type": "text", "text": "hello"},
|
||||||
|
{"type": "text", "text": "goodbye"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://example.com/image.png"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mm_message = {"content": mm_content}
|
||||||
|
context = dict(
|
||||||
|
message=mm_message,
|
||||||
|
carryover="Testing carryover.",
|
||||||
|
)
|
||||||
|
generated_message = agent1.generate_init_message(**context)
|
||||||
|
assert isinstance(generated_message, dict)
|
||||||
|
assert len(generated_message["content"]) == 4
|
||||||
|
|
||||||
|
context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"])
|
||||||
|
generated_message = agent1.generate_init_message(**context)
|
||||||
|
assert isinstance(generated_message, dict)
|
||||||
|
assert len(generated_message["content"]) == 4
|
||||||
|
|
||||||
|
context = dict(message=mm_message, carryover=3)
|
||||||
|
with pytest.raises(InvalidCarryOverType):
|
||||||
|
agent1.generate_init_message(**context)
|
||||||
|
|
||||||
|
# Test without carryover
|
||||||
|
print(mm_message)
|
||||||
|
context = dict(message=mm_message)
|
||||||
|
generated_message = agent1.generate_init_message(**context)
|
||||||
|
assert isinstance(generated_message, dict)
|
||||||
|
assert len(generated_message["content"]) == 3
|
||||||
|
|
||||||
|
# Test without text in multimodal message
|
||||||
|
mm_content = [
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}},
|
||||||
|
]
|
||||||
|
mm_message = {"content": mm_content}
|
||||||
|
context = dict(message=mm_message)
|
||||||
|
generated_message = agent1.generate_init_message(**context)
|
||||||
|
assert isinstance(generated_message, dict)
|
||||||
|
assert len(generated_message["content"]) == 1
|
||||||
|
|
||||||
|
generated_message = agent1.generate_init_message(**context, carryover="Testing carryover.")
|
||||||
|
assert isinstance(generated_message, dict)
|
||||||
|
assert len(generated_message["content"]) == 2
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test_trigger()
|
# test_trigger()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user