From f2d7553cdcfed076a4dd793aec0ffd349b6fbdd0 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 17 Oct 2023 22:53:40 +0800 Subject: [PATCH] Add support to custom text spliter (#270) * Add support to custom text spliter function and a list of files or urls * Add parameter to retrieve_config, add tests * Fix tests * Fix tests --- .../contrib/retrieve_user_proxy_agent.py | 4 +++ autogen/retrieve_utils.py | 36 ++++++++++++++++--- test/test_retrieve_utils.py | 22 ++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 0f29aa62d..94677244a 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -122,6 +122,8 @@ class RetrieveUserProxyAgent(UserProxyAgent): - custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name). Default is None, tiktoken will be used and may not be accurate for non-OpenAI models. + - custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings. + Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`. **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). Example of overriding retrieve_docs: @@ -175,6 +177,7 @@ class RetrieveUserProxyAgent(UserProxyAgent): self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False ) self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None) + self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None) self._context_max_tokens = self._max_tokens * 0.8 self._collection = True if self._docs_path is None else False # whether the collection is created self._ipython = get_ipython() @@ -364,6 +367,7 @@ class RetrieveUserProxyAgent(UserProxyAgent): embedding_model=self._embedding_model, get_or_create=self._get_or_create, embedding_function=self._embedding_function, + custom_text_split_function=self.custom_text_split_function, ) self._collection = True self._get_or_create = False diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 721b1ec29..c660fa85d 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -180,7 +180,11 @@ def extract_text_from_pdf(file: str) -> str: def split_files_to_chunks( - files: list, max_tokens: int = 4000, chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True + files: list, + max_tokens: int = 4000, + chunk_mode: str = "multi_lines", + must_break_at_empty_line: bool = True, + custom_text_split_function: Callable = None, ): """Split a list of files into chunks of max_tokens.""" @@ -200,18 +204,33 @@ def split_files_to_chunks( logger.warning(f"No text available in file: {file}") continue # Skip to the next file if no text is available - chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line) + if custom_text_split_function is not None: + chunks += custom_text_split_function(text) + else: + chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line) return chunks -def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: bool = True): +def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True): """Return a list of all the files in a given directory.""" if len(types) == 0: raise ValueError("types cannot be empty.") types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)] types += [t.upper() for t in types] + files = [] + # If the path is a list of files or urls, process and return them + if isinstance(dir_path, list): + for item in dir_path: + if os.path.isfile(item): + files.append(item) + elif is_url(item): + files.append(get_file_from_url(item)) + else: + logger.warning(f"File {item} does not exist. Skipping.") + return files + # If the path is a file, return it if os.path.isfile(dir_path): return [dir_path] @@ -220,7 +239,6 @@ def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: boo if is_url(dir_path): return [get_file_from_url(dir_path)] - files = [] if os.path.exists(dir_path): for type in types: if recursive: @@ -265,6 +283,7 @@ def create_vector_db_from_dir( must_break_at_empty_line: bool = True, embedding_model: str = "all-MiniLM-L6-v2", embedding_function: Callable = None, + custom_text_split_function: Callable = None, ): """Create a vector db from all the files in a given directory, the directory can also be a single file or a url to a single file. We support chromadb compatible APIs to create the vector db, this function is not required if @@ -304,7 +323,14 @@ def create_vector_db_from_dir( metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine ) - chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) + if custom_text_split_function is not None: + chunks = split_files_to_chunks( + get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function + ) + else: + chunks = split_files_to_chunks( + get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line + ) logger.info(f"Found {len(chunks)} chunks.") # Upsert in batch of 40000 or less if the total number of chunks is less than 40000 for i in range(0, len(chunks), min(40000, len(chunks))): diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index be215facb..a1c70d9cf 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -74,6 +74,10 @@ class TestRetrieveUtils: def test_get_files_from_dir(self): files = get_files_from_dir(test_dir) assert all(os.path.isfile(file) for file in files) + pdf_file_path = os.path.join(test_dir, "example.pdf") + txt_file_path = os.path.join(test_dir, "example.txt") + files = get_files_from_dir([pdf_file_path, txt_file_path]) + assert all(os.path.isfile(file) for file in files) def test_is_url(self): assert is_url("https://www.example.com") @@ -164,6 +168,24 @@ class TestRetrieveUtils: ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") assert ragragproxyagent._results["ids"] == [3, 1, 5] + def test_custom_text_split_function(self): + def custom_text_split_function(text): + return [text[: len(text) // 2], text[len(text) // 2 :]] + + db_path = "/tmp/test_retrieve_utils_chromadb.db" + client = chromadb.PersistentClient(path=db_path) + create_vector_db_from_dir( + os.path.join(test_dir, "example.txt"), + client=client, + collection_name="mytestcollection", + custom_text_split_function=custom_text_split_function, + ) + results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1) + assert ( + results.get("documents")[0][0] + == "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities\nof Large Language Models (LLMs) for various applications. The primary purpose o" + ) + if __name__ == "__main__": pytest.main()