mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-07 12:40:58 +00:00
Add streaming support for local/global search (#944)
* Added streaming output support for global search. Introduce `--streaming` flag to enable or disable streaming mode * ran ruff format --preview * update * cleanup code and streaming api * update cli argument * remove whitespace * checkpoint - add context data to streaming api * cleanup help menu * ruff format update * add context data to streaming response * add semversioner file * rename variable for better readability * rename variable for better readability * ruff fixes * fix abstract class type annotation * add documentation for --streaming CLI flag --------- Co-authored-by: 6GOD <55304045+6ixGODD@users.noreply.github.com> Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
a6238c654a
commit
62546a3c14
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Add streaming support for local/global search to query cli"
|
||||
}
|
||||
@ -19,6 +19,7 @@ python -m graphrag.query --config <config_file.yml> --data <path-to-data> --comm
|
||||
- `--community_level <community-level>` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2
|
||||
- `--response_type <response-type>` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`.
|
||||
- `--method <"local"|"global">` - Method to use to answer the query, one of local or global. For more information check [Overview](overview.md)
|
||||
- `--streaming` - Stream back the LLM response
|
||||
|
||||
## Env Variables
|
||||
|
||||
|
||||
@ -72,6 +72,12 @@ if __name__ == "__main__":
|
||||
default="Multiple Paragraphs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--streaming",
|
||||
help="Output response in a streaming (chunk-by-chunk) manner",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"query",
|
||||
nargs=1,
|
||||
@ -89,6 +95,7 @@ if __name__ == "__main__":
|
||||
args.root,
|
||||
args.community_level,
|
||||
args.response_type,
|
||||
args.streaming,
|
||||
args.query[0],
|
||||
)
|
||||
case SearchType.GLOBAL:
|
||||
@ -98,6 +105,7 @@ if __name__ == "__main__":
|
||||
args.root,
|
||||
args.community_level,
|
||||
args.response_type,
|
||||
args.streaming,
|
||||
args.query[0],
|
||||
)
|
||||
case _:
|
||||
|
||||
@ -7,10 +7,17 @@ Query Engine API.
|
||||
This API provides access to the query engine of graphrag, allowing external applications
|
||||
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
|
||||
|
||||
Contains the following functions:
|
||||
- global_search: Perform a global search.
|
||||
- global_search_streaming: Perform a global search and stream results via a generator.
|
||||
- local_search: Perform a local search.
|
||||
- local_search_streaming: Perform a local search and stream results via a generator.
|
||||
|
||||
WARNING: This API is under development and may undergo changes in future releases.
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
@ -35,53 +42,6 @@ from .input.loaders.dfs import store_entity_semantic_embeddings
|
||||
reporter = PrintProgressReporter("")
|
||||
|
||||
|
||||
def __get_embedding_description_store(
|
||||
entities: list[Entity],
|
||||
vector_store_type: str = VectorStoreType.LanceDB,
|
||||
config_args: dict | None = None,
|
||||
):
|
||||
"""Get the embedding description store."""
|
||||
if not config_args:
|
||||
config_args = {}
|
||||
|
||||
collection_name = config_args.get(
|
||||
"query_collection_name", "entity_description_embeddings"
|
||||
)
|
||||
config_args.update({"collection_name": collection_name})
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type, kwargs=config_args
|
||||
)
|
||||
|
||||
description_embedding_store.connect(**config_args)
|
||||
|
||||
if config_args.get("overwrite", True):
|
||||
# this step assumes the embeddings were originally stored in a file rather
|
||||
# than a vector database
|
||||
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=entities, vectorstore=description_embedding_store
|
||||
)
|
||||
else:
|
||||
# load description embeddings to an in-memory lancedb vectorstore
|
||||
# and connect to a remote db, specify url and port values.
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name=collection_name
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config_args.get("db_uri", "./lancedb")
|
||||
)
|
||||
|
||||
# load data from an existing table
|
||||
description_embedding_store.document_collection = (
|
||||
description_embedding_store.db_connection.open_table(
|
||||
description_embedding_store.collection_name
|
||||
)
|
||||
)
|
||||
|
||||
return description_embedding_store
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def global_search(
|
||||
config: GraphRagConfig,
|
||||
@ -125,6 +85,61 @@ async def global_search(
|
||||
return result.response
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def global_search_streaming(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a global search and return results as a generator.
|
||||
|
||||
Context data is returned as a dictionary of lists, with one list entry for each record.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- response_type (str): The type of response to return.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
entities=_entities,
|
||||
response_type=response_type,
|
||||
)
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
# when streaming results, a context data object is returned as the first result
|
||||
# and the query response in subsequent tokens
|
||||
context_data = None
|
||||
get_context_data = True
|
||||
async for stream_chunk in search_result:
|
||||
if get_context_data:
|
||||
context_data = _reformat_context_data(stream_chunk)
|
||||
yield context_data
|
||||
get_context_data = False
|
||||
else:
|
||||
yield stream_chunk
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def local_search(
|
||||
config: GraphRagConfig,
|
||||
@ -164,16 +179,17 @@ async def local_search(
|
||||
vector_store_args = (
|
||||
config.embeddings.vector_store if config.embeddings.vector_store else {}
|
||||
)
|
||||
|
||||
reporter.info(f"Vector Store Args: {vector_store_args}")
|
||||
|
||||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
description_embedding_store = __get_embedding_description_store(
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
entities=_entities,
|
||||
vector_store_type=vector_store_type,
|
||||
config_args=vector_store_args,
|
||||
)
|
||||
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
@ -190,3 +206,154 @@ async def local_search(
|
||||
result = await search_engine.asearch(query=query)
|
||||
reporter.success(f"Local Search Response: {result.response}")
|
||||
return result.response
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def local_search_streaming(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
covariates: pd.DataFrame | None,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a local search and return results as a generator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
|
||||
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
|
||||
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- response_type (str): The response type to return.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = (
|
||||
config.embeddings.vector_store if config.embeddings.vector_store else {}
|
||||
)
|
||||
reporter.info(f"Vector Store Args: {vector_store_args}")
|
||||
|
||||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
entities=_entities,
|
||||
vector_store_type=vector_store_type,
|
||||
config_args=vector_store_args,
|
||||
)
|
||||
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
config=config,
|
||||
reports=read_indexer_reports(community_reports, nodes, community_level),
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
entities=_entities,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
covariates={"claims": _covariates},
|
||||
description_embedding_store=description_embedding_store,
|
||||
response_type=response_type,
|
||||
)
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
# when streaming results, a context data object is returned as the first result
|
||||
# and the query response in subsequent tokens
|
||||
context_data = None
|
||||
get_context_data = True
|
||||
async for stream_chunk in search_result:
|
||||
if get_context_data:
|
||||
context_data = _reformat_context_data(stream_chunk)
|
||||
yield context_data
|
||||
get_context_data = False
|
||||
else:
|
||||
yield stream_chunk
|
||||
|
||||
|
||||
def _get_embedding_description_store(
|
||||
entities: list[Entity],
|
||||
vector_store_type: str = VectorStoreType.LanceDB,
|
||||
config_args: dict | None = None,
|
||||
):
|
||||
"""Get the embedding description store."""
|
||||
if not config_args:
|
||||
config_args = {}
|
||||
|
||||
collection_name = config_args.get(
|
||||
"query_collection_name", "entity_description_embeddings"
|
||||
)
|
||||
config_args.update({"collection_name": collection_name})
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type, kwargs=config_args
|
||||
)
|
||||
|
||||
description_embedding_store.connect(**config_args)
|
||||
|
||||
if config_args.get("overwrite", True):
|
||||
# this step assumes the embeddings were originally stored in a file rather
|
||||
# than a vector database
|
||||
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=entities, vectorstore=description_embedding_store
|
||||
)
|
||||
else:
|
||||
# load description embeddings to an in-memory lancedb vectorstore
|
||||
# and connect to a remote db, specify url and port values.
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name=collection_name
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config_args.get("db_uri", "./lancedb")
|
||||
)
|
||||
|
||||
# load data from an existing table
|
||||
description_embedding_store.document_collection = (
|
||||
description_embedding_store.db_connection.open_table(
|
||||
description_embedding_store.collection_name
|
||||
)
|
||||
)
|
||||
|
||||
return description_embedding_store
|
||||
|
||||
|
||||
def _reformat_context_data(context_data: dict) -> dict:
|
||||
"""
|
||||
Reformats context_data for all query responses.
|
||||
|
||||
Reformats a dictionary of dataframes into a dictionary of lists.
|
||||
One list entry for each record. Records are grouped by original
|
||||
dictionary keys.
|
||||
|
||||
Note: depending on which query algorithm is used, the context_data may not
|
||||
contain the same information (keys). In this case, the default behavior will be to
|
||||
set these keys as empty lists to preserve a standard output format.
|
||||
"""
|
||||
final_format = {
|
||||
"reports": [],
|
||||
"entities": [],
|
||||
"relationships": [],
|
||||
"claims": [],
|
||||
"sources": [],
|
||||
}
|
||||
for key in context_data:
|
||||
records = context_data[key].to_dict(orient="records")
|
||||
if len(records) < 1:
|
||||
continue
|
||||
final_format[key] = records
|
||||
return final_format
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
@ -22,11 +23,12 @@ reporter = PrintProgressReporter("")
|
||||
|
||||
|
||||
def run_global_search(
|
||||
config_dir: str | None,
|
||||
config_filepath: str | None,
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
):
|
||||
"""Perform a global search with a given query.
|
||||
@ -34,7 +36,7 @@ def run_global_search(
|
||||
Loads index files required for global search and calls the Query API.
|
||||
"""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_dir
|
||||
data_dir, root_dir, config_filepath
|
||||
)
|
||||
data_path = Path(data_dir)
|
||||
|
||||
@ -48,6 +50,34 @@ def run_global_search(
|
||||
data_path / "create_final_community_reports.parquet"
|
||||
)
|
||||
|
||||
# call the Query API
|
||||
if streaming:
|
||||
|
||||
async def run_streaming_search():
|
||||
full_response = ""
|
||||
context_data = None
|
||||
get_context_data = True
|
||||
async for stream_chunk in api.global_search_streaming(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
community_reports=final_community_reports,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
):
|
||||
if get_context_data:
|
||||
context_data = stream_chunk
|
||||
get_context_data = False
|
||||
else:
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
# not streaming
|
||||
return asyncio.run(
|
||||
api.global_search(
|
||||
config=config,
|
||||
@ -62,11 +92,12 @@ def run_global_search(
|
||||
|
||||
|
||||
def run_local_search(
|
||||
config_dir: str | None,
|
||||
config_filepath: str | None,
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
):
|
||||
"""Perform a local search with a given query.
|
||||
@ -74,7 +105,7 @@ def run_local_search(
|
||||
Loads index files required for local search and calls the Query API.
|
||||
"""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_dir
|
||||
data_dir, root_dir, config_filepath
|
||||
)
|
||||
data_path = Path(data_dir)
|
||||
|
||||
@ -95,6 +126,36 @@ def run_local_search(
|
||||
)
|
||||
|
||||
# call the Query API
|
||||
if streaming:
|
||||
|
||||
async def run_streaming_search():
|
||||
full_response = ""
|
||||
context_data = None
|
||||
get_context_data = True
|
||||
async for stream_chunk in api.local_search_streaming(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
community_reports=final_community_reports,
|
||||
text_units=final_text_units,
|
||||
relationships=final_relationships,
|
||||
covariates=final_covariates,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
):
|
||||
if get_context_data:
|
||||
context_data = stream_chunk
|
||||
get_context_data = False
|
||||
else:
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
# not streaming
|
||||
return asyncio.run(
|
||||
api.local_search(
|
||||
config=config,
|
||||
@ -114,14 +175,14 @@ def run_local_search(
|
||||
def _configure_paths_and_settings(
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
config_dir: str | None,
|
||||
config_filepath: str | None,
|
||||
) -> tuple[str, str | None, GraphRagConfig]:
|
||||
if data_dir is None and root_dir is None:
|
||||
msg = "Either data_dir or root_dir must be provided."
|
||||
raise ValueError(msg)
|
||||
if data_dir is None:
|
||||
data_dir = _infer_data_dir(cast(str, root_dir))
|
||||
config = _create_graphrag_config(root_dir, config_dir)
|
||||
config = _create_graphrag_config(root_dir, config_filepath)
|
||||
return data_dir, root_dir, config
|
||||
|
||||
|
||||
@ -141,10 +202,10 @@ def _infer_data_dir(root: str) -> str:
|
||||
|
||||
def _create_graphrag_config(
|
||||
root: str | None,
|
||||
config_dir: str | None,
|
||||
config_filepath: str | None,
|
||||
) -> GraphRagConfig:
|
||||
"""Create a GraphRag configuration."""
|
||||
return _read_config_parameters(root or "./", config_dir)
|
||||
return _read_config_parameters(root or "./", config_filepath)
|
||||
|
||||
|
||||
def _read_config_parameters(root: str, config: str | None):
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""Base classes for LLM and Embedding models."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any
|
||||
|
||||
|
||||
@ -31,6 +32,15 @@ class BaseLLM(ABC):
|
||||
) -> str:
|
||||
"""Generate a response."""
|
||||
|
||||
@abstractmethod
|
||||
def stream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate a response with streaming."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate(
|
||||
self,
|
||||
@ -41,6 +51,16 @@ class BaseLLM(ABC):
|
||||
) -> str:
|
||||
"""Generate a response asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
async def astream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate a response asynchronously with streaming."""
|
||||
...
|
||||
|
||||
|
||||
class BaseTextEmbedding(ABC):
|
||||
"""The text embedding interface."""
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
"""Chat-based OpenAI LLM implementation."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from typing import Any
|
||||
|
||||
from tenacity import (
|
||||
@ -92,6 +92,38 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
|
||||
# TODO: why not just throw in this case?
|
||||
return ""
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate text with streaming."""
|
||||
try:
|
||||
retryer = Retrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types),
|
||||
)
|
||||
for attempt in retryer:
|
||||
with attempt:
|
||||
generator = self._stream_generate(
|
||||
messages=messages,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
yield from generator
|
||||
|
||||
except RetryError as e:
|
||||
self._reporter.error(
|
||||
message="Error at stream_generate()",
|
||||
details={self.__class__.__name__: str(e)},
|
||||
)
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
@ -122,6 +154,35 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
|
||||
# TODO: why not just throw in this case?
|
||||
return ""
|
||||
|
||||
async def astream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate text asynchronously with streaming."""
|
||||
try:
|
||||
retryer = AsyncRetrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types), # type: ignore
|
||||
)
|
||||
async for attempt in retryer:
|
||||
with attempt:
|
||||
generator = self._astream_generate(
|
||||
messages=messages,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
async for response in generator:
|
||||
yield response
|
||||
except RetryError as e:
|
||||
self._reporter.error(f"Error at astream_generate(): {e}")
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
@ -163,6 +224,37 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
|
||||
return full_response
|
||||
return response.choices[0].message.content or "" # type: ignore
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[str, None, None]:
|
||||
model = self.model
|
||||
if not model:
|
||||
raise ValueError(_MODEL_REQUIRED_MSG)
|
||||
response = self.sync_client.chat.completions.create( # type: ignore
|
||||
model=model,
|
||||
messages=messages, # type: ignore
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
for chunk in response:
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
)
|
||||
|
||||
yield delta
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
@ -204,3 +296,34 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
|
||||
return full_response
|
||||
|
||||
return response.choices[0].message.content or "" # type: ignore
|
||||
|
||||
async def _astream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model = self.model
|
||||
if not model:
|
||||
raise ValueError(_MODEL_REQUIRED_MSG)
|
||||
response = await self.async_client.chat.completions.create( # type: ignore
|
||||
model=model,
|
||||
messages=messages, # type: ignore
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
async for chunk in response:
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
) # type: ignore
|
||||
|
||||
yield delta
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""Base classes for search algos."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@ -67,3 +68,11 @@ class BaseSearch(ABC):
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
"""Search for the given query asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def astream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream search for the given query."""
|
||||
|
||||
@ -7,6 +7,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@ -100,6 +101,37 @@ class GlobalSearch(BaseSearch):
|
||||
|
||||
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
|
||||
|
||||
async def astream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator:
|
||||
"""Stream the global search response."""
|
||||
context_chunks, context_records = self.context_builder.build_context(
|
||||
conversation_history=conversation_history, **self.context_builder_params
|
||||
)
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_map_response_start(context_chunks) # type: ignore
|
||||
map_responses = await asyncio.gather(*[
|
||||
self._map_response_single_batch(
|
||||
context_data=data, query=query, **self.map_llm_params
|
||||
)
|
||||
for data in context_chunks
|
||||
])
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_map_response_end(map_responses) # type: ignore
|
||||
|
||||
# send context records first before sending the reduce response
|
||||
yield context_records
|
||||
async for response in self._stream_reduce_response(
|
||||
map_responses=map_responses, # type: ignore
|
||||
query=query,
|
||||
**self.reduce_llm_params,
|
||||
):
|
||||
yield response
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
@ -357,3 +389,81 @@ class GlobalSearch(BaseSearch):
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
)
|
||||
|
||||
async def _stream_reduce_response(
|
||||
self,
|
||||
map_responses: list[SearchResult],
|
||||
query: str,
|
||||
**llm_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# collect all key points into a single list to prepare for sorting
|
||||
key_points = []
|
||||
for index, response in enumerate(map_responses):
|
||||
if not isinstance(response.response, list):
|
||||
continue
|
||||
for element in response.response:
|
||||
if not isinstance(element, dict):
|
||||
continue
|
||||
if "answer" not in element or "score" not in element:
|
||||
continue
|
||||
key_points.append({
|
||||
"analyst": index,
|
||||
"answer": element["answer"],
|
||||
"score": element["score"],
|
||||
})
|
||||
|
||||
# filter response with score = 0 and rank responses by descending order of score
|
||||
filtered_key_points = [
|
||||
point
|
||||
for point in key_points
|
||||
if point["score"] > 0 # type: ignore
|
||||
]
|
||||
|
||||
if len(filtered_key_points) == 0 and not self.allow_general_knowledge:
|
||||
# return no data answer if no key points are found
|
||||
log.warning(
|
||||
"Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."
|
||||
)
|
||||
yield NO_DATA_ANSWER
|
||||
return
|
||||
|
||||
filtered_key_points = sorted(
|
||||
filtered_key_points,
|
||||
key=lambda x: x["score"], # type: ignore
|
||||
reverse=True, # type: ignore
|
||||
)
|
||||
|
||||
data = []
|
||||
total_tokens = 0
|
||||
for point in filtered_key_points:
|
||||
formatted_response_data = [
|
||||
f'----Analyst {point["analyst"] + 1}----',
|
||||
f'Importance Score: {point["score"]}',
|
||||
point["answer"],
|
||||
]
|
||||
formatted_response_text = "\n".join(formatted_response_data)
|
||||
if (
|
||||
total_tokens + num_tokens(formatted_response_text, self.token_encoder)
|
||||
> self.max_data_tokens
|
||||
):
|
||||
break
|
||||
data.append(formatted_response_text)
|
||||
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
|
||||
text_data = "\n\n".join(data)
|
||||
|
||||
search_prompt = self.reduce_system_prompt.format(
|
||||
report_data=text_data, response_type=self.response_type
|
||||
)
|
||||
if self.allow_general_knowledge:
|
||||
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
async for resp in self.llm.astream_generate( # type: ignore
|
||||
search_messages,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**llm_kwargs, # type: ignore
|
||||
):
|
||||
yield resp
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
@ -106,6 +107,37 @@ class LocalSearch(BaseSearch):
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
)
|
||||
|
||||
async def astream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator:
|
||||
"""Build local search context that fits a single context window and generate answer for the user query."""
|
||||
start_time = time.time()
|
||||
|
||||
context_text, context_records = self.context_builder.build_context(
|
||||
query=query,
|
||||
conversation_history=conversation_history,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_text, response_type=self.response_type
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
# send context records first before sending the reduce response
|
||||
yield context_records
|
||||
async for response in self.llm.astream_generate( # type: ignore
|
||||
messages=search_messages,
|
||||
callbacks=self.callbacks,
|
||||
**self.llm_params,
|
||||
):
|
||||
yield response
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user