mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 02:40:21 +00:00
Add support to customized vectordb and embedding functions (#161)
* Add custom embedding function * Add support to custom vector db * Improve docstring * Improve docstring * Improve docstring * Add support to customized is_termination_msg fucntion * Add a test for customize vector db with lancedb * Fix tests * Add test for embedding_function * Update docstring
This commit is contained in:
parent
37a07a83c3
commit
fa6e2a52c0
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install -e .
|
||||
python -c "import autogen"
|
||||
pip install -e.[mathchat,retrievechat] datasets pytest
|
||||
pip install -e.[mathchat,retrievechat,test] datasets pytest
|
||||
pip uninstall -y openai
|
||||
- name: Test with pytest
|
||||
if: matrix.python-version != '3.10'
|
||||
|
||||
@ -67,6 +67,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
self,
|
||||
name="RetrieveChatAgent", # default set to RetrieveChatAgent
|
||||
human_input_mode: Optional[str] = "ALWAYS",
|
||||
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
|
||||
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
|
||||
**kwargs,
|
||||
):
|
||||
@ -82,14 +83,17 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
the number of auto reply reaches the max_consecutive_auto_reply.
|
||||
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
||||
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
||||
is_termination_msg (function): a function that takes a message in the form of a dictionary
|
||||
and returns a boolean value indicating if this received message is a termination message.
|
||||
The dict can contain the following keys: "content", "role", "name", "function_call".
|
||||
retrieve_config (dict or None): config for the retrieve agent.
|
||||
To use default config, set to None. Otherwise, set to a dictionary with the following keys:
|
||||
- task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
|
||||
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
|
||||
- client (Optional, chromadb.Client): the chromadb client.
|
||||
If key not provided, a default client `chromadb.Client()` will be used.
|
||||
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
|
||||
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
|
||||
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
|
||||
or the url to a single file. If key not provided, a default path `./docs` will be used.
|
||||
or the url to a single file. Default is None, which works only if the collection is already created.
|
||||
- collection_name (Optional, str): the name of the collection.
|
||||
If key not provided, a default name `autogen-docs` will be used.
|
||||
- model (Optional, str): the model to use for the retrieve chat.
|
||||
@ -106,16 +110,45 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
|
||||
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
|
||||
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
|
||||
- embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
|
||||
SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
|
||||
other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
|
||||
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
|
||||
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
|
||||
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
|
||||
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
|
||||
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
|
||||
This is the same as that used in chromadb. Default is False.
|
||||
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
|
||||
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
|
||||
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
|
||||
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
|
||||
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
|
||||
|
||||
Example of overriding retrieve_docs:
|
||||
If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
|
||||
```python
|
||||
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
|
||||
def query_vector_db(
|
||||
self,
|
||||
query_texts: List[str],
|
||||
n_results: int = 10,
|
||||
search_string: str = "",
|
||||
**kwargs,
|
||||
) -> Dict[str, Union[List[str], List[List[str]]]]:
|
||||
# define your own query function here
|
||||
pass
|
||||
|
||||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs):
|
||||
results = self.query_vector_db(
|
||||
query_texts=[problem],
|
||||
n_results=n_results,
|
||||
search_string=search_string,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._results = results
|
||||
print("doc_ids: ", results["ids"])
|
||||
```
|
||||
"""
|
||||
super().__init__(
|
||||
name=name,
|
||||
@ -126,7 +159,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
self._retrieve_config = {} if retrieve_config is None else retrieve_config
|
||||
self._task = self._retrieve_config.get("task", "default")
|
||||
self._client = self._retrieve_config.get("client", chromadb.Client())
|
||||
self._docs_path = self._retrieve_config.get("docs_path", "./docs")
|
||||
self._docs_path = self._retrieve_config.get("docs_path", None)
|
||||
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
|
||||
self._model = self._retrieve_config.get("model", "gpt-4")
|
||||
self._max_tokens = self.get_max_tokens(self._model)
|
||||
@ -134,20 +167,26 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
self._chunk_mode = self._retrieve_config.get("chunk_mode", "multi_lines")
|
||||
self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True)
|
||||
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
|
||||
self._embedding_function = self._retrieve_config.get("embedding_function", None)
|
||||
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
|
||||
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
|
||||
self.update_context = self._retrieve_config.get("update_context", True)
|
||||
self._get_or_create = self._retrieve_config.get("get_or_create", False)
|
||||
self._get_or_create = (
|
||||
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
|
||||
)
|
||||
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
|
||||
self._context_max_tokens = self._max_tokens * 0.8
|
||||
self._collection = False # the collection is not created
|
||||
self._collection = True if self._docs_path is None else False # whether the collection is created
|
||||
self._ipython = get_ipython()
|
||||
self._doc_idx = -1 # the index of the current used doc
|
||||
self._results = {} # the results of the current query
|
||||
self._intermediate_answers = set() # the intermediate answers
|
||||
self._doc_contents = [] # the contents of the current used doc
|
||||
self._doc_ids = [] # the ids of the current used doc
|
||||
self._is_termination_msg = self._is_termination_msg_retrievechat # update the termination message function
|
||||
# update the termination message function
|
||||
self._is_termination_msg = (
|
||||
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
|
||||
)
|
||||
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=1)
|
||||
|
||||
def _is_termination_msg_retrievechat(self, message):
|
||||
@ -188,7 +227,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
self._doc_contents = [] # the contents of the current used doc
|
||||
self._doc_ids = [] # the ids of the current used doc
|
||||
|
||||
def _get_context(self, results):
|
||||
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
|
||||
doc_contents = ""
|
||||
current_tokens = 0
|
||||
_doc_idx = self._doc_idx
|
||||
@ -297,6 +336,22 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
return False, None
|
||||
|
||||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
|
||||
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
|
||||
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
|
||||
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
|
||||
parameters and add your own parameters with default values, and keep the results in below type.
|
||||
|
||||
Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
|
||||
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
|
||||
to `chromadb.api.types.QueryResult` as an example.
|
||||
ids: List[string]
|
||||
documents: List[List[string]]
|
||||
|
||||
Args:
|
||||
problem (str): the problem to be solved.
|
||||
n_results (int): the number of results to be retrieved.
|
||||
search_string (str): only docs containing this string will be retrieved.
|
||||
"""
|
||||
if not self._collection or self._get_or_create:
|
||||
print("Trying to create collection.")
|
||||
create_vector_db_from_dir(
|
||||
@ -308,6 +363,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
must_break_at_empty_line=self._must_break_at_empty_line,
|
||||
embedding_model=self._embedding_model,
|
||||
get_or_create=self._get_or_create,
|
||||
embedding_function=self._embedding_function,
|
||||
)
|
||||
self._collection = True
|
||||
self._get_or_create = False
|
||||
@ -319,6 +375,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
client=self._client,
|
||||
collection_name=self._collection_name,
|
||||
embedding_model=self._embedding_model,
|
||||
embedding_function=self._embedding_function,
|
||||
)
|
||||
self._results = results
|
||||
print("doc_ids: ", results["ids"])
|
||||
|
||||
@ -6,6 +6,7 @@ import glob
|
||||
import tiktoken
|
||||
import chromadb
|
||||
from chromadb.api import API
|
||||
from chromadb.api.types import QueryResult
|
||||
import chromadb.utils.embedding_functions as ef
|
||||
import logging
|
||||
import pypdf
|
||||
@ -263,12 +264,36 @@ def create_vector_db_from_dir(
|
||||
chunk_mode: str = "multi_lines",
|
||||
must_break_at_empty_line: bool = True,
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
embedding_function: Callable = None,
|
||||
):
|
||||
"""Create a vector db from all the files in a given directory."""
|
||||
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
|
||||
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
|
||||
you prepared your own vector db.
|
||||
|
||||
Args:
|
||||
dir_path (str): the path to the directory, file or url.
|
||||
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
|
||||
client (Optional, API): the chromadb client. Default is None.
|
||||
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
|
||||
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
|
||||
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
|
||||
will be recreated if it already exists.
|
||||
chunk_mode (Optional, str): the chunk mode. Default is "multi_lines".
|
||||
must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True.
|
||||
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
|
||||
embedding_function is not None.
|
||||
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
|
||||
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
|
||||
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
|
||||
"""
|
||||
if client is None:
|
||||
client = chromadb.PersistentClient(path=db_path)
|
||||
try:
|
||||
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
|
||||
embedding_function = (
|
||||
ef.SentenceTransformerEmbeddingFunction(embedding_model)
|
||||
if embedding_function is None
|
||||
else embedding_function
|
||||
)
|
||||
collection = client.create_collection(
|
||||
collection_name,
|
||||
get_or_create=get_or_create,
|
||||
@ -300,14 +325,41 @@ def query_vector_db(
|
||||
collection_name: str = "all-my-documents",
|
||||
search_string: str = "",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
) -> Dict[str, List[str]]:
|
||||
"""Query a vector db."""
|
||||
embedding_function: Callable = None,
|
||||
) -> QueryResult:
|
||||
"""Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db
|
||||
and query function.
|
||||
|
||||
Args:
|
||||
query_texts (List[str]): the query texts.
|
||||
n_results (Optional, int): the number of results to return. Default is 10.
|
||||
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
|
||||
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
|
||||
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
|
||||
search_string (Optional, str): the search string. Default is "".
|
||||
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
|
||||
embedding_function is not None.
|
||||
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
|
||||
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
|
||||
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
|
||||
|
||||
Returns:
|
||||
QueryResult: the query result. The format is:
|
||||
class QueryResult(TypedDict):
|
||||
ids: List[IDs]
|
||||
embeddings: Optional[List[List[Embedding]]]
|
||||
documents: Optional[List[List[Document]]]
|
||||
metadatas: Optional[List[List[Metadata]]]
|
||||
distances: Optional[List[List[float]]]
|
||||
"""
|
||||
if client is None:
|
||||
client = chromadb.PersistentClient(path=db_path)
|
||||
# the collection's embedding function is always the default one, but we want to use the one we used to create the
|
||||
# collection. So we compute the embeddings ourselves and pass it to the query function.
|
||||
collection = client.get_collection(collection_name)
|
||||
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
|
||||
embedding_function = (
|
||||
ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function
|
||||
)
|
||||
query_embeddings = embedding_function(query_texts)
|
||||
# Query/search n most similar results. You can also .get by id
|
||||
results = collection.query(
|
||||
|
||||
1
setup.py
1
setup.py
@ -40,6 +40,7 @@ setuptools.setup(
|
||||
extras_require={
|
||||
"test": [
|
||||
"chromadb",
|
||||
"lancedb",
|
||||
"coverage>=5.3",
|
||||
"datasets",
|
||||
"ipykernel",
|
||||
|
||||
@ -12,6 +12,7 @@ try:
|
||||
)
|
||||
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions as ef
|
||||
|
||||
skip_test = False
|
||||
except ImportError:
|
||||
@ -49,6 +50,7 @@ def test_retrievechat():
|
||||
},
|
||||
)
|
||||
|
||||
sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction()
|
||||
ragproxyagent = RetrieveUserProxyAgent(
|
||||
name="ragproxyagent",
|
||||
human_input_mode="NEVER",
|
||||
@ -58,6 +60,7 @@ def test_retrievechat():
|
||||
"chunk_token_size": 2000,
|
||||
"model": config_list[0]["model"],
|
||||
"client": chromadb.PersistentClient(path="/tmp/chromadb"),
|
||||
"embedding_function": sentence_transformer_ef,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -100,6 +100,70 @@ class TestRetrieveUtils:
|
||||
results = query_vector_db(["autogen"], client=client)
|
||||
assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", []))
|
||||
|
||||
def test_custom_vector_db(self):
|
||||
try:
|
||||
import lancedb
|
||||
except ImportError:
|
||||
return
|
||||
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
|
||||
|
||||
db_path = "/tmp/lancedb"
|
||||
|
||||
def create_lancedb():
|
||||
db = lancedb.connect(db_path)
|
||||
data = [
|
||||
{"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"},
|
||||
{"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"},
|
||||
{"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"},
|
||||
{"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"},
|
||||
{"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"},
|
||||
{"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"},
|
||||
]
|
||||
try:
|
||||
db.create_table("my_table", data)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
|
||||
def query_vector_db(
|
||||
self,
|
||||
query_texts,
|
||||
n_results=10,
|
||||
search_string="",
|
||||
):
|
||||
if query_texts:
|
||||
vector = [0.1, 0.3]
|
||||
db = lancedb.connect(db_path)
|
||||
table = db.open_table("my_table")
|
||||
query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df()
|
||||
return {"ids": query["id"].tolist(), "documents": query["documents"].tolist()}
|
||||
|
||||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
|
||||
results = self.query_vector_db(
|
||||
query_texts=[problem],
|
||||
n_results=n_results,
|
||||
search_string=search_string,
|
||||
)
|
||||
|
||||
self._results = results
|
||||
print("doc_ids: ", results["ids"])
|
||||
|
||||
ragragproxyagent = MyRetrieveUserProxyAgent(
|
||||
name="ragproxyagent",
|
||||
human_input_mode="NEVER",
|
||||
max_consecutive_auto_reply=2,
|
||||
retrieve_config={
|
||||
"task": "qa",
|
||||
"chunk_token_size": 2000,
|
||||
"client": "__",
|
||||
"embedding_model": "all-mpnet-base-v2",
|
||||
},
|
||||
)
|
||||
|
||||
create_lancedb()
|
||||
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
|
||||
assert ragragproxyagent._results["ids"] == [3, 1, 5]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user