mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-11 02:01:10 +00:00
Fix generate_init_message
for Multimodal Messages (#2124)
* multimodal carryover * adds mm carryover tests * more tests + cleanup code * check content instead * beibin suggestion * cleanup * fix async * use deepcopy * handle carryover method * remove content copy * sonichi suggestions --------- Co-authored-by: Beibin Li <BeibinLi@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
6ed8f696ef
commit
7a685b52d8
@ -2259,30 +2259,54 @@ class ConversableAgent(LLMAgent):
|
||||
"""
|
||||
if message is None:
|
||||
message = self.get_human_input(">")
|
||||
if isinstance(message, str):
|
||||
return self._process_carryover(message, kwargs)
|
||||
elif isinstance(message, dict):
|
||||
message = message.copy()
|
||||
# TODO: Do we need to do the following?
|
||||
# if message.get("content") is None:
|
||||
# message["content"] = self.get_human_input(">")
|
||||
message["content"] = self._process_carryover(message.get("content", ""), kwargs)
|
||||
|
||||
return self._handle_carryover(message, kwargs)
|
||||
|
||||
def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
|
||||
if not kwargs.get("carryover"):
|
||||
return message
|
||||
|
||||
def _process_carryover(self, message: str, kwargs: dict) -> str:
|
||||
carryover = kwargs.get("carryover")
|
||||
if carryover:
|
||||
# if carryover is string
|
||||
if isinstance(carryover, str):
|
||||
message += "\nContext: \n" + carryover
|
||||
elif isinstance(carryover, list):
|
||||
message += "\nContext: \n" + ("\n").join([t for t in carryover])
|
||||
else:
|
||||
raise InvalidCarryOverType(
|
||||
"Carryover should be a string or a list of strings. Not adding carryover to the message."
|
||||
)
|
||||
if isinstance(message, str):
|
||||
return self._process_carryover(message, kwargs)
|
||||
|
||||
elif isinstance(message, dict):
|
||||
if isinstance(message.get("content"), str):
|
||||
# Makes sure the original message is not mutated
|
||||
message = message.copy()
|
||||
message["content"] = self._process_carryover(message["content"], kwargs)
|
||||
elif isinstance(message.get("content"), list):
|
||||
# Makes sure the original message is not mutated
|
||||
message = message.copy()
|
||||
message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
|
||||
else:
|
||||
raise InvalidCarryOverType("Carryover should be a string or a list of strings.")
|
||||
|
||||
return message
|
||||
|
||||
def _process_carryover(self, content: str, kwargs: dict) -> str:
|
||||
# Makes sure there's a carryover
|
||||
if not kwargs.get("carryover"):
|
||||
return content
|
||||
|
||||
# if carryover is string
|
||||
if isinstance(kwargs["carryover"], str):
|
||||
content += "\nContext: \n" + kwargs["carryover"]
|
||||
elif isinstance(kwargs["carryover"], list):
|
||||
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
|
||||
else:
|
||||
raise InvalidCarryOverType(
|
||||
"Carryover should be a string or a list of strings. Not adding carryover to the message."
|
||||
)
|
||||
return content
|
||||
|
||||
def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
|
||||
"""Prepends the context to a multimodal message."""
|
||||
# Makes sure there's a carryover
|
||||
if not kwargs.get("carryover"):
|
||||
return content
|
||||
|
||||
return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content
|
||||
|
||||
async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
|
||||
"""Generate the initial message for the agent.
|
||||
If message is None, input() will be called to get the initial message.
|
||||
@ -2295,12 +2319,8 @@ class ConversableAgent(LLMAgent):
|
||||
"""
|
||||
if message is None:
|
||||
message = await self.a_get_human_input(">")
|
||||
if isinstance(message, str):
|
||||
return self._process_carryover(message, kwargs)
|
||||
elif isinstance(message, dict):
|
||||
message = message.copy()
|
||||
message["content"] = self._process_carryover(message["content"], kwargs)
|
||||
return message
|
||||
|
||||
return self._handle_carryover(message, kwargs)
|
||||
|
||||
def register_function(self, function_map: Dict[str, Union[Callable, None]]):
|
||||
"""Register functions to the agent.
|
||||
|
@ -1263,6 +1263,54 @@ def test_messages_with_carryover():
|
||||
with pytest.raises(InvalidCarryOverType):
|
||||
agent1.generate_init_message(**context)
|
||||
|
||||
# Test multimodal messages
|
||||
mm_content = [
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "text", "text": "goodbye"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.png"},
|
||||
},
|
||||
]
|
||||
mm_message = {"content": mm_content}
|
||||
context = dict(
|
||||
message=mm_message,
|
||||
carryover="Testing carryover.",
|
||||
)
|
||||
generated_message = agent1.generate_init_message(**context)
|
||||
assert isinstance(generated_message, dict)
|
||||
assert len(generated_message["content"]) == 4
|
||||
|
||||
context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"])
|
||||
generated_message = agent1.generate_init_message(**context)
|
||||
assert isinstance(generated_message, dict)
|
||||
assert len(generated_message["content"]) == 4
|
||||
|
||||
context = dict(message=mm_message, carryover=3)
|
||||
with pytest.raises(InvalidCarryOverType):
|
||||
agent1.generate_init_message(**context)
|
||||
|
||||
# Test without carryover
|
||||
print(mm_message)
|
||||
context = dict(message=mm_message)
|
||||
generated_message = agent1.generate_init_message(**context)
|
||||
assert isinstance(generated_message, dict)
|
||||
assert len(generated_message["content"]) == 3
|
||||
|
||||
# Test without text in multimodal message
|
||||
mm_content = [
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}},
|
||||
]
|
||||
mm_message = {"content": mm_content}
|
||||
context = dict(message=mm_message)
|
||||
generated_message = agent1.generate_init_message(**context)
|
||||
assert isinstance(generated_message, dict)
|
||||
assert len(generated_message["content"]) == 1
|
||||
|
||||
generated_message = agent1.generate_init_message(**context, carryover="Testing carryover.")
|
||||
assert isinstance(generated_message, dict)
|
||||
assert len(generated_message["content"]) == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_trigger()
|
||||
|
Loading…
x
Reference in New Issue
Block a user