mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 14:38:50 +00:00
Replace the use of assert in non-test code (#80)
* Replace `assert`s in the `conversable_agent` module with `if-log-raise`. * Use a `logger` object in the `code_utils` module. * Replace use of `assert` with `if-log-raise` in the `code_utils` module. * Replace use of `assert` in the `math_utils` module with `if-not-raise`. * Replace `assert` with `if` in the `oai.completion` module. * Replace `assert` in the `retrieve_utils` module with an if statement. * Add missing `not`. * Blacken `completion.py`. * Test `generate_reply` and `a_generate_reply` raise an assertion error when there are neither `messages` nor a `sender`. * Test `execute_code` raises an `AssertionError` when neither code nor filename is provided. * Test `split_text_to_chunks` raises when passed an invalid chunk mode. * * Add `tiktoken` and `chromadb` to test dependencies as they're used in the `test_retrieve_utils` module. * Sort the test requirements alphabetically.
This commit is contained in:
parent
39c145dd53
commit
a3547f82c4
@ -2,6 +2,7 @@ import asyncio
|
||||
from collections import defaultdict
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from autogen import oai
|
||||
from .agent import Agent
|
||||
@ -21,6 +22,9 @@ except ImportError:
|
||||
return x
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversableAgent(Agent):
|
||||
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
|
||||
|
||||
@ -757,7 +761,11 @@ class ConversableAgent(Agent):
|
||||
Returns:
|
||||
str or dict or None: reply. None if no reply is generated.
|
||||
"""
|
||||
assert messages is not None or sender is not None, "Either messages or sender must be provided."
|
||||
if all((messages is None, sender is None)):
|
||||
error_msg = f"Either {messages=} or {sender=} must be provided."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
@ -804,7 +812,11 @@ class ConversableAgent(Agent):
|
||||
Returns:
|
||||
str or dict or None: reply. None if no reply is generated.
|
||||
"""
|
||||
assert messages is not None or sender is not None, "Either messages or sender must be provided."
|
||||
if all((messages is None, sender is None)):
|
||||
error_msg = f"Either {messages=} or {sender=} must be provided."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
|
||||
@ -26,6 +26,8 @@ DEFAULT_TIMEOUT = 600
|
||||
WIN32 = sys.platform == "win32"
|
||||
PATH_SEPARATOR = WIN32 and "\\" or "/"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def infer_lang(code):
|
||||
"""infer the language for the code.
|
||||
@ -250,7 +252,11 @@ def execute_code(
|
||||
str: The error message if the code fails to execute; the stdout otherwise.
|
||||
image: The docker image name after container run when docker is used.
|
||||
"""
|
||||
assert code is not None or filename is not None, "Either code or filename must be provided."
|
||||
if all((code is None, filename is None)):
|
||||
error_msg = f"Either {code=} or {filename=} must be provided."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
timeout = timeout or DEFAULT_TIMEOUT
|
||||
original_filename = filename
|
||||
if WIN32 and lang in ["sh", "shell"]:
|
||||
@ -276,7 +282,7 @@ def execute_code(
|
||||
f".\\{filename}" if WIN32 else filename,
|
||||
]
|
||||
if WIN32:
|
||||
logging.warning("SIGALRM is not supported on Windows. No timeout will be enforced.")
|
||||
logger.warning("SIGALRM is not supported on Windows. No timeout will be enforced.")
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=work_dir,
|
||||
|
||||
@ -35,8 +35,9 @@ def remove_boxed(string: str) -> Optional[str]:
|
||||
"""
|
||||
left = "\\boxed{"
|
||||
try:
|
||||
assert string[: len(left)] == left
|
||||
assert string[-1] == "}"
|
||||
if not all((string[: len(left)] == left, string[-1] == "}")):
|
||||
raise AssertionError
|
||||
|
||||
return string[len(left) : -1]
|
||||
except Exception:
|
||||
return None
|
||||
@ -94,7 +95,8 @@ def _fix_fracs(string: str) -> str:
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
if not len(substr) >= 2:
|
||||
raise AssertionError
|
||||
except Exception:
|
||||
return string
|
||||
a = substr[0]
|
||||
@ -129,7 +131,8 @@ def _fix_a_slash_b(string: str) -> str:
|
||||
try:
|
||||
a = int(a_str)
|
||||
b = int(b_str)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
if not string == "{}/{}".format(a, b):
|
||||
raise AssertionError
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except Exception:
|
||||
@ -143,7 +146,8 @@ def _remove_right_units(string: str) -> str:
|
||||
"""
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
if not len(splits) == 2:
|
||||
raise AssertionError
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
@ -582,23 +582,31 @@ class Completion(openai_Completion):
|
||||
cls._prompts = space.get("prompt")
|
||||
if cls._prompts is None:
|
||||
cls._messages = space.get("messages")
|
||||
assert isinstance(cls._messages, list) and isinstance(
|
||||
cls._messages[0], (dict, list)
|
||||
), "messages must be a list of dicts or a list of lists."
|
||||
if not all((isinstance(cls._messages, list), isinstance(cls._messages[0], (dict, list)))):
|
||||
error_msg = "messages must be a list of dicts or a list of lists."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
if isinstance(cls._messages[0], dict):
|
||||
cls._messages = [cls._messages]
|
||||
space["messages"] = tune.choice(list(range(len(cls._messages))))
|
||||
else:
|
||||
assert space.get("messages") is None, "messages and prompt cannot be provided at the same time."
|
||||
assert isinstance(cls._prompts, (str, list)), "prompt must be a string or a list of strings."
|
||||
if space.get("messages") is not None:
|
||||
error_msg = "messages and prompt cannot be provided at the same time."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
if not isinstance(cls._prompts, (str, list)):
|
||||
error_msg = "prompt must be a string or a list of strings."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
if isinstance(cls._prompts, str):
|
||||
cls._prompts = [cls._prompts]
|
||||
space["prompt"] = tune.choice(list(range(len(cls._prompts))))
|
||||
cls._stops = space.get("stop")
|
||||
if cls._stops:
|
||||
assert isinstance(
|
||||
cls._stops, (str, list)
|
||||
), "stop must be a string, a list of strings, or a list of lists of strings."
|
||||
if not isinstance(cls._stops, (str, list)):
|
||||
error_msg = "stop must be a string, a list of strings, or a list of lists of strings."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
|
||||
cls._stops = [cls._stops]
|
||||
space["stop"] = tune.choice(list(range(len(cls._stops))))
|
||||
@ -969,7 +977,10 @@ class Completion(openai_Completion):
|
||||
elif isinstance(agg_method, dict):
|
||||
for key in metric_keys:
|
||||
metric_agg_method = agg_method[key]
|
||||
assert callable(metric_agg_method), "please provide a callable for each metric"
|
||||
if not callable(metric_agg_method):
|
||||
error_msg = "please provide a callable for each metric"
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
result_agg[key] = metric_agg_method([r[key] for r in result_list])
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@ -29,6 +29,7 @@ TEXT_FORMATS = [
|
||||
"yml",
|
||||
"pdf",
|
||||
]
|
||||
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})
|
||||
|
||||
|
||||
def num_tokens_from_text(
|
||||
@ -96,7 +97,8 @@ def split_text_to_chunks(
|
||||
overlap: int = 10,
|
||||
):
|
||||
"""Split a long text into chunks of max_tokens."""
|
||||
assert chunk_mode in {"one_line", "multi_lines"}
|
||||
if chunk_mode not in VALID_CHUNK_MODES:
|
||||
raise AssertionError
|
||||
if chunk_mode == "one_line":
|
||||
must_break_at_empty_line = False
|
||||
chunks = []
|
||||
|
||||
9
setup.py
9
setup.py
@ -38,15 +38,18 @@ setuptools.setup(
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
"test": [
|
||||
"pytest>=6.1.1",
|
||||
"chromadb",
|
||||
"coverage>=5.3",
|
||||
"pre-commit",
|
||||
"datasets",
|
||||
"ipykernel",
|
||||
"nbconvert",
|
||||
"nbformat",
|
||||
"ipykernel",
|
||||
"pre-commit",
|
||||
"pydantic==1.10.9",
|
||||
"pytest-asyncio",
|
||||
"pytest>=6.1.1",
|
||||
"sympy",
|
||||
"tiktoken",
|
||||
"wolframalpha",
|
||||
],
|
||||
"blendsearch": ["flaml[blendsearch]"],
|
||||
|
||||
@ -2,6 +2,17 @@ import pytest
|
||||
from autogen.agentchat import ConversableAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversable_agent():
|
||||
return ConversableAgent(
|
||||
"conversable_agent_0",
|
||||
max_consecutive_auto_reply=10,
|
||||
code_execution_config=False,
|
||||
llm_config=False,
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
|
||||
|
||||
def test_trigger():
|
||||
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
|
||||
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER")
|
||||
@ -217,6 +228,17 @@ def test_generate_reply():
|
||||
), "generate_reply not working when messages is None"
|
||||
|
||||
|
||||
def test_generate_reply_raises_on_messages_and_sender_none(conversable_agent):
|
||||
with pytest.raises(AssertionError):
|
||||
conversable_agent.generate_reply(messages=None, sender=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_generate_reply_raises_on_messages_and_sender_none(conversable_agent):
|
||||
with pytest.raises(AssertionError):
|
||||
await conversable_agent.a_generate_reply(messages=None, sender=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_trigger()
|
||||
# test_context()
|
||||
|
||||
@ -264,6 +264,11 @@ def test_execute_code(use_docker=None):
|
||||
assert isinstance(image, str) or docker is None or os.path.exists("/.dockerenv") or use_docker is False
|
||||
|
||||
|
||||
def test_execute_code_raises_when_code_and_filename_are_both_none():
|
||||
with pytest.raises(AssertionError):
|
||||
execute_code(code=None, filename=None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["darwin"],
|
||||
reason="do not run on MacOS",
|
||||
|
||||
@ -48,6 +48,10 @@ class TestRetrieveUtils:
|
||||
chunks = split_text_to_chunks(long_text, max_tokens=1000)
|
||||
assert all(num_tokens_from_text(chunk) <= 1000 for chunk in chunks)
|
||||
|
||||
def test_split_text_to_chunks_raises_on_invalid_chunk_mode(self):
|
||||
with pytest.raises(AssertionError):
|
||||
split_text_to_chunks("A" * 10000, chunk_mode="bogus_chunk_mode")
|
||||
|
||||
def test_extract_text_from_pdf(self):
|
||||
pdf_file_path = os.path.join(test_dir, "example.pdf")
|
||||
assert "".join(expected_text.split()) == "".join(extract_text_from_pdf(pdf_file_path).strip().split())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user