mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-25 18:01:03 +00:00

* Add custom embedding function * Add support to custom vector db * Improve docstring * Improve docstring * Improve docstring * Add support to customized is_termination_msg fucntion * Add a test for customize vector db with lancedb * Fix tests * Add test for embedding_function * Update docstring
98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
import pytest
|
|
import sys
|
|
import autogen
|
|
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
|
|
|
try:
|
|
from autogen.agentchat.contrib.retrieve_assistant_agent import (
|
|
RetrieveAssistantAgent,
|
|
)
|
|
from autogen.agentchat.contrib.retrieve_user_proxy_agent import (
|
|
RetrieveUserProxyAgent,
|
|
)
|
|
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
|
|
import chromadb
|
|
from chromadb.utils import embedding_functions as ef
|
|
|
|
skip_test = False
|
|
except ImportError:
|
|
skip_test = True
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.platform in ["darwin", "win32"] or skip_test,
|
|
reason="do not run on MacOS or windows",
|
|
)
|
|
def test_retrievechat():
|
|
try:
|
|
import openai
|
|
except ImportError:
|
|
return
|
|
|
|
conversations = {}
|
|
autogen.ChatCompletion.start_logging(conversations)
|
|
|
|
config_list = autogen.config_list_from_json(
|
|
OAI_CONFIG_LIST,
|
|
file_location=KEY_LOC,
|
|
filter_dict={
|
|
"model": ["gpt-4", "gpt4", "gpt-4-32k", "gpt-4-32k-0314"],
|
|
},
|
|
)
|
|
|
|
assistant = RetrieveAssistantAgent(
|
|
name="assistant",
|
|
system_message="You are a helpful assistant.",
|
|
llm_config={
|
|
"request_timeout": 600,
|
|
"seed": 42,
|
|
"config_list": config_list,
|
|
},
|
|
)
|
|
|
|
sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction()
|
|
ragproxyagent = RetrieveUserProxyAgent(
|
|
name="ragproxyagent",
|
|
human_input_mode="NEVER",
|
|
max_consecutive_auto_reply=2,
|
|
retrieve_config={
|
|
"docs_path": "./website/docs",
|
|
"chunk_token_size": 2000,
|
|
"model": config_list[0]["model"],
|
|
"client": chromadb.PersistentClient(path="/tmp/chromadb"),
|
|
"embedding_function": sentence_transformer_ef,
|
|
},
|
|
)
|
|
|
|
assistant.reset()
|
|
|
|
code_problem = "How can I use FLAML to perform a classification task, set use_spark=True, train 30 seconds and force cancel jobs if time limit is reached."
|
|
ragproxyagent.initiate_chat(assistant, problem=code_problem, search_string="spark", silent=True)
|
|
|
|
print(conversations)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.platform in ["darwin", "win32"] or skip_test,
|
|
reason="do not run on MacOS or windows",
|
|
)
|
|
def test_retrieve_utils():
|
|
client = chromadb.PersistentClient(path="/tmp/chromadb")
|
|
create_vector_db_from_dir(dir_path="./website/docs", client=client, collection_name="autogen-docs")
|
|
results = query_vector_db(
|
|
query_texts=[
|
|
"How can I use AutoGen UserProxyAgent and AssistantAgent to do code generation?",
|
|
],
|
|
n_results=4,
|
|
client=client,
|
|
collection_name="autogen-docs",
|
|
search_string="AutoGen",
|
|
)
|
|
print(results["ids"][0])
|
|
assert len(results["ids"][0]) == 4
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_retrievechat()
|
|
test_retrieve_utils()
|