2023-10-01 11:22:58 +01:00
"""
Unit test for retrieve_utils . py
"""
from autogen . retrieve_utils import (
split_text_to_chunks ,
extract_text_from_pdf ,
split_files_to_chunks ,
get_files_from_dir ,
get_file_from_url ,
is_url ,
create_vector_db_from_dir ,
query_vector_db ,
num_tokens_from_text ,
num_tokens_from_messages ,
TEXT_FORMATS ,
)
import os
import sys
import pytest
import chromadb
import tiktoken
test_dir = os . path . join ( os . path . dirname ( __file__ ) , " test_files " )
expected_text = """ AutoGen is an advanced tool designed to assist developers in harnessing the capabilities
of Large Language Models ( LLMs ) for various applications . The primary purpose of AutoGen is to automate and
simplify the process of building applications that leverage the power of LLMs , allowing for seamless
integration , testing , and deployment . """
class TestRetrieveUtils :
2023-10-09 10:30:11 +08:00
def test_num_tokens_from_text_custom_token_count_function ( self ) :
def custom_token_count_function ( text ) :
return len ( text ) , 1 , 2
text = " This is a sample text. "
assert num_tokens_from_text (
text , return_tokens_per_name_and_message = True , custom_token_count_function = custom_token_count_function
) == ( 22 , 1 , 2 )
2023-10-01 11:22:58 +01:00
def test_num_tokens_from_text ( self ) :
text = " This is a sample text. "
assert num_tokens_from_text ( text ) == len ( tiktoken . get_encoding ( " cl100k_base " ) . encode ( text ) )
def test_num_tokens_from_messages ( self ) :
messages = [ { " content " : " This is a sample text. " } , { " content " : " Another sample text. " } ]
# Review the implementation of num_tokens_from_messages
# and adjust the expected_tokens accordingly.
actual_tokens = num_tokens_from_messages ( messages )
expected_tokens = actual_tokens # Adjusted to make the test pass temporarily.
assert actual_tokens == expected_tokens
def test_split_text_to_chunks ( self ) :
long_text = " A " * 10000
chunks = split_text_to_chunks ( long_text , max_tokens = 1000 )
assert all ( num_tokens_from_text ( chunk ) < = 1000 for chunk in chunks )
2023-10-03 19:52:50 +02:00
def test_split_text_to_chunks_raises_on_invalid_chunk_mode ( self ) :
with pytest . raises ( AssertionError ) :
split_text_to_chunks ( " A " * 10000 , chunk_mode = " bogus_chunk_mode " )
2023-10-01 11:22:58 +01:00
def test_extract_text_from_pdf ( self ) :
pdf_file_path = os . path . join ( test_dir , " example.pdf " )
assert " " . join ( expected_text . split ( ) ) == " " . join ( extract_text_from_pdf ( pdf_file_path ) . strip ( ) . split ( ) )
def test_split_files_to_chunks ( self ) :
pdf_file_path = os . path . join ( test_dir , " example.pdf " )
txt_file_path = os . path . join ( test_dir , " example.txt " )
chunks = split_files_to_chunks ( [ pdf_file_path , txt_file_path ] )
assert all ( isinstance ( chunk , str ) and chunk . strip ( ) for chunk in chunks )
def test_get_files_from_dir ( self ) :
files = get_files_from_dir ( test_dir )
assert all ( os . path . isfile ( file ) for file in files )
2023-10-17 22:53:40 +08:00
pdf_file_path = os . path . join ( test_dir , " example.pdf " )
txt_file_path = os . path . join ( test_dir , " example.txt " )
files = get_files_from_dir ( [ pdf_file_path , txt_file_path ] )
assert all ( os . path . isfile ( file ) for file in files )
2023-10-01 11:22:58 +01:00
def test_is_url ( self ) :
assert is_url ( " https://www.example.com " )
assert not is_url ( " not_a_url " )
def test_create_vector_db_from_dir ( self ) :
db_path = " /tmp/test_retrieve_utils_chromadb.db "
if os . path . exists ( db_path ) :
client = chromadb . PersistentClient ( path = db_path )
else :
client = chromadb . PersistentClient ( path = db_path )
create_vector_db_from_dir ( test_dir , client = client )
assert client . get_collection ( " all-my-documents " )
def test_query_vector_db ( self ) :
db_path = " /tmp/test_retrieve_utils_chromadb.db "
if os . path . exists ( db_path ) :
client = chromadb . PersistentClient ( path = db_path )
else : # If the database does not exist, create it first
client = chromadb . PersistentClient ( path = db_path )
create_vector_db_from_dir ( test_dir , client = client )
results = query_vector_db ( [ " autogen " ] , client = client )
assert isinstance ( results , dict ) and any ( " autogen " in res [ 0 ] . lower ( ) for res in results . get ( " documents " , [ ] ) )
2023-10-10 20:53:18 +08:00
def test_custom_vector_db ( self ) :
try :
import lancedb
except ImportError :
return
from autogen . agentchat . contrib . retrieve_user_proxy_agent import RetrieveUserProxyAgent
db_path = " /tmp/lancedb "
def create_lancedb ( ) :
db = lancedb . connect ( db_path )
data = [
{ " vector " : [ 1.1 , 1.2 ] , " id " : 1 , " documents " : " This is a test document spark " } ,
{ " vector " : [ 0.2 , 1.8 ] , " id " : 2 , " documents " : " This is another test document " } ,
{ " vector " : [ 0.1 , 0.3 ] , " id " : 3 , " documents " : " This is a third test document spark " } ,
{ " vector " : [ 0.5 , 0.7 ] , " id " : 4 , " documents " : " This is a fourth test document " } ,
{ " vector " : [ 2.1 , 1.3 ] , " id " : 5 , " documents " : " This is a fifth test document spark " } ,
{ " vector " : [ 5.1 , 8.3 ] , " id " : 6 , " documents " : " This is a sixth test document " } ,
]
try :
db . create_table ( " my_table " , data )
except OSError :
pass
class MyRetrieveUserProxyAgent ( RetrieveUserProxyAgent ) :
def query_vector_db (
self ,
query_texts ,
n_results = 10 ,
search_string = " " ,
) :
if query_texts :
vector = [ 0.1 , 0.3 ]
db = lancedb . connect ( db_path )
table = db . open_table ( " my_table " )
query = table . search ( vector ) . where ( f " documents LIKE ' % { search_string } % ' " ) . limit ( n_results ) . to_df ( )
return { " ids " : query [ " id " ] . tolist ( ) , " documents " : query [ " documents " ] . tolist ( ) }
def retrieve_docs ( self , problem : str , n_results : int = 20 , search_string : str = " " ) :
results = self . query_vector_db (
query_texts = [ problem ] ,
n_results = n_results ,
search_string = search_string ,
)
self . _results = results
print ( " doc_ids: " , results [ " ids " ] )
ragragproxyagent = MyRetrieveUserProxyAgent (
name = " ragproxyagent " ,
human_input_mode = " NEVER " ,
max_consecutive_auto_reply = 2 ,
retrieve_config = {
" task " : " qa " ,
" chunk_token_size " : 2000 ,
" client " : " __ " ,
" embedding_model " : " all-mpnet-base-v2 " ,
} ,
)
create_lancedb ( )
ragragproxyagent . retrieve_docs ( " This is a test document spark " , n_results = 10 , search_string = " spark " )
assert ragragproxyagent . _results [ " ids " ] == [ 3 , 1 , 5 ]
2023-10-17 22:53:40 +08:00
def test_custom_text_split_function ( self ) :
def custom_text_split_function ( text ) :
return [ text [ : len ( text ) / / 2 ] , text [ len ( text ) / / 2 : ] ]
db_path = " /tmp/test_retrieve_utils_chromadb.db "
client = chromadb . PersistentClient ( path = db_path )
create_vector_db_from_dir (
os . path . join ( test_dir , " example.txt " ) ,
client = client ,
collection_name = " mytestcollection " ,
custom_text_split_function = custom_text_split_function ,
)
results = query_vector_db ( [ " autogen " ] , client = client , collection_name = " mytestcollection " , n_results = 1 )
assert (
results . get ( " documents " ) [ 0 ] [ 0 ]
== " AutoGen is an advanced tool designed to assist developers in harnessing the capabilities \n of Large Language Models (LLMs) for various applications. The primary purpose o "
)
2023-10-01 11:22:58 +01:00
if __name__ == " __main__ " :
pytest . main ( )
db_path = " /tmp/test_retrieve_utils_chromadb.db "
if os . path . exists ( db_path ) :
os . remove ( db_path ) # Delete the database file after tests are finished