Cleanup query api - remove code duplication (#1690)

* consolidate query api functions and remove code duplication

* refactor and remove more code duplication

* Add semversioner file

* fix basic search

* fix drift search and update base class function names

* update example notebooks
This commit is contained in:
Josh Bradley 2025-02-13 16:31:08 -05:00 committed by GitHub
parent fe461417b5
commit b8b949f3bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 131 additions and 344 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "cleanup query code duplication."
}

View File

@ -327,7 +327,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -388,7 +388,7 @@
}
],
"source": [
"resp = await search.asearch(\"Who is agent Mercer?\")"
"resp = await search.search(\"Who is agent Mercer?\")"
]
},
{

View File

@ -392,7 +392,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -420,7 +420,7 @@
}
],
"source": [
"result = await search_engine.asearch(\n",
"result = await search_engine.search(\n",
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
")\n",
"\n",

View File

@ -394,7 +394,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -420,7 +420,7 @@
}
],
"source": [
"result = await search_engine.asearch(\n",
"result = await search_engine.search(\n",
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
")\n",
"\n",

View File

@ -963,7 +963,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -991,13 +991,13 @@
}
],
"source": [
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n",
"result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
"print(result.response)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -1026,7 +1026,7 @@
],
"source": [
"question = \"Tell me about Dr. Jordan Hayes\"\n",
"result = await search_engine.asearch(question)\n",
"result = await search_engine.search(question)\n",
"print(result.response)"
]
},

View File

@ -384,7 +384,7 @@
"metadata": {},
"outputs": [],
"source": [
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n",
"result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
"print(result.response)"
]
},
@ -395,7 +395,7 @@
"outputs": [],
"source": [
"question = \"Tell me about Dr. Jordan Hayes\"\n",
"result = await search_engine.asearch(question)\n",
"result = await search_engine.search(question)\n",
"print(result.response)"
]
},

View File

@ -18,7 +18,7 @@ Backwards compatibility is not guaranteed at this time.
"""
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any
from typing import Any
import pandas as pd
from pydantic import validate_call
@ -53,9 +53,6 @@ from graphrag.utils.api import (
)
from graphrag.utils.cli import redact
if TYPE_CHECKING:
from graphrag.query.structured_search.base import SearchResult
logger = PrintProgressLogger("")
@ -94,40 +91,27 @@ async def global_search(
------
TODO: Document any exceptions to expect.
"""
communities_ = read_indexer_communities(communities, community_reports)
reports = read_indexer_reports(
community_reports,
communities,
full_response = ""
context_data = {}
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in global_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
)
entities_ = read_indexer_entities(
entities, communities, community_level=community_level
)
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=entities_,
communities=communities_,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = reformat_context_data(result.context_data) # type: ignore
return response, context_data
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
@ -193,11 +177,11 @@ async def global_search_streaming(
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
)
search_result = search_engine.astream_search(query=query)
search_result = search_engine.stream_search(query=query)
# when streaming results, a context data object is returned as the first result
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
@ -385,34 +369,29 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
search_engine = get_local_search_engine(
full_response = ""
context_data = {}
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in local_search_streaming(
config=config,
reports=read_indexer_reports(community_reports, communities, community_level),
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
covariates={"claims": covariates_},
description_embedding_store=description_embedding_store, # type: ignore
entities=entities,
communities=communities,
community_reports=community_reports,
text_units=text_units,
relationships=relationships,
covariates=covariates,
community_level=community_level,
response_type=response_type,
system_prompt=prompt,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = reformat_context_data(result.context_data) # type: ignore
return response, context_data
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
@ -475,11 +454,11 @@ async def local_search_streaming(
response_type=response_type,
system_prompt=prompt,
)
search_result = search_engine.astream_search(query=query)
search_result = search_engine.stream_search(query=query)
# when streaming results, a context data object is returned as the first result
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
@ -751,47 +730,28 @@ async def drift_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)
full_content_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=community_full_content_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
reports = read_indexer_reports(community_reports, communities, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)
search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
full_response = ""
context_data = {}
for key in result.context_data:
context_data[key] = reformat_context_data(result.context_data[key]) # type: ignore
return response, context_data
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in drift_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
text_units=text_units,
relationships=relationships,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
@ -860,12 +820,11 @@ async def drift_search_streaming(
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)
search_result = search_engine.stream_search(query=query)
search_result = search_engine.astream_search(query=query)
# when streaming results, a context data object is returned as the first result
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
@ -1105,29 +1064,22 @@ async def basic_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=text_unit_text_embedding,
)
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
search_engine = get_basic_search_engine(
full_response = ""
context_data = {}
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in basic_search_streaming(
config=config,
text_units=read_indexer_text_units(text_units),
text_unit_embeddings=description_embedding_store,
system_prompt=prompt,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = reformat_context_data(result.context_data) # type: ignore
return response, context_data
text_units=text_units,
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
@ -1155,8 +1107,6 @@ async def basic_search_streaming(
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
else:
vector_store_args = None
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
@ -1172,12 +1122,11 @@ async def basic_search_streaming(
text_unit_embeddings=description_embedding_store,
system_prompt=prompt,
)
search_result = search_engine.stream_search(query=query)
search_result = search_engine.astream_search(query=query)
# when streaming results, a context data object is returned as the first result
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:

View File

@ -69,27 +69,22 @@ class BaseSearch(ABC, Generic[T]):
self.context_builder_params = context_builder_params or {}
@abstractmethod
def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Search for the given query."""
@abstractmethod
async def asearch(
async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Search for the given query asynchronously."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)
@abstractmethod
def astream_search(
def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[str, None] | None:
) -> AsyncGenerator[Any, None]:
"""Stream search for the given query."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)

View File

@ -55,7 +55,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
self.callbacks = callbacks
self.response_type = response_type
async def asearch(
async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
@ -121,77 +121,11 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
output_tokens=0,
)
def search(
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Build basic search context that fits a single context window and generate answer for the user question."""
start_time = time.time()
search_prompt = ""
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
try:
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
response = self.llm.generate(
messages=search_messages,
streaming=True,
callbacks=self.callbacks,
**self.llm_params,
)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
return SearchResult(
response=response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
log.exception("Exception in _map_response_single_batch")
return SearchResult(
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
async def astream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
) -> AsyncGenerator[Any, None]:
"""Build basic search context that fits a single context window and generate answer for the user query."""
start_time = time.time()

View File

@ -50,7 +50,7 @@ class DriftAction:
"""Check if the action is complete (i.e., an answer is available)."""
return self.answer is not None
async def asearch(self, search_engine: Any, global_query: str, scorer: Any = None):
async def search(self, search_engine: Any, global_query: str, scorer: Any = None):
"""
Execute an asynchronous search using the search engine, and update the action with the results.
@ -70,7 +70,7 @@ class DriftAction:
log.warning("Action already complete. Skipping search.")
return self
search_result = await search_engine.asearch(
search_result = await search_engine.search(
drift_query=global_query, query=self.query
)

View File

@ -154,7 +154,7 @@ class DRIFTPrimer:
return parsed_response, token_ct
async def asearch(
async def search(
self,
query: str,
top_k_reports: pd.DataFrame,

View File

@ -144,7 +144,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
error_msg = "Response must be a list of dictionaries."
raise ValueError(error_msg)
async def asearch_step(
async def _search_step(
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
) -> list[DriftAction]:
"""
@ -160,12 +160,12 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
list[DriftAction]: The results from executing the search actions asynchronously.
"""
tasks = [
action.asearch(search_engine=search_engine, global_query=global_query)
action.search(search_engine=search_engine, global_query=global_query)
for action in actions
]
return await tqdm_asyncio.gather(*tasks, leave=False)
async def asearch(
async def search(
self,
query: str,
conversation_history: Any = None,
@ -204,7 +204,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["prompt_tokens"]
primer_response = await self.primer.asearch(
primer_response = await self.primer.search(
query=query, top_k_reports=primer_context
)
llm_calls["primer"] = primer_response.llm_calls
@ -229,7 +229,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
len(actions) - self.context_builder.config.drift_k_followups
)
# Process actions
results = await self.asearch_step(
results = await self._search_step(
global_query=query, search_engine=self.local_search, actions=actions
)
@ -278,37 +278,17 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
output_tokens_categories=output_tokens,
)
def search(
self,
query: str,
conversation_history: Any = None,
**kwargs,
) -> SearchResult:
"""
Perform a synchronous DRIFT search (Not Implemented).
Args:
query (str): The query to search for.
conversation_history (Any, optional): The conversation history.
Raises
------
NotImplementedError: Synchronous DRIFT is not implemented.
"""
error_msg = "Synchronous DRIFT is not implemented."
raise NotImplementedError(error_msg)
async def astream_search(
async def stream_search(
self, query: str, conversation_history: ConversationHistory | None = None
) -> AsyncGenerator[str, None]:
"""
Perform a streaming DRIFT search (Not Implemented).
Perform a streaming DRIFT search asynchronously.
Args:
query (str): The query to search for.
conversation_history (ConversationHistory, optional): The conversation history.
"""
result = await self.asearch(
result = await self.search(
query=query, conversation_history=conversation_history, reduce=False
)

View File

@ -102,7 +102,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
async def astream_search(
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
@ -135,7 +135,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
):
yield response
async def asearch(
async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
@ -204,15 +204,6 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
output_tokens_categories=output_tokens,
)
def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs: Any,
) -> GlobalSearchResult:
"""Perform a global search synchronously."""
return asyncio.run(self.asearch(query, conversation_history))
async def _map_response_single_batch(
self,
context_data: str,
@ -235,7 +226,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
log.info("Map response: %s", search_response)
try:
# parse search response json
processed_response = self.parse_search_response(search_response)
processed_response = self._parse_search_response(search_response)
except ValueError:
log.warning(
"Warning: Error parsing search response json - skipping this batch"
@ -264,7 +255,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
output_tokens=0,
)
def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
def _parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
"""Parse the search response json and return a list of key points.
Parameters

View File

@ -54,7 +54,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
self.callbacks = callbacks
self.response_type = response_type
async def asearch(
async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
@ -128,7 +128,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
output_tokens=0,
)
async def astream_search(
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
@ -158,69 +158,3 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
**self.llm_params,
):
yield response
def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Build local search context that fits a single context window and generate answer for the user question."""
start_time = time.time()
search_prompt = ""
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
try:
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
response = self.llm.generate(
messages=search_messages,
streaming=True,
callbacks=self.callbacks,
**self.llm_params,
)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
return SearchResult(
response=response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
log.exception("Exception in _map_response_single_batch")
return SearchResult(
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)