mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Cleanup query api - remove code duplication (#1690)
* consolidate query api functions and remove code duplication * refactor and remove more code duplication * Add semversioner file * fix basic search * fix drift search and update base class function names * update example notebooks
This commit is contained in:
parent
fe461417b5
commit
b8b949f3bb
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "cleanup query code duplication."
|
||||
}
|
@ -327,7 +327,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -388,7 +388,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"resp = await search.asearch(\"Who is agent Mercer?\")"
|
||||
"resp = await search.search(\"Who is agent Mercer?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -392,7 +392,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -420,7 +420,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = await search_engine.asearch(\n",
|
||||
"result = await search_engine.search(\n",
|
||||
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
|
@ -394,7 +394,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -420,7 +420,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = await search_engine.asearch(\n",
|
||||
"result = await search_engine.search(\n",
|
||||
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
|
@ -963,7 +963,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -991,13 +991,13 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n",
|
||||
"result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
|
||||
"print(result.response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1026,7 +1026,7 @@
|
||||
],
|
||||
"source": [
|
||||
"question = \"Tell me about Dr. Jordan Hayes\"\n",
|
||||
"result = await search_engine.asearch(question)\n",
|
||||
"result = await search_engine.search(question)\n",
|
||||
"print(result.response)"
|
||||
]
|
||||
},
|
||||
|
@ -384,7 +384,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n",
|
||||
"result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
|
||||
"print(result.response)"
|
||||
]
|
||||
},
|
||||
@ -395,7 +395,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"Tell me about Dr. Jordan Hayes\"\n",
|
||||
"result = await search_engine.asearch(question)\n",
|
||||
"result = await search_engine.search(question)\n",
|
||||
"print(result.response)"
|
||||
]
|
||||
},
|
||||
|
@ -18,7 +18,7 @@ Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import validate_call
|
||||
@ -53,9 +53,6 @@ from graphrag.utils.api import (
|
||||
)
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.query.structured_search.base import SearchResult
|
||||
|
||||
logger = PrintProgressLogger("")
|
||||
|
||||
|
||||
@ -94,40 +91,27 @@ async def global_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
communities_ = read_indexer_communities(communities, community_reports)
|
||||
reports = read_indexer_reports(
|
||||
community_reports,
|
||||
communities,
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
get_context_data = True
|
||||
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||
# All subsequent chunks are the query response.
|
||||
async for chunk in global_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
communities=communities,
|
||||
community_reports=community_reports,
|
||||
community_level=community_level,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
)
|
||||
entities_ = read_indexer_entities(
|
||||
entities, communities, community_level=community_level
|
||||
)
|
||||
|
||||
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
|
||||
reduce_prompt = load_search_prompt(
|
||||
config.root_dir, config.global_search.reduce_prompt
|
||||
)
|
||||
knowledge_prompt = load_search_prompt(
|
||||
config.root_dir, config.global_search.knowledge_prompt
|
||||
)
|
||||
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
entities=entities_,
|
||||
communities=communities_,
|
||||
response_type=response_type,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
map_system_prompt=map_prompt,
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
general_knowledge_inclusion_prompt=knowledge_prompt,
|
||||
)
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
context_data = reformat_context_data(result.context_data) # type: ignore
|
||||
return response, context_data
|
||||
query=query,
|
||||
):
|
||||
if get_context_data:
|
||||
context_data = chunk
|
||||
get_context_data = False
|
||||
else:
|
||||
full_response += chunk
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
@ -193,11 +177,11 @@ async def global_search_streaming(
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
general_knowledge_inclusion_prompt=knowledge_prompt,
|
||||
)
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
search_result = search_engine.stream_search(query=query)
|
||||
|
||||
# when streaming results, a context data object is returned as the first result
|
||||
# NOTE: when streaming results, a context data object is returned as the first result
|
||||
# and the query response in subsequent tokens
|
||||
context_data = None
|
||||
context_data = {}
|
||||
get_context_data = True
|
||||
async for stream_chunk in search_result:
|
||||
if get_context_data:
|
||||
@ -385,34 +369,29 @@ async def local_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
embedding_name=entity_description_embedding,
|
||||
)
|
||||
entities_ = read_indexer_entities(entities, communities, community_level)
|
||||
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
|
||||
search_engine = get_local_search_engine(
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
get_context_data = True
|
||||
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||
# All subsequent chunks are the query response.
|
||||
async for chunk in local_search_streaming(
|
||||
config=config,
|
||||
reports=read_indexer_reports(community_reports, communities, community_level),
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
entities=entities_,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
covariates={"claims": covariates_},
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
entities=entities,
|
||||
communities=communities,
|
||||
community_reports=community_reports,
|
||||
text_units=text_units,
|
||||
relationships=relationships,
|
||||
covariates=covariates,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
system_prompt=prompt,
|
||||
)
|
||||
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
context_data = reformat_context_data(result.context_data) # type: ignore
|
||||
return response, context_data
|
||||
query=query,
|
||||
):
|
||||
if get_context_data:
|
||||
context_data = chunk
|
||||
get_context_data = False
|
||||
else:
|
||||
full_response += chunk
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
@ -475,11 +454,11 @@ async def local_search_streaming(
|
||||
response_type=response_type,
|
||||
system_prompt=prompt,
|
||||
)
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
search_result = search_engine.stream_search(query=query)
|
||||
|
||||
# when streaming results, a context data object is returned as the first result
|
||||
# NOTE: when streaming results, a context data object is returned as the first result
|
||||
# and the query response in subsequent tokens
|
||||
context_data = None
|
||||
context_data = {}
|
||||
get_context_data = True
|
||||
async for stream_chunk in search_result:
|
||||
if get_context_data:
|
||||
@ -751,47 +730,28 @@ async def drift_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
embedding_name=entity_description_embedding,
|
||||
)
|
||||
|
||||
full_content_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
embedding_name=community_full_content_embedding,
|
||||
)
|
||||
|
||||
entities_ = read_indexer_entities(entities, communities, community_level)
|
||||
reports = read_indexer_reports(community_reports, communities, community_level)
|
||||
read_indexer_report_embeddings(reports, full_content_embedding_store)
|
||||
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
|
||||
reduce_prompt = load_search_prompt(
|
||||
config.root_dir, config.drift_search.reduce_prompt
|
||||
)
|
||||
search_engine = get_drift_search_engine(
|
||||
config=config,
|
||||
reports=reports,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
entities=entities_,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
local_system_prompt=prompt,
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
for key in result.context_data:
|
||||
context_data[key] = reformat_context_data(result.context_data[key]) # type: ignore
|
||||
|
||||
return response, context_data
|
||||
get_context_data = True
|
||||
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||
# All subsequent chunks are the query response.
|
||||
async for chunk in drift_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
communities=communities,
|
||||
community_reports=community_reports,
|
||||
text_units=text_units,
|
||||
relationships=relationships,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
):
|
||||
if get_context_data:
|
||||
context_data = chunk
|
||||
get_context_data = False
|
||||
else:
|
||||
full_response += chunk
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
@ -860,12 +820,11 @@ async def drift_search_streaming(
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
response_type=response_type,
|
||||
)
|
||||
search_result = search_engine.stream_search(query=query)
|
||||
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
# when streaming results, a context data object is returned as the first result
|
||||
# NOTE: when streaming results, a context data object is returned as the first result
|
||||
# and the query response in subsequent tokens
|
||||
context_data = None
|
||||
context_data = {}
|
||||
get_context_data = True
|
||||
async for stream_chunk in search_result:
|
||||
if get_context_data:
|
||||
@ -1105,29 +1064,22 @@ async def basic_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
embedding_name=text_unit_text_embedding,
|
||||
)
|
||||
|
||||
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
|
||||
|
||||
search_engine = get_basic_search_engine(
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
get_context_data = True
|
||||
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||
# All subsequent chunks are the query response.
|
||||
async for chunk in basic_search_streaming(
|
||||
config=config,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
text_unit_embeddings=description_embedding_store,
|
||||
system_prompt=prompt,
|
||||
)
|
||||
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
context_data = reformat_context_data(result.context_data) # type: ignore
|
||||
return response, context_data
|
||||
text_units=text_units,
|
||||
query=query,
|
||||
):
|
||||
if get_context_data:
|
||||
context_data = chunk
|
||||
get_context_data = False
|
||||
else:
|
||||
full_response += chunk
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
@ -1155,8 +1107,6 @@ async def basic_search_streaming(
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
else:
|
||||
vector_store_args = None
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
@ -1172,12 +1122,11 @@ async def basic_search_streaming(
|
||||
text_unit_embeddings=description_embedding_store,
|
||||
system_prompt=prompt,
|
||||
)
|
||||
search_result = search_engine.stream_search(query=query)
|
||||
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
# when streaming results, a context data object is returned as the first result
|
||||
# NOTE: when streaming results, a context data object is returned as the first result
|
||||
# and the query response in subsequent tokens
|
||||
context_data = None
|
||||
context_data = {}
|
||||
get_context_data = True
|
||||
async for stream_chunk in search_result:
|
||||
if get_context_data:
|
||||
|
@ -69,27 +69,22 @@ class BaseSearch(ABC, Generic[T]):
|
||||
self.context_builder_params = context_builder_params or {}
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
"""Search for the given query."""
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
"""Search for the given query asynchronously."""
|
||||
msg = "Subclasses must implement this method"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@abstractmethod
|
||||
def astream_search(
|
||||
def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator[str, None] | None:
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
"""Stream search for the given query."""
|
||||
msg = "Subclasses must implement this method"
|
||||
raise NotImplementedError(msg)
|
||||
|
@ -55,7 +55,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
self.callbacks = callbacks
|
||||
self.response_type = response_type
|
||||
|
||||
async def asearch(
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
@ -121,77 +121,11 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
def search(
|
||||
async def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
"""Build basic search context that fits a single context window and generate answer for the user question."""
|
||||
start_time = time.time()
|
||||
search_prompt = ""
|
||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
||||
context_result = self.context_builder.build_context(
|
||||
query=query,
|
||||
conversation_history=conversation_history,
|
||||
**kwargs,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
llm_calls["build_context"] = context_result.llm_calls
|
||||
prompt_tokens["build_context"] = context_result.prompt_tokens
|
||||
output_tokens["build_context"] = context_result.output_tokens
|
||||
|
||||
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
|
||||
try:
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_result.context_chunks,
|
||||
response_type=self.response_type,
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
response = self.llm.generate(
|
||||
messages=search_messages,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks,
|
||||
**self.llm_params,
|
||||
)
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
||||
|
||||
return SearchResult(
|
||||
response=response,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=sum(llm_calls.values()),
|
||||
prompt_tokens=sum(prompt_tokens.values()),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
log.exception("Exception in _map_response_single_batch")
|
||||
return SearchResult(
|
||||
response="",
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
async def astream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
"""Build basic search context that fits a single context window and generate answer for the user query."""
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -50,7 +50,7 @@ class DriftAction:
|
||||
"""Check if the action is complete (i.e., an answer is available)."""
|
||||
return self.answer is not None
|
||||
|
||||
async def asearch(self, search_engine: Any, global_query: str, scorer: Any = None):
|
||||
async def search(self, search_engine: Any, global_query: str, scorer: Any = None):
|
||||
"""
|
||||
Execute an asynchronous search using the search engine, and update the action with the results.
|
||||
|
||||
@ -70,7 +70,7 @@ class DriftAction:
|
||||
log.warning("Action already complete. Skipping search.")
|
||||
return self
|
||||
|
||||
search_result = await search_engine.asearch(
|
||||
search_result = await search_engine.search(
|
||||
drift_query=global_query, query=self.query
|
||||
)
|
||||
|
||||
|
@ -154,7 +154,7 @@ class DRIFTPrimer:
|
||||
|
||||
return parsed_response, token_ct
|
||||
|
||||
async def asearch(
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k_reports: pd.DataFrame,
|
||||
|
@ -144,7 +144,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
error_msg = "Response must be a list of dictionaries."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
async def asearch_step(
|
||||
async def _search_step(
|
||||
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
|
||||
) -> list[DriftAction]:
|
||||
"""
|
||||
@ -160,12 +160,12 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
list[DriftAction]: The results from executing the search actions asynchronously.
|
||||
"""
|
||||
tasks = [
|
||||
action.asearch(search_engine=search_engine, global_query=global_query)
|
||||
action.search(search_engine=search_engine, global_query=global_query)
|
||||
for action in actions
|
||||
]
|
||||
return await tqdm_asyncio.gather(*tasks, leave=False)
|
||||
|
||||
async def asearch(
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: Any = None,
|
||||
@ -204,7 +204,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
output_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
|
||||
primer_response = await self.primer.asearch(
|
||||
primer_response = await self.primer.search(
|
||||
query=query, top_k_reports=primer_context
|
||||
)
|
||||
llm_calls["primer"] = primer_response.llm_calls
|
||||
@ -229,7 +229,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
len(actions) - self.context_builder.config.drift_k_followups
|
||||
)
|
||||
# Process actions
|
||||
results = await self.asearch_step(
|
||||
results = await self._search_step(
|
||||
global_query=query, search_engine=self.local_search, actions=actions
|
||||
)
|
||||
|
||||
@ -278,37 +278,17 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: Any = None,
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
"""
|
||||
Perform a synchronous DRIFT search (Not Implemented).
|
||||
|
||||
Args:
|
||||
query (str): The query to search for.
|
||||
conversation_history (Any, optional): The conversation history.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError: Synchronous DRIFT is not implemented.
|
||||
"""
|
||||
error_msg = "Synchronous DRIFT is not implemented."
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
async def astream_search(
|
||||
async def stream_search(
|
||||
self, query: str, conversation_history: ConversationHistory | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Perform a streaming DRIFT search (Not Implemented).
|
||||
Perform a streaming DRIFT search asynchronously.
|
||||
|
||||
Args:
|
||||
query (str): The query to search for.
|
||||
conversation_history (ConversationHistory, optional): The conversation history.
|
||||
"""
|
||||
result = await self.asearch(
|
||||
result = await self.search(
|
||||
query=query, conversation_history=conversation_history, reduce=False
|
||||
)
|
||||
|
||||
|
@ -102,7 +102,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
|
||||
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
|
||||
|
||||
async def astream_search(
|
||||
async def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
@ -135,7 +135,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
):
|
||||
yield response
|
||||
|
||||
async def asearch(
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
@ -204,15 +204,6 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
**kwargs: Any,
|
||||
) -> GlobalSearchResult:
|
||||
"""Perform a global search synchronously."""
|
||||
return asyncio.run(self.asearch(query, conversation_history))
|
||||
|
||||
async def _map_response_single_batch(
|
||||
self,
|
||||
context_data: str,
|
||||
@ -235,7 +226,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
log.info("Map response: %s", search_response)
|
||||
try:
|
||||
# parse search response json
|
||||
processed_response = self.parse_search_response(search_response)
|
||||
processed_response = self._parse_search_response(search_response)
|
||||
except ValueError:
|
||||
log.warning(
|
||||
"Warning: Error parsing search response json - skipping this batch"
|
||||
@ -264,7 +255,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
|
||||
def _parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
|
||||
"""Parse the search response json and return a list of key points.
|
||||
|
||||
Parameters
|
||||
|
@ -54,7 +54,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
self.callbacks = callbacks
|
||||
self.response_type = response_type
|
||||
|
||||
async def asearch(
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
@ -128,7 +128,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
async def astream_search(
|
||||
async def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
@ -158,69 +158,3 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
**self.llm_params,
|
||||
):
|
||||
yield response
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
"""Build local search context that fits a single context window and generate answer for the user question."""
|
||||
start_time = time.time()
|
||||
search_prompt = ""
|
||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
||||
context_result = self.context_builder.build_context(
|
||||
query=query,
|
||||
conversation_history=conversation_history,
|
||||
**kwargs,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
llm_calls["build_context"] = context_result.llm_calls
|
||||
prompt_tokens["build_context"] = context_result.prompt_tokens
|
||||
output_tokens["build_context"] = context_result.output_tokens
|
||||
|
||||
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
|
||||
try:
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_result.context_chunks,
|
||||
response_type=self.response_type,
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
response = self.llm.generate(
|
||||
messages=search_messages,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks,
|
||||
**self.llm_params,
|
||||
)
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
||||
|
||||
return SearchResult(
|
||||
response=response,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=sum(llm_calls.values()),
|
||||
prompt_tokens=sum(prompt_tokens.values()),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
log.exception("Exception in _map_response_single_batch")
|
||||
return SearchResult(
|
||||
response="",
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user