mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Cleanup query api - remove code duplication (#1690)
* consolidate query api functions and remove code duplication * refactor and remove more code duplication * Add semversioner file * fix basic search * fix drift search and update base class function names * update example notebooks
This commit is contained in:
parent
fe461417b5
commit
b8b949f3bb
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "cleanup query code duplication."
|
||||||
|
}
|
@ -327,7 +327,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -388,7 +388,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"resp = await search.asearch(\"Who is agent Mercer?\")"
|
"resp = await search.search(\"Who is agent Mercer?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -392,7 +392,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -420,7 +420,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"result = await search_engine.asearch(\n",
|
"result = await search_engine.search(\n",
|
||||||
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -394,7 +394,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -420,7 +420,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"result = await search_engine.asearch(\n",
|
"result = await search_engine.search(\n",
|
||||||
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -963,7 +963,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -991,13 +991,13 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n",
|
"result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
|
||||||
"print(result.response)"
|
"print(result.response)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -1026,7 +1026,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"question = \"Tell me about Dr. Jordan Hayes\"\n",
|
"question = \"Tell me about Dr. Jordan Hayes\"\n",
|
||||||
"result = await search_engine.asearch(question)\n",
|
"result = await search_engine.search(question)\n",
|
||||||
"print(result.response)"
|
"print(result.response)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -384,7 +384,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n",
|
"result = await search_engine.search(\"Tell me about Agent Mercer\")\n",
|
||||||
"print(result.response)"
|
"print(result.response)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -395,7 +395,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"question = \"Tell me about Dr. Jordan Hayes\"\n",
|
"question = \"Tell me about Dr. Jordan Hayes\"\n",
|
||||||
"result = await search_engine.asearch(question)\n",
|
"result = await search_engine.search(question)\n",
|
||||||
"print(result.response)"
|
"print(result.response)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -18,7 +18,7 @@ Backwards compatibility is not guaranteed at this time.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import validate_call
|
from pydantic import validate_call
|
||||||
@ -53,9 +53,6 @@ from graphrag.utils.api import (
|
|||||||
)
|
)
|
||||||
from graphrag.utils.cli import redact
|
from graphrag.utils.cli import redact
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from graphrag.query.structured_search.base import SearchResult
|
|
||||||
|
|
||||||
logger = PrintProgressLogger("")
|
logger = PrintProgressLogger("")
|
||||||
|
|
||||||
|
|
||||||
@ -94,40 +91,27 @@ async def global_search(
|
|||||||
------
|
------
|
||||||
TODO: Document any exceptions to expect.
|
TODO: Document any exceptions to expect.
|
||||||
"""
|
"""
|
||||||
communities_ = read_indexer_communities(communities, community_reports)
|
full_response = ""
|
||||||
reports = read_indexer_reports(
|
context_data = {}
|
||||||
community_reports,
|
get_context_data = True
|
||||||
communities,
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||||
|
# All subsequent chunks are the query response.
|
||||||
|
async for chunk in global_search_streaming(
|
||||||
|
config=config,
|
||||||
|
entities=entities,
|
||||||
|
communities=communities,
|
||||||
|
community_reports=community_reports,
|
||||||
community_level=community_level,
|
community_level=community_level,
|
||||||
dynamic_community_selection=dynamic_community_selection,
|
dynamic_community_selection=dynamic_community_selection,
|
||||||
)
|
|
||||||
entities_ = read_indexer_entities(
|
|
||||||
entities, communities, community_level=community_level
|
|
||||||
)
|
|
||||||
|
|
||||||
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
|
|
||||||
reduce_prompt = load_search_prompt(
|
|
||||||
config.root_dir, config.global_search.reduce_prompt
|
|
||||||
)
|
|
||||||
knowledge_prompt = load_search_prompt(
|
|
||||||
config.root_dir, config.global_search.knowledge_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
search_engine = get_global_search_engine(
|
|
||||||
config,
|
|
||||||
reports=reports,
|
|
||||||
entities=entities_,
|
|
||||||
communities=communities_,
|
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
dynamic_community_selection=dynamic_community_selection,
|
query=query,
|
||||||
map_system_prompt=map_prompt,
|
):
|
||||||
reduce_system_prompt=reduce_prompt,
|
if get_context_data:
|
||||||
general_knowledge_inclusion_prompt=knowledge_prompt,
|
context_data = chunk
|
||||||
)
|
get_context_data = False
|
||||||
result: SearchResult = await search_engine.asearch(query=query)
|
else:
|
||||||
response = result.response
|
full_response += chunk
|
||||||
context_data = reformat_context_data(result.context_data) # type: ignore
|
return full_response, context_data
|
||||||
return response, context_data
|
|
||||||
|
|
||||||
|
|
||||||
@validate_call(config={"arbitrary_types_allowed": True})
|
@validate_call(config={"arbitrary_types_allowed": True})
|
||||||
@ -193,11 +177,11 @@ async def global_search_streaming(
|
|||||||
reduce_system_prompt=reduce_prompt,
|
reduce_system_prompt=reduce_prompt,
|
||||||
general_knowledge_inclusion_prompt=knowledge_prompt,
|
general_knowledge_inclusion_prompt=knowledge_prompt,
|
||||||
)
|
)
|
||||||
search_result = search_engine.astream_search(query=query)
|
search_result = search_engine.stream_search(query=query)
|
||||||
|
|
||||||
# when streaming results, a context data object is returned as the first result
|
# NOTE: when streaming results, a context data object is returned as the first result
|
||||||
# and the query response in subsequent tokens
|
# and the query response in subsequent tokens
|
||||||
context_data = None
|
context_data = {}
|
||||||
get_context_data = True
|
get_context_data = True
|
||||||
async for stream_chunk in search_result:
|
async for stream_chunk in search_result:
|
||||||
if get_context_data:
|
if get_context_data:
|
||||||
@ -385,34 +369,29 @@ async def local_search(
|
|||||||
------
|
------
|
||||||
TODO: Document any exceptions to expect.
|
TODO: Document any exceptions to expect.
|
||||||
"""
|
"""
|
||||||
vector_store_args = {}
|
full_response = ""
|
||||||
for index, store in config.vector_store.items():
|
context_data = {}
|
||||||
vector_store_args[index] = store.model_dump()
|
get_context_data = True
|
||||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||||
|
# All subsequent chunks are the query response.
|
||||||
description_embedding_store = get_embedding_store(
|
async for chunk in local_search_streaming(
|
||||||
config_args=vector_store_args, # type: ignore
|
|
||||||
embedding_name=entity_description_embedding,
|
|
||||||
)
|
|
||||||
entities_ = read_indexer_entities(entities, communities, community_level)
|
|
||||||
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
|
|
||||||
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
|
|
||||||
search_engine = get_local_search_engine(
|
|
||||||
config=config,
|
config=config,
|
||||||
reports=read_indexer_reports(community_reports, communities, community_level),
|
entities=entities,
|
||||||
text_units=read_indexer_text_units(text_units),
|
communities=communities,
|
||||||
entities=entities_,
|
community_reports=community_reports,
|
||||||
relationships=read_indexer_relationships(relationships),
|
text_units=text_units,
|
||||||
covariates={"claims": covariates_},
|
relationships=relationships,
|
||||||
description_embedding_store=description_embedding_store, # type: ignore
|
covariates=covariates,
|
||||||
|
community_level=community_level,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
system_prompt=prompt,
|
query=query,
|
||||||
)
|
):
|
||||||
|
if get_context_data:
|
||||||
result: SearchResult = await search_engine.asearch(query=query)
|
context_data = chunk
|
||||||
response = result.response
|
get_context_data = False
|
||||||
context_data = reformat_context_data(result.context_data) # type: ignore
|
else:
|
||||||
return response, context_data
|
full_response += chunk
|
||||||
|
return full_response, context_data
|
||||||
|
|
||||||
|
|
||||||
@validate_call(config={"arbitrary_types_allowed": True})
|
@validate_call(config={"arbitrary_types_allowed": True})
|
||||||
@ -475,11 +454,11 @@ async def local_search_streaming(
|
|||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
system_prompt=prompt,
|
system_prompt=prompt,
|
||||||
)
|
)
|
||||||
search_result = search_engine.astream_search(query=query)
|
search_result = search_engine.stream_search(query=query)
|
||||||
|
|
||||||
# when streaming results, a context data object is returned as the first result
|
# NOTE: when streaming results, a context data object is returned as the first result
|
||||||
# and the query response in subsequent tokens
|
# and the query response in subsequent tokens
|
||||||
context_data = None
|
context_data = {}
|
||||||
get_context_data = True
|
get_context_data = True
|
||||||
async for stream_chunk in search_result:
|
async for stream_chunk in search_result:
|
||||||
if get_context_data:
|
if get_context_data:
|
||||||
@ -751,47 +730,28 @@ async def drift_search(
|
|||||||
------
|
------
|
||||||
TODO: Document any exceptions to expect.
|
TODO: Document any exceptions to expect.
|
||||||
"""
|
"""
|
||||||
vector_store_args = {}
|
full_response = ""
|
||||||
for index, store in config.vector_store.items():
|
|
||||||
vector_store_args[index] = store.model_dump()
|
|
||||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
|
||||||
|
|
||||||
description_embedding_store = get_embedding_store(
|
|
||||||
config_args=vector_store_args, # type: ignore
|
|
||||||
embedding_name=entity_description_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
full_content_embedding_store = get_embedding_store(
|
|
||||||
config_args=vector_store_args, # type: ignore
|
|
||||||
embedding_name=community_full_content_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
entities_ = read_indexer_entities(entities, communities, community_level)
|
|
||||||
reports = read_indexer_reports(community_reports, communities, community_level)
|
|
||||||
read_indexer_report_embeddings(reports, full_content_embedding_store)
|
|
||||||
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
|
|
||||||
reduce_prompt = load_search_prompt(
|
|
||||||
config.root_dir, config.drift_search.reduce_prompt
|
|
||||||
)
|
|
||||||
search_engine = get_drift_search_engine(
|
|
||||||
config=config,
|
|
||||||
reports=reports,
|
|
||||||
text_units=read_indexer_text_units(text_units),
|
|
||||||
entities=entities_,
|
|
||||||
relationships=read_indexer_relationships(relationships),
|
|
||||||
description_embedding_store=description_embedding_store, # type: ignore
|
|
||||||
local_system_prompt=prompt,
|
|
||||||
reduce_system_prompt=reduce_prompt,
|
|
||||||
response_type=response_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
result: SearchResult = await search_engine.asearch(query=query)
|
|
||||||
response = result.response
|
|
||||||
context_data = {}
|
context_data = {}
|
||||||
for key in result.context_data:
|
get_context_data = True
|
||||||
context_data[key] = reformat_context_data(result.context_data[key]) # type: ignore
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||||
|
# All subsequent chunks are the query response.
|
||||||
return response, context_data
|
async for chunk in drift_search_streaming(
|
||||||
|
config=config,
|
||||||
|
entities=entities,
|
||||||
|
communities=communities,
|
||||||
|
community_reports=community_reports,
|
||||||
|
text_units=text_units,
|
||||||
|
relationships=relationships,
|
||||||
|
community_level=community_level,
|
||||||
|
response_type=response_type,
|
||||||
|
query=query,
|
||||||
|
):
|
||||||
|
if get_context_data:
|
||||||
|
context_data = chunk
|
||||||
|
get_context_data = False
|
||||||
|
else:
|
||||||
|
full_response += chunk
|
||||||
|
return full_response, context_data
|
||||||
|
|
||||||
|
|
||||||
@validate_call(config={"arbitrary_types_allowed": True})
|
@validate_call(config={"arbitrary_types_allowed": True})
|
||||||
@ -860,12 +820,11 @@ async def drift_search_streaming(
|
|||||||
reduce_system_prompt=reduce_prompt,
|
reduce_system_prompt=reduce_prompt,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
)
|
)
|
||||||
|
search_result = search_engine.stream_search(query=query)
|
||||||
|
|
||||||
search_result = search_engine.astream_search(query=query)
|
# NOTE: when streaming results, a context data object is returned as the first result
|
||||||
|
|
||||||
# when streaming results, a context data object is returned as the first result
|
|
||||||
# and the query response in subsequent tokens
|
# and the query response in subsequent tokens
|
||||||
context_data = None
|
context_data = {}
|
||||||
get_context_data = True
|
get_context_data = True
|
||||||
async for stream_chunk in search_result:
|
async for stream_chunk in search_result:
|
||||||
if get_context_data:
|
if get_context_data:
|
||||||
@ -1105,29 +1064,22 @@ async def basic_search(
|
|||||||
------
|
------
|
||||||
TODO: Document any exceptions to expect.
|
TODO: Document any exceptions to expect.
|
||||||
"""
|
"""
|
||||||
vector_store_args = {}
|
full_response = ""
|
||||||
for index, store in config.vector_store.items():
|
context_data = {}
|
||||||
vector_store_args[index] = store.model_dump()
|
get_context_data = True
|
||||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
||||||
|
# All subsequent chunks are the query response.
|
||||||
description_embedding_store = get_embedding_store(
|
async for chunk in basic_search_streaming(
|
||||||
config_args=vector_store_args, # type: ignore
|
|
||||||
embedding_name=text_unit_text_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
|
|
||||||
|
|
||||||
search_engine = get_basic_search_engine(
|
|
||||||
config=config,
|
config=config,
|
||||||
text_units=read_indexer_text_units(text_units),
|
text_units=text_units,
|
||||||
text_unit_embeddings=description_embedding_store,
|
query=query,
|
||||||
system_prompt=prompt,
|
):
|
||||||
)
|
if get_context_data:
|
||||||
|
context_data = chunk
|
||||||
result: SearchResult = await search_engine.asearch(query=query)
|
get_context_data = False
|
||||||
response = result.response
|
else:
|
||||||
context_data = reformat_context_data(result.context_data) # type: ignore
|
full_response += chunk
|
||||||
return response, context_data
|
return full_response, context_data
|
||||||
|
|
||||||
|
|
||||||
@validate_call(config={"arbitrary_types_allowed": True})
|
@validate_call(config={"arbitrary_types_allowed": True})
|
||||||
@ -1155,8 +1107,6 @@ async def basic_search_streaming(
|
|||||||
vector_store_args = {}
|
vector_store_args = {}
|
||||||
for index, store in config.vector_store.items():
|
for index, store in config.vector_store.items():
|
||||||
vector_store_args[index] = store.model_dump()
|
vector_store_args[index] = store.model_dump()
|
||||||
else:
|
|
||||||
vector_store_args = None
|
|
||||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||||
|
|
||||||
description_embedding_store = get_embedding_store(
|
description_embedding_store = get_embedding_store(
|
||||||
@ -1172,12 +1122,11 @@ async def basic_search_streaming(
|
|||||||
text_unit_embeddings=description_embedding_store,
|
text_unit_embeddings=description_embedding_store,
|
||||||
system_prompt=prompt,
|
system_prompt=prompt,
|
||||||
)
|
)
|
||||||
|
search_result = search_engine.stream_search(query=query)
|
||||||
|
|
||||||
search_result = search_engine.astream_search(query=query)
|
# NOTE: when streaming results, a context data object is returned as the first result
|
||||||
|
|
||||||
# when streaming results, a context data object is returned as the first result
|
|
||||||
# and the query response in subsequent tokens
|
# and the query response in subsequent tokens
|
||||||
context_data = None
|
context_data = {}
|
||||||
get_context_data = True
|
get_context_data = True
|
||||||
async for stream_chunk in search_result:
|
async for stream_chunk in search_result:
|
||||||
if get_context_data:
|
if get_context_data:
|
||||||
|
@ -69,27 +69,22 @@ class BaseSearch(ABC, Generic[T]):
|
|||||||
self.context_builder_params = context_builder_params or {}
|
self.context_builder_params = context_builder_params or {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
async def search(
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
conversation_history: ConversationHistory | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> SearchResult:
|
|
||||||
"""Search for the given query."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def asearch(
|
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> SearchResult:
|
) -> SearchResult:
|
||||||
"""Search for the given query asynchronously."""
|
"""Search for the given query asynchronously."""
|
||||||
|
msg = "Subclasses must implement this method"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def astream_search(
|
def stream_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
) -> AsyncGenerator[str, None] | None:
|
) -> AsyncGenerator[Any, None]:
|
||||||
"""Stream search for the given query."""
|
"""Stream search for the given query."""
|
||||||
|
msg = "Subclasses must implement this method"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
@ -55,7 +55,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
|||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.response_type = response_type
|
self.response_type = response_type
|
||||||
|
|
||||||
async def asearch(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
@ -121,77 +121,11 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
|||||||
output_tokens=0,
|
output_tokens=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(
|
async def stream_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
**kwargs,
|
) -> AsyncGenerator[Any, None]:
|
||||||
) -> SearchResult:
|
|
||||||
"""Build basic search context that fits a single context window and generate answer for the user question."""
|
|
||||||
start_time = time.time()
|
|
||||||
search_prompt = ""
|
|
||||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
|
||||||
context_result = self.context_builder.build_context(
|
|
||||||
query=query,
|
|
||||||
conversation_history=conversation_history,
|
|
||||||
**kwargs,
|
|
||||||
**self.context_builder_params,
|
|
||||||
)
|
|
||||||
llm_calls["build_context"] = context_result.llm_calls
|
|
||||||
prompt_tokens["build_context"] = context_result.prompt_tokens
|
|
||||||
output_tokens["build_context"] = context_result.output_tokens
|
|
||||||
|
|
||||||
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
|
|
||||||
try:
|
|
||||||
search_prompt = self.system_prompt.format(
|
|
||||||
context_data=context_result.context_chunks,
|
|
||||||
response_type=self.response_type,
|
|
||||||
)
|
|
||||||
search_messages = [
|
|
||||||
{"role": "system", "content": search_prompt},
|
|
||||||
{"role": "user", "content": query},
|
|
||||||
]
|
|
||||||
|
|
||||||
response = self.llm.generate(
|
|
||||||
messages=search_messages,
|
|
||||||
streaming=True,
|
|
||||||
callbacks=self.callbacks,
|
|
||||||
**self.llm_params,
|
|
||||||
)
|
|
||||||
llm_calls["response"] = 1
|
|
||||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
|
||||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
|
||||||
|
|
||||||
return SearchResult(
|
|
||||||
response=response,
|
|
||||||
context_data=context_result.context_records,
|
|
||||||
context_text=context_result.context_chunks,
|
|
||||||
completion_time=time.time() - start_time,
|
|
||||||
llm_calls=sum(llm_calls.values()),
|
|
||||||
prompt_tokens=sum(prompt_tokens.values()),
|
|
||||||
output_tokens=sum(output_tokens.values()),
|
|
||||||
llm_calls_categories=llm_calls,
|
|
||||||
prompt_tokens_categories=prompt_tokens,
|
|
||||||
output_tokens_categories=output_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
log.exception("Exception in _map_response_single_batch")
|
|
||||||
return SearchResult(
|
|
||||||
response="",
|
|
||||||
context_data=context_result.context_records,
|
|
||||||
context_text=context_result.context_chunks,
|
|
||||||
completion_time=time.time() - start_time,
|
|
||||||
llm_calls=1,
|
|
||||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
|
||||||
output_tokens=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def astream_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
conversation_history: ConversationHistory | None = None,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
"""Build basic search context that fits a single context window and generate answer for the user query."""
|
"""Build basic search context that fits a single context window and generate answer for the user query."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class DriftAction:
|
|||||||
"""Check if the action is complete (i.e., an answer is available)."""
|
"""Check if the action is complete (i.e., an answer is available)."""
|
||||||
return self.answer is not None
|
return self.answer is not None
|
||||||
|
|
||||||
async def asearch(self, search_engine: Any, global_query: str, scorer: Any = None):
|
async def search(self, search_engine: Any, global_query: str, scorer: Any = None):
|
||||||
"""
|
"""
|
||||||
Execute an asynchronous search using the search engine, and update the action with the results.
|
Execute an asynchronous search using the search engine, and update the action with the results.
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class DriftAction:
|
|||||||
log.warning("Action already complete. Skipping search.")
|
log.warning("Action already complete. Skipping search.")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
search_result = await search_engine.asearch(
|
search_result = await search_engine.search(
|
||||||
drift_query=global_query, query=self.query
|
drift_query=global_query, query=self.query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class DRIFTPrimer:
|
|||||||
|
|
||||||
return parsed_response, token_ct
|
return parsed_response, token_ct
|
||||||
|
|
||||||
async def asearch(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
top_k_reports: pd.DataFrame,
|
top_k_reports: pd.DataFrame,
|
||||||
|
@ -144,7 +144,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
|||||||
error_msg = "Response must be a list of dictionaries."
|
error_msg = "Response must be a list of dictionaries."
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
async def asearch_step(
|
async def _search_step(
|
||||||
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
|
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
|
||||||
) -> list[DriftAction]:
|
) -> list[DriftAction]:
|
||||||
"""
|
"""
|
||||||
@ -160,12 +160,12 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
|||||||
list[DriftAction]: The results from executing the search actions asynchronously.
|
list[DriftAction]: The results from executing the search actions asynchronously.
|
||||||
"""
|
"""
|
||||||
tasks = [
|
tasks = [
|
||||||
action.asearch(search_engine=search_engine, global_query=global_query)
|
action.search(search_engine=search_engine, global_query=global_query)
|
||||||
for action in actions
|
for action in actions
|
||||||
]
|
]
|
||||||
return await tqdm_asyncio.gather(*tasks, leave=False)
|
return await tqdm_asyncio.gather(*tasks, leave=False)
|
||||||
|
|
||||||
async def asearch(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: Any = None,
|
conversation_history: Any = None,
|
||||||
@ -204,7 +204,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
|||||||
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
|
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||||
output_tokens["build_context"] = token_ct["prompt_tokens"]
|
output_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||||
|
|
||||||
primer_response = await self.primer.asearch(
|
primer_response = await self.primer.search(
|
||||||
query=query, top_k_reports=primer_context
|
query=query, top_k_reports=primer_context
|
||||||
)
|
)
|
||||||
llm_calls["primer"] = primer_response.llm_calls
|
llm_calls["primer"] = primer_response.llm_calls
|
||||||
@ -229,7 +229,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
|||||||
len(actions) - self.context_builder.config.drift_k_followups
|
len(actions) - self.context_builder.config.drift_k_followups
|
||||||
)
|
)
|
||||||
# Process actions
|
# Process actions
|
||||||
results = await self.asearch_step(
|
results = await self._search_step(
|
||||||
global_query=query, search_engine=self.local_search, actions=actions
|
global_query=query, search_engine=self.local_search, actions=actions
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -278,37 +278,17 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
|||||||
output_tokens_categories=output_tokens,
|
output_tokens_categories=output_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(
|
async def stream_search(
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
conversation_history: Any = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> SearchResult:
|
|
||||||
"""
|
|
||||||
Perform a synchronous DRIFT search (Not Implemented).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The query to search for.
|
|
||||||
conversation_history (Any, optional): The conversation history.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
NotImplementedError: Synchronous DRIFT is not implemented.
|
|
||||||
"""
|
|
||||||
error_msg = "Synchronous DRIFT is not implemented."
|
|
||||||
raise NotImplementedError(error_msg)
|
|
||||||
|
|
||||||
async def astream_search(
|
|
||||||
self, query: str, conversation_history: ConversationHistory | None = None
|
self, query: str, conversation_history: ConversationHistory | None = None
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Perform a streaming DRIFT search (Not Implemented).
|
Perform a streaming DRIFT search asynchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): The query to search for.
|
query (str): The query to search for.
|
||||||
conversation_history (ConversationHistory, optional): The conversation history.
|
conversation_history (ConversationHistory, optional): The conversation history.
|
||||||
"""
|
"""
|
||||||
result = await self.asearch(
|
result = await self.search(
|
||||||
query=query, conversation_history=conversation_history, reduce=False
|
query=query, conversation_history=conversation_history, reduce=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
|||||||
|
|
||||||
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
|
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
|
||||||
|
|
||||||
async def astream_search(
|
async def stream_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
@ -135,7 +135,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
|||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
async def asearch(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
@ -204,15 +204,6 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
|||||||
output_tokens_categories=output_tokens,
|
output_tokens_categories=output_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
conversation_history: ConversationHistory | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> GlobalSearchResult:
|
|
||||||
"""Perform a global search synchronously."""
|
|
||||||
return asyncio.run(self.asearch(query, conversation_history))
|
|
||||||
|
|
||||||
async def _map_response_single_batch(
|
async def _map_response_single_batch(
|
||||||
self,
|
self,
|
||||||
context_data: str,
|
context_data: str,
|
||||||
@ -235,7 +226,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
|||||||
log.info("Map response: %s", search_response)
|
log.info("Map response: %s", search_response)
|
||||||
try:
|
try:
|
||||||
# parse search response json
|
# parse search response json
|
||||||
processed_response = self.parse_search_response(search_response)
|
processed_response = self._parse_search_response(search_response)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
log.warning(
|
log.warning(
|
||||||
"Warning: Error parsing search response json - skipping this batch"
|
"Warning: Error parsing search response json - skipping this batch"
|
||||||
@ -264,7 +255,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
|||||||
output_tokens=0,
|
output_tokens=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
|
def _parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
|
||||||
"""Parse the search response json and return a list of key points.
|
"""Parse the search response json and return a list of key points.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -54,7 +54,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
|||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.response_type = response_type
|
self.response_type = response_type
|
||||||
|
|
||||||
async def asearch(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
@ -128,7 +128,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
|||||||
output_tokens=0,
|
output_tokens=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def astream_search(
|
async def stream_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: ConversationHistory | None = None,
|
conversation_history: ConversationHistory | None = None,
|
||||||
@ -158,69 +158,3 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
|||||||
**self.llm_params,
|
**self.llm_params,
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
def search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
conversation_history: ConversationHistory | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> SearchResult:
|
|
||||||
"""Build local search context that fits a single context window and generate answer for the user question."""
|
|
||||||
start_time = time.time()
|
|
||||||
search_prompt = ""
|
|
||||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
|
||||||
context_result = self.context_builder.build_context(
|
|
||||||
query=query,
|
|
||||||
conversation_history=conversation_history,
|
|
||||||
**kwargs,
|
|
||||||
**self.context_builder_params,
|
|
||||||
)
|
|
||||||
llm_calls["build_context"] = context_result.llm_calls
|
|
||||||
prompt_tokens["build_context"] = context_result.prompt_tokens
|
|
||||||
output_tokens["build_context"] = context_result.output_tokens
|
|
||||||
|
|
||||||
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
|
|
||||||
try:
|
|
||||||
search_prompt = self.system_prompt.format(
|
|
||||||
context_data=context_result.context_chunks,
|
|
||||||
response_type=self.response_type,
|
|
||||||
)
|
|
||||||
search_messages = [
|
|
||||||
{"role": "system", "content": search_prompt},
|
|
||||||
{"role": "user", "content": query},
|
|
||||||
]
|
|
||||||
|
|
||||||
response = self.llm.generate(
|
|
||||||
messages=search_messages,
|
|
||||||
streaming=True,
|
|
||||||
callbacks=self.callbacks,
|
|
||||||
**self.llm_params,
|
|
||||||
)
|
|
||||||
llm_calls["response"] = 1
|
|
||||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
|
||||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
|
||||||
|
|
||||||
return SearchResult(
|
|
||||||
response=response,
|
|
||||||
context_data=context_result.context_records,
|
|
||||||
context_text=context_result.context_chunks,
|
|
||||||
completion_time=time.time() - start_time,
|
|
||||||
llm_calls=sum(llm_calls.values()),
|
|
||||||
prompt_tokens=sum(prompt_tokens.values()),
|
|
||||||
output_tokens=sum(output_tokens.values()),
|
|
||||||
llm_calls_categories=llm_calls,
|
|
||||||
prompt_tokens_categories=prompt_tokens,
|
|
||||||
output_tokens_categories=output_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
log.exception("Exception in _map_response_single_batch")
|
|
||||||
return SearchResult(
|
|
||||||
response="",
|
|
||||||
context_data=context_result.context_records,
|
|
||||||
context_text=context_result.context_chunks,
|
|
||||||
completion_time=time.time() - start_time,
|
|
||||||
llm_calls=1,
|
|
||||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
|
||||||
output_tokens=0,
|
|
||||||
)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user