2024-02-28 17:11:08 -08:00
#!/usr/bin/env python3 -m pytest
2023-11-03 21:01:49 -07:00
import os
2023-08-13 20:51:54 +08:00
import sys
2024-04-05 10:26:06 +08:00
import pytest
2023-09-16 16:34:16 +00:00
import autogen
2023-11-03 21:01:49 -07:00
2024-01-10 04:42:07 +03:00
sys . path . append ( os . path . join ( os . path . dirname ( __file__ ) , " ../.. " ) )
from conftest import skip_openai # noqa: E402
2023-11-03 21:01:49 -07:00
sys . path . append ( os . path . join ( os . path . dirname ( __file__ ) , " .. " ) )
from test_assistant_agent import KEY_LOC , OAI_CONFIG_LIST # noqa: E402
2023-08-13 20:51:54 +08:00
try :
2024-04-05 10:26:06 +08:00
import chromadb
2023-11-03 21:01:49 -07:00
import openai
2024-04-05 10:26:06 +08:00
from chromadb . utils import embedding_functions as ef
2023-09-16 16:34:16 +00:00
from autogen . agentchat . contrib . retrieve_assistant_agent import (
2023-08-13 20:51:54 +08:00
RetrieveAssistantAgent ,
)
2023-09-16 16:34:16 +00:00
from autogen . agentchat . contrib . retrieve_user_proxy_agent import (
2023-08-13 20:51:54 +08:00
RetrieveUserProxyAgent ,
)
except ImportError :
2024-01-10 04:42:07 +03:00
skip = True
else :
skip = False or skip_openai
2023-08-13 20:51:54 +08:00
@pytest.mark.skipif (
2024-01-10 04:42:07 +03:00
sys . platform in [ " darwin " , " win32 " ] or skip ,
reason = " do not run on MacOS or windows OR dependency is not installed OR requested to skip " ,
2023-08-13 20:51:54 +08:00
)
def test_retrievechat ( ) :
conversations = { }
2023-11-03 21:01:49 -07:00
# autogen.ChatCompletion.start_logging(conversations) # deprecated in v0.2
2023-08-13 20:51:54 +08:00
config_list = autogen . config_list_from_json (
OAI_CONFIG_LIST ,
file_location = KEY_LOC ,
)
assistant = RetrieveAssistantAgent (
name = " assistant " ,
system_message = " You are a helpful assistant. " ,
llm_config = {
2023-11-03 21:01:49 -07:00
" timeout " : 600 ,
2023-08-13 20:51:54 +08:00
" seed " : 42 ,
" config_list " : config_list ,
} ,
)
2023-10-10 20:53:18 +08:00
sentence_transformer_ef = ef . SentenceTransformerEmbeddingFunction ( )
2023-08-13 20:51:54 +08:00
ragproxyagent = RetrieveUserProxyAgent (
name = " ragproxyagent " ,
human_input_mode = " NEVER " ,
max_consecutive_auto_reply = 2 ,
retrieve_config = {
" docs_path " : " ./website/docs " ,
" chunk_token_size " : 2000 ,
" model " : config_list [ 0 ] [ " model " ] ,
" client " : chromadb . PersistentClient ( path = " /tmp/chromadb " ) ,
2023-10-10 20:53:18 +08:00
" embedding_function " : sentence_transformer_ef ,
2023-10-27 17:24:04 -07:00
" get_or_create " : True ,
2023-08-13 20:51:54 +08:00
} ,
)
assistant . reset ( )
code_problem = " How can I use FLAML to perform a classification task, set use_spark=True, train 30 seconds and force cancel jobs if time limit is reached. "
2024-03-09 15:27:46 -05:00
ragproxyagent . initiate_chat (
assistant , message = ragproxyagent . message_generator , problem = code_problem , search_string = " spark " , silent = True
)
2023-08-13 20:51:54 +08:00
print ( conversations )
2023-12-01 00:34:45 +08:00
@pytest.mark.skipif (
2024-01-10 04:42:07 +03:00
sys . platform in [ " darwin " , " win32 " ] or skip ,
reason = " do not run on MacOS or windows OR dependency is not installed OR requested to skip " ,
2023-12-01 00:34:45 +08:00
)
def test_retrieve_config ( caplog ) :
# test warning message when no docs_path is provided
ragproxyagent = RetrieveUserProxyAgent (
name = " ragproxyagent " ,
human_input_mode = " NEVER " ,
max_consecutive_auto_reply = 2 ,
retrieve_config = {
" chunk_token_size " : 2000 ,
" get_or_create " : True ,
} ,
)
# Capture the printed content
captured_logs = caplog . records [ 0 ]
print ( captured_logs )
# Assert on the printed content
assert (
f " docs_path is not provided in retrieve_config. Will raise ValueError if the collection ` { ragproxyagent . _collection_name } ` doesn ' t exist. "
in captured_logs . message
)
assert captured_logs . levelname == " WARNING "
2023-08-13 20:51:54 +08:00
if __name__ == " __main__ " :
2023-12-01 00:34:45 +08:00
# test_retrievechat()
test_retrieve_config ( )