Cleanup query api - remove code duplication (#1690)

* consolidate query api functions and remove code duplication

* refactor and remove more code duplication

* Add semversioner file

* fix basic search

* fix drift search and update base class function names

* update example notebooks
This commit is contained in:
Josh Bradley 2025-02-13 16:31:08 -05:00 committed by GitHub
parent fe461417b5
commit b8b949f3bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 131 additions and 344 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "cleanup query code duplication."
}

View File

@ -327,7 +327,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -388,7 +388,7 @@
} }
], ],
"source": [ "source": [
"resp = await search.asearch(\"Who is agent Mercer?\")" "resp = await search.search(\"Who is agent Mercer?\")"
] ]
}, },
{ {

View File

@ -392,7 +392,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -420,7 +420,7 @@
} }
], ],
"source": [ "source": [
"result = await search_engine.asearch(\n", "result = await search_engine.search(\n",
" \"What is Cosmic Vocalization and who are involved in it?\"\n", " \"What is Cosmic Vocalization and who are involved in it?\"\n",
")\n", ")\n",
"\n", "\n",

View File

@ -394,7 +394,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -420,7 +420,7 @@
} }
], ],
"source": [ "source": [
"result = await search_engine.asearch(\n", "result = await search_engine.search(\n",
" \"What is Cosmic Vocalization and who are involved in it?\"\n", " \"What is Cosmic Vocalization and who are involved in it?\"\n",
")\n", ")\n",
"\n", "\n",

View File

@ -963,7 +963,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -991,13 +991,13 @@
} }
], ],
"source": [ "source": [
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n", "result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
"print(result.response)" "print(result.response)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -1026,7 +1026,7 @@
], ],
"source": [ "source": [
"question = \"Tell me about Dr. Jordan Hayes\"\n", "question = \"Tell me about Dr. Jordan Hayes\"\n",
"result = await search_engine.asearch(question)\n", "result = await search_engine.search(question)\n",
"print(result.response)" "print(result.response)"
] ]
}, },

View File

@ -384,7 +384,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n", "result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
"print(result.response)" "print(result.response)"
] ]
}, },
@ -395,7 +395,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"question = \"Tell me about Dr. Jordan Hayes\"\n", "question = \"Tell me about Dr. Jordan Hayes\"\n",
"result = await search_engine.asearch(question)\n", "result = await search_engine.search(question)\n",
"print(result.response)" "print(result.response)"
] ]
}, },

View File

@ -18,7 +18,7 @@ Backwards compatibility is not guaranteed at this time.
""" """
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any from typing import Any
import pandas as pd import pandas as pd
from pydantic import validate_call from pydantic import validate_call
@ -53,9 +53,6 @@ from graphrag.utils.api import (
) )
from graphrag.utils.cli import redact from graphrag.utils.cli import redact
if TYPE_CHECKING:
from graphrag.query.structured_search.base import SearchResult
logger = PrintProgressLogger("") logger = PrintProgressLogger("")
@ -94,40 +91,27 @@ async def global_search(
------ ------
TODO: Document any exceptions to expect. TODO: Document any exceptions to expect.
""" """
communities_ = read_indexer_communities(communities, community_reports) full_response = ""
reports = read_indexer_reports( context_data = {}
community_reports, get_context_data = True
communities, # NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in global_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
community_level=community_level, community_level=community_level,
dynamic_community_selection=dynamic_community_selection, dynamic_community_selection=dynamic_community_selection,
)
entities_ = read_indexer_entities(
entities, communities, community_level=community_level
)
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=entities_,
communities=communities_,
response_type=response_type, response_type=response_type,
dynamic_community_selection=dynamic_community_selection, query=query,
map_system_prompt=map_prompt, ):
reduce_system_prompt=reduce_prompt, if get_context_data:
general_knowledge_inclusion_prompt=knowledge_prompt, context_data = chunk
) get_context_data = False
result: SearchResult = await search_engine.asearch(query=query) else:
response = result.response full_response += chunk
context_data = reformat_context_data(result.context_data) # type: ignore return full_response, context_data
return response, context_data
@validate_call(config={"arbitrary_types_allowed": True}) @validate_call(config={"arbitrary_types_allowed": True})
@ -193,11 +177,11 @@ async def global_search_streaming(
reduce_system_prompt=reduce_prompt, reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt, general_knowledge_inclusion_prompt=knowledge_prompt,
) )
search_result = search_engine.astream_search(query=query) search_result = search_engine.stream_search(query=query)
# when streaming results, a context data object is returned as the first result # NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens # and the query response in subsequent tokens
context_data = None context_data = {}
get_context_data = True get_context_data = True
async for stream_chunk in search_result: async for stream_chunk in search_result:
if get_context_data: if get_context_data:
@ -385,34 +369,29 @@ async def local_search(
------ ------
TODO: Document any exceptions to expect. TODO: Document any exceptions to expect.
""" """
vector_store_args = {} full_response = ""
for index, store in config.vector_store.items(): context_data = {}
vector_store_args[index] = store.model_dump() get_context_data = True
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa # NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
description_embedding_store = get_embedding_store( async for chunk in local_search_streaming(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
search_engine = get_local_search_engine(
config=config, config=config,
reports=read_indexer_reports(community_reports, communities, community_level), entities=entities,
text_units=read_indexer_text_units(text_units), communities=communities,
entities=entities_, community_reports=community_reports,
relationships=read_indexer_relationships(relationships), text_units=text_units,
covariates={"claims": covariates_}, relationships=relationships,
description_embedding_store=description_embedding_store, # type: ignore covariates=covariates,
community_level=community_level,
response_type=response_type, response_type=response_type,
system_prompt=prompt, query=query,
) ):
if get_context_data:
result: SearchResult = await search_engine.asearch(query=query) context_data = chunk
response = result.response get_context_data = False
context_data = reformat_context_data(result.context_data) # type: ignore else:
return response, context_data full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True}) @validate_call(config={"arbitrary_types_allowed": True})
@ -475,11 +454,11 @@ async def local_search_streaming(
response_type=response_type, response_type=response_type,
system_prompt=prompt, system_prompt=prompt,
) )
search_result = search_engine.astream_search(query=query) search_result = search_engine.stream_search(query=query)
# when streaming results, a context data object is returned as the first result # NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens # and the query response in subsequent tokens
context_data = None context_data = {}
get_context_data = True get_context_data = True
async for stream_chunk in search_result: async for stream_chunk in search_result:
if get_context_data: if get_context_data:
@ -751,47 +730,28 @@ async def drift_search(
------ ------
TODO: Document any exceptions to expect. TODO: Document any exceptions to expect.
""" """
vector_store_args = {} full_response = ""
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)
full_content_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=community_full_content_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
reports = read_indexer_reports(community_reports, communities, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)
search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = {} context_data = {}
for key in result.context_data: get_context_data = True
context_data[key] = reformat_context_data(result.context_data[key]) # type: ignore # NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
return response, context_data async for chunk in drift_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
text_units=text_units,
relationships=relationships,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True}) @validate_call(config={"arbitrary_types_allowed": True})
@ -860,12 +820,11 @@ async def drift_search_streaming(
reduce_system_prompt=reduce_prompt, reduce_system_prompt=reduce_prompt,
response_type=response_type, response_type=response_type,
) )
search_result = search_engine.stream_search(query=query)
search_result = search_engine.astream_search(query=query) # NOTE: when streaming results, a context data object is returned as the first result
# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens # and the query response in subsequent tokens
context_data = None context_data = {}
get_context_data = True get_context_data = True
async for stream_chunk in search_result: async for stream_chunk in search_result:
if get_context_data: if get_context_data:
@ -1105,29 +1064,22 @@ async def basic_search(
------ ------
TODO: Document any exceptions to expect. TODO: Document any exceptions to expect.
""" """
vector_store_args = {} full_response = ""
for index, store in config.vector_store.items(): context_data = {}
vector_store_args[index] = store.model_dump() get_context_data = True
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa # NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
description_embedding_store = get_embedding_store( async for chunk in basic_search_streaming(
config_args=vector_store_args, # type: ignore
embedding_name=text_unit_text_embedding,
)
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
search_engine = get_basic_search_engine(
config=config, config=config,
text_units=read_indexer_text_units(text_units), text_units=text_units,
text_unit_embeddings=description_embedding_store, query=query,
system_prompt=prompt, ):
) if get_context_data:
context_data = chunk
result: SearchResult = await search_engine.asearch(query=query) get_context_data = False
response = result.response else:
context_data = reformat_context_data(result.context_data) # type: ignore full_response += chunk
return response, context_data return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True}) @validate_call(config={"arbitrary_types_allowed": True})
@ -1155,8 +1107,6 @@ async def basic_search_streaming(
vector_store_args = {} vector_store_args = {}
for index, store in config.vector_store.items(): for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump() vector_store_args[index] = store.model_dump()
else:
vector_store_args = None
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store( description_embedding_store = get_embedding_store(
@ -1172,12 +1122,11 @@ async def basic_search_streaming(
text_unit_embeddings=description_embedding_store, text_unit_embeddings=description_embedding_store,
system_prompt=prompt, system_prompt=prompt,
) )
search_result = search_engine.stream_search(query=query)
search_result = search_engine.astream_search(query=query) # NOTE: when streaming results, a context data object is returned as the first result
# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens # and the query response in subsequent tokens
context_data = None context_data = {}
get_context_data = True get_context_data = True
async for stream_chunk in search_result: async for stream_chunk in search_result:
if get_context_data: if get_context_data:

View File

@ -69,27 +69,22 @@ class BaseSearch(ABC, Generic[T]):
self.context_builder_params = context_builder_params or {} self.context_builder_params = context_builder_params or {}
@abstractmethod @abstractmethod
def search( async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Search for the given query."""
@abstractmethod
async def asearch(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
**kwargs, **kwargs,
) -> SearchResult: ) -> SearchResult:
"""Search for the given query asynchronously.""" """Search for the given query asynchronously."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)
@abstractmethod @abstractmethod
def astream_search( def stream_search(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[str, None] | None: ) -> AsyncGenerator[Any, None]:
"""Stream search for the given query.""" """Stream search for the given query."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)

View File

@ -55,7 +55,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
self.callbacks = callbacks self.callbacks = callbacks
self.response_type = response_type self.response_type = response_type
async def asearch( async def search(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
@ -121,77 +121,11 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
output_tokens=0, output_tokens=0,
) )
def search( async def stream_search(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
**kwargs, ) -> AsyncGenerator[Any, None]:
) -> SearchResult:
"""Build basic search context that fits a single context window and generate answer for the user question."""
start_time = time.time()
search_prompt = ""
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
try:
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
response = self.llm.generate(
messages=search_messages,
streaming=True,
callbacks=self.callbacks,
**self.llm_params,
)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
return SearchResult(
response=response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
log.exception("Exception in _map_response_single_batch")
return SearchResult(
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
async def astream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
"""Build basic search context that fits a single context window and generate answer for the user query.""" """Build basic search context that fits a single context window and generate answer for the user query."""
start_time = time.time() start_time = time.time()

View File

@ -50,7 +50,7 @@ class DriftAction:
"""Check if the action is complete (i.e., an answer is available).""" """Check if the action is complete (i.e., an answer is available)."""
return self.answer is not None return self.answer is not None
async def asearch(self, search_engine: Any, global_query: str, scorer: Any = None): async def search(self, search_engine: Any, global_query: str, scorer: Any = None):
""" """
Execute an asynchronous search using the search engine, and update the action with the results. Execute an asynchronous search using the search engine, and update the action with the results.
@ -70,7 +70,7 @@ class DriftAction:
log.warning("Action already complete. Skipping search.") log.warning("Action already complete. Skipping search.")
return self return self
search_result = await search_engine.asearch( search_result = await search_engine.search(
drift_query=global_query, query=self.query drift_query=global_query, query=self.query
) )

View File

@ -154,7 +154,7 @@ class DRIFTPrimer:
return parsed_response, token_ct return parsed_response, token_ct
async def asearch( async def search(
self, self,
query: str, query: str,
top_k_reports: pd.DataFrame, top_k_reports: pd.DataFrame,

View File

@ -144,7 +144,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
error_msg = "Response must be a list of dictionaries." error_msg = "Response must be a list of dictionaries."
raise ValueError(error_msg) raise ValueError(error_msg)
async def asearch_step( async def _search_step(
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction] self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
) -> list[DriftAction]: ) -> list[DriftAction]:
""" """
@ -160,12 +160,12 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
list[DriftAction]: The results from executing the search actions asynchronously. list[DriftAction]: The results from executing the search actions asynchronously.
""" """
tasks = [ tasks = [
action.asearch(search_engine=search_engine, global_query=global_query) action.search(search_engine=search_engine, global_query=global_query)
for action in actions for action in actions
] ]
return await tqdm_asyncio.gather(*tasks, leave=False) return await tqdm_asyncio.gather(*tasks, leave=False)
async def asearch( async def search(
self, self,
query: str, query: str,
conversation_history: Any = None, conversation_history: Any = None,
@ -204,7 +204,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
prompt_tokens["build_context"] = token_ct["prompt_tokens"] prompt_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["prompt_tokens"] output_tokens["build_context"] = token_ct["prompt_tokens"]
primer_response = await self.primer.asearch( primer_response = await self.primer.search(
query=query, top_k_reports=primer_context query=query, top_k_reports=primer_context
) )
llm_calls["primer"] = primer_response.llm_calls llm_calls["primer"] = primer_response.llm_calls
@ -229,7 +229,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
len(actions) - self.context_builder.config.drift_k_followups len(actions) - self.context_builder.config.drift_k_followups
) )
# Process actions # Process actions
results = await self.asearch_step( results = await self._search_step(
global_query=query, search_engine=self.local_search, actions=actions global_query=query, search_engine=self.local_search, actions=actions
) )
@ -278,37 +278,17 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
output_tokens_categories=output_tokens, output_tokens_categories=output_tokens,
) )
def search( async def stream_search(
self,
query: str,
conversation_history: Any = None,
**kwargs,
) -> SearchResult:
"""
Perform a synchronous DRIFT search (Not Implemented).
Args:
query (str): The query to search for.
conversation_history (Any, optional): The conversation history.
Raises
------
NotImplementedError: Synchronous DRIFT is not implemented.
"""
error_msg = "Synchronous DRIFT is not implemented."
raise NotImplementedError(error_msg)
async def astream_search(
self, query: str, conversation_history: ConversationHistory | None = None self, query: str, conversation_history: ConversationHistory | None = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
Perform a streaming DRIFT search (Not Implemented). Perform a streaming DRIFT search asynchronously.
Args: Args:
query (str): The query to search for. query (str): The query to search for.
conversation_history (ConversationHistory, optional): The conversation history. conversation_history (ConversationHistory, optional): The conversation history.
""" """
result = await self.asearch( result = await self.search(
query=query, conversation_history=conversation_history, reduce=False query=query, conversation_history=conversation_history, reduce=False
) )

View File

@ -102,7 +102,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
self.semaphore = asyncio.Semaphore(concurrent_coroutines) self.semaphore = asyncio.Semaphore(concurrent_coroutines)
async def astream_search( async def stream_search(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
@ -135,7 +135,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
): ):
yield response yield response
async def asearch( async def search(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
@ -204,15 +204,6 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
output_tokens_categories=output_tokens, output_tokens_categories=output_tokens,
) )
def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs: Any,
) -> GlobalSearchResult:
"""Perform a global search synchronously."""
return asyncio.run(self.asearch(query, conversation_history))
async def _map_response_single_batch( async def _map_response_single_batch(
self, self,
context_data: str, context_data: str,
@ -235,7 +226,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
log.info("Map response: %s", search_response) log.info("Map response: %s", search_response)
try: try:
# parse search response json # parse search response json
processed_response = self.parse_search_response(search_response) processed_response = self._parse_search_response(search_response)
except ValueError: except ValueError:
log.warning( log.warning(
"Warning: Error parsing search response json - skipping this batch" "Warning: Error parsing search response json - skipping this batch"
@ -264,7 +255,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
output_tokens=0, output_tokens=0,
) )
def parse_search_response(self, search_response: str) -> list[dict[str, Any]]: def _parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
"""Parse the search response json and return a list of key points. """Parse the search response json and return a list of key points.
Parameters Parameters

View File

@ -54,7 +54,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
self.callbacks = callbacks self.callbacks = callbacks
self.response_type = response_type self.response_type = response_type
async def asearch( async def search(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
@ -128,7 +128,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
output_tokens=0, output_tokens=0,
) )
async def astream_search( async def stream_search(
self, self,
query: str, query: str,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
@ -158,69 +158,3 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
**self.llm_params, **self.llm_params,
): ):
yield response yield response
def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Build local search context that fits a single context window and generate answer for the user question."""
start_time = time.time()
search_prompt = ""
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
try:
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
response = self.llm.generate(
messages=search_messages,
streaming=True,
callbacks=self.callbacks,
**self.llm_params,
)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
return SearchResult(
response=response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
log.exception("Exception in _map_response_single_batch")
return SearchResult(
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)