| 
									
										
										
										
											2024-02-28 17:11:08 -08:00
										 |  |  | #!/usr/bin/env python3 -m pytest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | """
 | 
					
						
							|  |  |  | Unit test for retrieve_utils.py | 
					
						
							|  |  |  | """
 | 
					
						
							| 
									
										
										
										
											2023-12-31 22:37:21 +03:00
										 |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-03 21:01:49 -07:00
										 |  |  | try: | 
					
						
							|  |  |  |     import chromadb | 
					
						
							|  |  |  |     from autogen.retrieve_utils import ( | 
					
						
							|  |  |  |         split_text_to_chunks, | 
					
						
							|  |  |  |         extract_text_from_pdf, | 
					
						
							|  |  |  |         split_files_to_chunks, | 
					
						
							|  |  |  |         get_files_from_dir, | 
					
						
							|  |  |  |         is_url, | 
					
						
							|  |  |  |         create_vector_db_from_dir, | 
					
						
							|  |  |  |         query_vector_db, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     from autogen.token_count_utils import count_token | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     skip = True | 
					
						
							|  |  |  | else: | 
					
						
							|  |  |  |     skip = False | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | import os | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-05 21:30:28 +08:00
										 |  |  | try: | 
					
						
							|  |  |  |     from unstructured.partition.auto import partition | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     HAS_UNSTRUCTURED = True | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     HAS_UNSTRUCTURED = False | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | test_dir = os.path.join(os.path.dirname(__file__), "test_files") | 
					
						
							|  |  |  | expected_text = """AutoGen is an advanced tool designed to assist developers in harnessing the capabilities
 | 
					
						
							|  |  |  | of Large Language Models (LLMs) for various applications. The primary purpose of AutoGen is to automate and | 
					
						
							|  |  |  | simplify the process of building applications that leverage the power of LLMs, allowing for seamless | 
					
						
							|  |  |  | integration, testing, and deployment."""
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-03 21:01:49 -07:00
										 |  |  | @pytest.mark.skipif(skip, reason="dependency is not installed") | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | class TestRetrieveUtils: | 
					
						
							|  |  |  |     def test_split_text_to_chunks(self): | 
					
						
							|  |  |  |         long_text = "A" * 10000 | 
					
						
							|  |  |  |         chunks = split_text_to_chunks(long_text, max_tokens=1000) | 
					
						
							| 
									
										
										
										
											2023-10-27 08:57:35 -04:00
										 |  |  |         assert all(count_token(chunk) <= 1000 for chunk in chunks) | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-03 19:52:50 +02:00
										 |  |  |     def test_split_text_to_chunks_raises_on_invalid_chunk_mode(self): | 
					
						
							|  |  |  |         with pytest.raises(AssertionError): | 
					
						
							|  |  |  |             split_text_to_chunks("A" * 10000, chunk_mode="bogus_chunk_mode") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  |     def test_extract_text_from_pdf(self): | 
					
						
							|  |  |  |         pdf_file_path = os.path.join(test_dir, "example.pdf") | 
					
						
							|  |  |  |         assert "".join(expected_text.split()) == "".join(extract_text_from_pdf(pdf_file_path).strip().split()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_split_files_to_chunks(self): | 
					
						
							|  |  |  |         pdf_file_path = os.path.join(test_dir, "example.pdf") | 
					
						
							|  |  |  |         txt_file_path = os.path.join(test_dir, "example.txt") | 
					
						
							|  |  |  |         chunks = split_files_to_chunks([pdf_file_path, txt_file_path]) | 
					
						
							| 
									
										
										
										
											2023-11-05 21:30:28 +08:00
										 |  |  |         assert all( | 
					
						
							|  |  |  |             isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip() | 
					
						
							|  |  |  |             for chunk in chunks | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_get_files_from_dir(self): | 
					
						
							| 
									
										
										
										
											2023-11-21 11:53:50 +08:00
										 |  |  |         files = get_files_from_dir(test_dir, recursive=False) | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  |         assert all(os.path.isfile(file) for file in files) | 
					
						
							| 
									
										
										
										
											2023-10-17 22:53:40 +08:00
										 |  |  |         pdf_file_path = os.path.join(test_dir, "example.pdf") | 
					
						
							|  |  |  |         txt_file_path = os.path.join(test_dir, "example.txt") | 
					
						
							|  |  |  |         files = get_files_from_dir([pdf_file_path, txt_file_path]) | 
					
						
							|  |  |  |         assert all(os.path.isfile(file) for file in files) | 
					
						
							| 
									
										
										
										
											2023-11-21 11:53:50 +08:00
										 |  |  |         files = get_files_from_dir( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 pdf_file_path, | 
					
						
							|  |  |  |                 txt_file_path, | 
					
						
							|  |  |  |                 os.path.join(test_dir, "..", "..", "website/docs"), | 
					
						
							|  |  |  |                 "https://raw.githubusercontent.com/microsoft/autogen/main/README.md", | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             recursive=True, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         assert all(os.path.isfile(file) for file in files) | 
					
						
							|  |  |  |         files = get_files_from_dir( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 pdf_file_path, | 
					
						
							|  |  |  |                 txt_file_path, | 
					
						
							|  |  |  |                 os.path.join(test_dir, "..", "..", "website/docs"), | 
					
						
							|  |  |  |                 "https://raw.githubusercontent.com/microsoft/autogen/main/README.md", | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             recursive=True, | 
					
						
							|  |  |  |             types=["pdf", "txt"], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         assert all(os.path.isfile(file) for file in files) | 
					
						
							|  |  |  |         assert len(files) == 3 | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_is_url(self): | 
					
						
							|  |  |  |         assert is_url("https://www.example.com") | 
					
						
							|  |  |  |         assert not is_url("not_a_url") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_create_vector_db_from_dir(self): | 
					
						
							|  |  |  |         db_path = "/tmp/test_retrieve_utils_chromadb.db" | 
					
						
							|  |  |  |         if os.path.exists(db_path): | 
					
						
							|  |  |  |             client = chromadb.PersistentClient(path=db_path) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             client = chromadb.PersistentClient(path=db_path) | 
					
						
							|  |  |  |             create_vector_db_from_dir(test_dir, client=client) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert client.get_collection("all-my-documents") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_query_vector_db(self): | 
					
						
							|  |  |  |         db_path = "/tmp/test_retrieve_utils_chromadb.db" | 
					
						
							|  |  |  |         if os.path.exists(db_path): | 
					
						
							|  |  |  |             client = chromadb.PersistentClient(path=db_path) | 
					
						
							|  |  |  |         else:  # If the database does not exist, create it first | 
					
						
							|  |  |  |             client = chromadb.PersistentClient(path=db_path) | 
					
						
							|  |  |  |             create_vector_db_from_dir(test_dir, client=client) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         results = query_vector_db(["autogen"], client=client) | 
					
						
							|  |  |  |         assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-10 20:53:18 +08:00
										 |  |  |     def test_custom_vector_db(self): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             import lancedb | 
					
						
							|  |  |  |         except ImportError: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  |         from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         db_path = "/tmp/lancedb" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def create_lancedb(): | 
					
						
							|  |  |  |             db = lancedb.connect(db_path) | 
					
						
							|  |  |  |             data = [ | 
					
						
							|  |  |  |                 {"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"}, | 
					
						
							|  |  |  |                 {"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"}, | 
					
						
							|  |  |  |                 {"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"}, | 
					
						
							|  |  |  |                 {"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"}, | 
					
						
							|  |  |  |                 {"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"}, | 
					
						
							|  |  |  |                 {"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"}, | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 db.create_table("my_table", data) | 
					
						
							|  |  |  |             except OSError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): | 
					
						
							|  |  |  |             def query_vector_db( | 
					
						
							|  |  |  |                 self, | 
					
						
							|  |  |  |                 query_texts, | 
					
						
							|  |  |  |                 n_results=10, | 
					
						
							|  |  |  |                 search_string="", | 
					
						
							|  |  |  |             ): | 
					
						
							|  |  |  |                 if query_texts: | 
					
						
							|  |  |  |                     vector = [0.1, 0.3] | 
					
						
							|  |  |  |                 db = lancedb.connect(db_path) | 
					
						
							|  |  |  |                 table = db.open_table("my_table") | 
					
						
							|  |  |  |                 query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() | 
					
						
							| 
									
										
										
										
											2023-10-25 00:09:25 +08:00
										 |  |  |                 return {"ids": [query["id"].tolist()], "documents": [query["documents"].tolist()]} | 
					
						
							| 
									
										
										
										
											2023-10-10 20:53:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): | 
					
						
							|  |  |  |                 results = self.query_vector_db( | 
					
						
							|  |  |  |                     query_texts=[problem], | 
					
						
							|  |  |  |                     n_results=n_results, | 
					
						
							|  |  |  |                     search_string=search_string, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 self._results = results | 
					
						
							|  |  |  |                 print("doc_ids: ", results["ids"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ragragproxyagent = MyRetrieveUserProxyAgent( | 
					
						
							|  |  |  |             name="ragproxyagent", | 
					
						
							|  |  |  |             human_input_mode="NEVER", | 
					
						
							|  |  |  |             max_consecutive_auto_reply=2, | 
					
						
							|  |  |  |             retrieve_config={ | 
					
						
							|  |  |  |                 "task": "qa", | 
					
						
							|  |  |  |                 "chunk_token_size": 2000, | 
					
						
							|  |  |  |                 "client": "__", | 
					
						
							|  |  |  |                 "embedding_model": "all-mpnet-base-v2", | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         create_lancedb() | 
					
						
							|  |  |  |         ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") | 
					
						
							| 
									
										
										
										
											2023-10-25 00:09:25 +08:00
										 |  |  |         assert ragragproxyagent._results["ids"] == [[3, 1, 5]] | 
					
						
							| 
									
										
										
										
											2023-10-10 20:53:18 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-17 22:53:40 +08:00
										 |  |  |     def test_custom_text_split_function(self): | 
					
						
							|  |  |  |         def custom_text_split_function(text): | 
					
						
							|  |  |  |             return [text[: len(text) // 2], text[len(text) // 2 :]] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         db_path = "/tmp/test_retrieve_utils_chromadb.db" | 
					
						
							|  |  |  |         client = chromadb.PersistentClient(path=db_path) | 
					
						
							|  |  |  |         create_vector_db_from_dir( | 
					
						
							|  |  |  |             os.path.join(test_dir, "example.txt"), | 
					
						
							|  |  |  |             client=client, | 
					
						
							|  |  |  |             collection_name="mytestcollection", | 
					
						
							|  |  |  |             custom_text_split_function=custom_text_split_function, | 
					
						
							| 
									
										
										
										
											2023-10-27 17:24:04 -07:00
										 |  |  |             get_or_create=True, | 
					
						
							| 
									
										
										
										
											2023-11-21 11:53:50 +08:00
										 |  |  |             recursive=False, | 
					
						
							| 
									
										
										
										
											2023-10-17 22:53:40 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1) | 
					
						
							|  |  |  |         assert ( | 
					
						
							| 
									
										
										
										
											2023-11-05 21:30:28 +08:00
										 |  |  |             "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities" | 
					
						
							|  |  |  |             in results.get("documents")[0][0] | 
					
						
							| 
									
										
										
										
											2023-10-17 22:53:40 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-27 17:24:04 -07:00
										 |  |  |     def test_retrieve_utils(self): | 
					
						
							|  |  |  |         client = chromadb.PersistentClient(path="/tmp/chromadb") | 
					
						
							| 
									
										
										
										
											2023-11-03 21:01:49 -07:00
										 |  |  |         create_vector_db_from_dir( | 
					
						
							| 
									
										
										
										
											2023-11-05 21:30:28 +08:00
										 |  |  |             dir_path="./website/docs", | 
					
						
							|  |  |  |             client=client, | 
					
						
							|  |  |  |             collection_name="autogen-docs", | 
					
						
							| 
									
										
										
										
											2023-11-21 11:53:50 +08:00
										 |  |  |             custom_text_types=["txt", "md", "rtf", "rst"], | 
					
						
							| 
									
										
										
										
											2023-11-05 21:30:28 +08:00
										 |  |  |             get_or_create=True, | 
					
						
							| 
									
										
										
										
											2023-11-03 21:01:49 -07:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-10-27 17:24:04 -07:00
										 |  |  |         results = query_vector_db( | 
					
						
							|  |  |  |             query_texts=[ | 
					
						
							|  |  |  |                 "How can I use AutoGen UserProxyAgent and AssistantAgent to do code generation?", | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             n_results=4, | 
					
						
							|  |  |  |             client=client, | 
					
						
							|  |  |  |             collection_name="autogen-docs", | 
					
						
							|  |  |  |             search_string="AutoGen", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         print(results["ids"][0]) | 
					
						
							|  |  |  |         assert len(results["ids"][0]) == 4 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-05 21:30:28 +08:00
										 |  |  |     @pytest.mark.skipif( | 
					
						
							|  |  |  |         not HAS_UNSTRUCTURED, | 
					
						
							|  |  |  |         reason="do not run if unstructured is not installed", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     def test_unstructured(self): | 
					
						
							|  |  |  |         pdf_file_path = os.path.join(test_dir, "example.pdf") | 
					
						
							|  |  |  |         txt_file_path = os.path.join(test_dir, "example.txt") | 
					
						
							|  |  |  |         word_file_path = os.path.join(test_dir, "example.docx") | 
					
						
							|  |  |  |         chunks = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path]) | 
					
						
							|  |  |  |         assert all( | 
					
						
							|  |  |  |             isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip() | 
					
						
							|  |  |  |             for chunk in chunks | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-01 11:22:58 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     pytest.main() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     db_path = "/tmp/test_retrieve_utils_chromadb.db" | 
					
						
							|  |  |  |     if os.path.exists(db_path): | 
					
						
							|  |  |  |         os.remove(db_path)  # Delete the database file after tests are finished |