2023-10-01 11:22:58 +01:00
"""
Unit test for retrieve_utils . py
"""
from autogen . retrieve_utils import (
split_text_to_chunks ,
extract_text_from_pdf ,
split_files_to_chunks ,
get_files_from_dir ,
is_url ,
create_vector_db_from_dir ,
query_vector_db ,
)
2023-10-27 08:57:35 -04:00
from autogen . token_count_utils import count_token
2023-10-01 11:22:58 +01:00
import os
import pytest
import chromadb
test_dir = os . path . join ( os . path . dirname ( __file__ ) , " test_files " )
expected_text = """ AutoGen is an advanced tool designed to assist developers in harnessing the capabilities
of Large Language Models ( LLMs ) for various applications . The primary purpose of AutoGen is to automate and
simplify the process of building applications that leverage the power of LLMs , allowing for seamless
integration , testing , and deployment . """
class TestRetrieveUtils :
def test_split_text_to_chunks ( self ) :
long_text = " A " * 10000
chunks = split_text_to_chunks ( long_text , max_tokens = 1000 )
2023-10-27 08:57:35 -04:00
assert all ( count_token ( chunk ) < = 1000 for chunk in chunks )
2023-10-01 11:22:58 +01:00
2023-10-03 19:52:50 +02:00
def test_split_text_to_chunks_raises_on_invalid_chunk_mode ( self ) :
with pytest . raises ( AssertionError ) :
split_text_to_chunks ( " A " * 10000 , chunk_mode = " bogus_chunk_mode " )
2023-10-01 11:22:58 +01:00
def test_extract_text_from_pdf ( self ) :
pdf_file_path = os . path . join ( test_dir , " example.pdf " )
assert " " . join ( expected_text . split ( ) ) == " " . join ( extract_text_from_pdf ( pdf_file_path ) . strip ( ) . split ( ) )
def test_split_files_to_chunks ( self ) :
pdf_file_path = os . path . join ( test_dir , " example.pdf " )
txt_file_path = os . path . join ( test_dir , " example.txt " )
chunks = split_files_to_chunks ( [ pdf_file_path , txt_file_path ] )
assert all ( isinstance ( chunk , str ) and chunk . strip ( ) for chunk in chunks )
def test_get_files_from_dir ( self ) :
files = get_files_from_dir ( test_dir )
assert all ( os . path . isfile ( file ) for file in files )
2023-10-17 22:53:40 +08:00
pdf_file_path = os . path . join ( test_dir , " example.pdf " )
txt_file_path = os . path . join ( test_dir , " example.txt " )
files = get_files_from_dir ( [ pdf_file_path , txt_file_path ] )
assert all ( os . path . isfile ( file ) for file in files )
2023-10-01 11:22:58 +01:00
def test_is_url ( self ) :
assert is_url ( " https://www.example.com " )
assert not is_url ( " not_a_url " )
def test_create_vector_db_from_dir ( self ) :
db_path = " /tmp/test_retrieve_utils_chromadb.db "
if os . path . exists ( db_path ) :
client = chromadb . PersistentClient ( path = db_path )
else :
client = chromadb . PersistentClient ( path = db_path )
create_vector_db_from_dir ( test_dir , client = client )
assert client . get_collection ( " all-my-documents " )
def test_query_vector_db ( self ) :
db_path = " /tmp/test_retrieve_utils_chromadb.db "
if os . path . exists ( db_path ) :
client = chromadb . PersistentClient ( path = db_path )
else : # If the database does not exist, create it first
client = chromadb . PersistentClient ( path = db_path )
create_vector_db_from_dir ( test_dir , client = client )
results = query_vector_db ( [ " autogen " ] , client = client )
assert isinstance ( results , dict ) and any ( " autogen " in res [ 0 ] . lower ( ) for res in results . get ( " documents " , [ ] ) )
2023-10-10 20:53:18 +08:00
def test_custom_vector_db ( self ) :
try :
import lancedb
except ImportError :
return
from autogen . agentchat . contrib . retrieve_user_proxy_agent import RetrieveUserProxyAgent
db_path = " /tmp/lancedb "
def create_lancedb ( ) :
db = lancedb . connect ( db_path )
data = [
{ " vector " : [ 1.1 , 1.2 ] , " id " : 1 , " documents " : " This is a test document spark " } ,
{ " vector " : [ 0.2 , 1.8 ] , " id " : 2 , " documents " : " This is another test document " } ,
{ " vector " : [ 0.1 , 0.3 ] , " id " : 3 , " documents " : " This is a third test document spark " } ,
{ " vector " : [ 0.5 , 0.7 ] , " id " : 4 , " documents " : " This is a fourth test document " } ,
{ " vector " : [ 2.1 , 1.3 ] , " id " : 5 , " documents " : " This is a fifth test document spark " } ,
{ " vector " : [ 5.1 , 8.3 ] , " id " : 6 , " documents " : " This is a sixth test document " } ,
]
try :
db . create_table ( " my_table " , data )
except OSError :
pass
class MyRetrieveUserProxyAgent ( RetrieveUserProxyAgent ) :
def query_vector_db (
self ,
query_texts ,
n_results = 10 ,
search_string = " " ,
) :
if query_texts :
vector = [ 0.1 , 0.3 ]
db = lancedb . connect ( db_path )
table = db . open_table ( " my_table " )
query = table . search ( vector ) . where ( f " documents LIKE ' % { search_string } % ' " ) . limit ( n_results ) . to_df ( )
2023-10-25 00:09:25 +08:00
return { " ids " : [ query [ " id " ] . tolist ( ) ] , " documents " : [ query [ " documents " ] . tolist ( ) ] }
2023-10-10 20:53:18 +08:00
def retrieve_docs ( self , problem : str , n_results : int = 20 , search_string : str = " " ) :
results = self . query_vector_db (
query_texts = [ problem ] ,
n_results = n_results ,
search_string = search_string ,
)
self . _results = results
print ( " doc_ids: " , results [ " ids " ] )
ragragproxyagent = MyRetrieveUserProxyAgent (
name = " ragproxyagent " ,
human_input_mode = " NEVER " ,
max_consecutive_auto_reply = 2 ,
retrieve_config = {
" task " : " qa " ,
" chunk_token_size " : 2000 ,
" client " : " __ " ,
" embedding_model " : " all-mpnet-base-v2 " ,
} ,
)
create_lancedb ( )
ragragproxyagent . retrieve_docs ( " This is a test document spark " , n_results = 10 , search_string = " spark " )
2023-10-25 00:09:25 +08:00
assert ragragproxyagent . _results [ " ids " ] == [ [ 3 , 1 , 5 ] ]
2023-10-10 20:53:18 +08:00
2023-10-17 22:53:40 +08:00
def test_custom_text_split_function ( self ) :
def custom_text_split_function ( text ) :
return [ text [ : len ( text ) / / 2 ] , text [ len ( text ) / / 2 : ] ]
db_path = " /tmp/test_retrieve_utils_chromadb.db "
client = chromadb . PersistentClient ( path = db_path )
create_vector_db_from_dir (
os . path . join ( test_dir , " example.txt " ) ,
client = client ,
collection_name = " mytestcollection " ,
custom_text_split_function = custom_text_split_function ,
2023-10-27 17:24:04 -07:00
get_or_create = True ,
2023-10-17 22:53:40 +08:00
)
results = query_vector_db ( [ " autogen " ] , client = client , collection_name = " mytestcollection " , n_results = 1 )
assert (
results . get ( " documents " ) [ 0 ] [ 0 ]
== " AutoGen is an advanced tool designed to assist developers in harnessing the capabilities \n of Large Language Models (LLMs) for various applications. The primary purpose o "
)
2023-10-27 17:24:04 -07:00
def test_retrieve_utils ( self ) :
client = chromadb . PersistentClient ( path = " /tmp/chromadb " )
create_vector_db_from_dir ( dir_path = " ./website/docs " , client = client , collection_name = " autogen-docs " )
results = query_vector_db (
query_texts = [
" How can I use AutoGen UserProxyAgent and AssistantAgent to do code generation? " ,
] ,
n_results = 4 ,
client = client ,
collection_name = " autogen-docs " ,
search_string = " AutoGen " ,
)
print ( results [ " ids " ] [ 0 ] )
assert len ( results [ " ids " ] [ 0 ] ) == 4
2023-10-01 11:22:58 +01:00
if __name__ == " __main__ " :
pytest . main ( )
db_path = " /tmp/test_retrieve_utils_chromadb.db "
if os . path . exists ( db_path ) :
os . remove ( db_path ) # Delete the database file after tests are finished