Add streaming support for local/global search (#944)

* Added streaming output support for global search. Introduce `--streaming` flag to enable or disable streaming mode

* ran ruff format --preview

* update

* cleanup code and streaming api

* update cli argument

* remove whitespace

* checkpoint - add context data to streaming api

* cleanup help menu

* ruff format update

* add context data to streaming response

* add semversioner file

* rename variable for better readability

* rename variable for better readability

* ruff fixes

* fix abstract class type annotation

* add documentation for --streaming CLI flag

---------

Co-authored-by: 6GOD <55304045+6ixGODD@users.noreply.github.com>
Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Josh Bradley 2024-08-20 15:44:48 -04:00 committed by GitHub
parent a6238c654a
commit 62546a3c14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 593 additions and 58 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add streaming support for local/global search to query cli"
}

View File

@ -19,6 +19,7 @@ python -m graphrag.query --config <config_file.yml> --data <path-to-data> --comm
- `--community_level <community-level>` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2
- `--response_type <response-type>` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`.
- `--method <"local"|"global">` - Method to use to answer the query, one of local or global. For more information check [Overview](overview.md)
- `--streaming` - Stream back the LLM response
## Env Variables

View File

@ -72,6 +72,12 @@ if __name__ == "__main__":
default="Multiple Paragraphs",
)
parser.add_argument(
"--streaming",
help="Output response in a streaming (chunk-by-chunk) manner",
action="store_true",
)
parser.add_argument(
"query",
nargs=1,
@ -89,6 +95,7 @@ if __name__ == "__main__":
args.root,
args.community_level,
args.response_type,
args.streaming,
args.query[0],
)
case SearchType.GLOBAL:
@ -98,6 +105,7 @@ if __name__ == "__main__":
args.root,
args.community_level,
args.response_type,
args.streaming,
args.query[0],
)
case _:

View File

@ -7,10 +7,17 @@ Query Engine API.
This API provides access to the query engine of graphrag, allowing external applications
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
Contains the following functions:
- global_search: Perform a global search.
- global_search_streaming: Perform a global search and stream results via a generator.
- local_search: Perform a local search.
- local_search_streaming: Perform a local search and stream results via a generator.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""
from collections.abc import AsyncGenerator
from typing import Any
import pandas as pd
@ -35,53 +42,6 @@ from .input.loaders.dfs import store_entity_semantic_embeddings
reporter = PrintProgressReporter("")
def __get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}
collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)
description_embedding_store.connect(**config_args)
if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)
# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)
return description_embedding_store
@validate_call(config={"arbitrary_types_allowed": True})
async def global_search(
config: GraphRagConfig,
@ -125,6 +85,61 @@ async def global_search(
return result.response
@validate_call(config={"arbitrary_types_allowed": True})
async def global_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a global search and return results as a generator.
Context data is returned as a dictionary of lists, with one list entry for each record.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- community_level (int): The community level to search at.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
reports = read_indexer_reports(community_reports, nodes, community_level)
_entities = read_indexer_entities(nodes, entities, community_level)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
response_type=response_type,
)
search_result = search_engine.astream_search(query=query)
# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk)
yield context_data
get_context_data = False
else:
yield stream_chunk
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
config: GraphRagConfig,
@ -164,16 +179,17 @@ async def local_search(
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = __get_embedding_description_store(
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
search_engine = get_local_search_engine(
@ -190,3 +206,154 @@ async def local_search(
result = await search_engine.asearch(query=query)
reporter.success(f"Local Search Response: {result.response}")
return result.response
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a local search and return results as a generator.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, nodes, community_level),
text_units=read_indexer_text_units(text_units),
entities=_entities,
relationships=read_indexer_relationships(relationships),
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store,
response_type=response_type,
)
search_result = search_engine.astream_search(query=query)
# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk)
yield context_data
get_context_data = False
else:
yield stream_chunk
def _get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}
collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)
description_embedding_store.connect(**config_args)
if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)
# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)
return description_embedding_store
def _reformat_context_data(context_data: dict) -> dict:
"""
Reformats context_data for all query responses.
Reformats a dictionary of dataframes into a dictionary of lists.
One list entry for each record. Records are grouped by original
dictionary keys.
Note: depending on which query algorithm is used, the context_data may not
contain the same information (keys). In this case, the default behavior will be to
set these keys as empty lists to preserve a standard output format.
"""
final_format = {
"reports": [],
"entities": [],
"relationships": [],
"claims": [],
"sources": [],
}
for key in context_data:
records = context_data[key].to_dict(orient="records")
if len(records) < 1:
continue
final_format[key] = records
return final_format

View File

@ -5,6 +5,7 @@
import asyncio
import re
import sys
from pathlib import Path
from typing import cast
@ -22,11 +23,12 @@ reporter = PrintProgressReporter("")
def run_global_search(
config_dir: str | None,
config_filepath: str | None,
data_dir: str | None,
root_dir: str | None,
community_level: int,
response_type: str,
streaming: bool,
query: str,
):
"""Perform a global search with a given query.
@ -34,7 +36,7 @@ def run_global_search(
Loads index files required for global search and calls the Query API.
"""
data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir
data_dir, root_dir, config_filepath
)
data_path = Path(data_dir)
@ -48,6 +50,34 @@ def run_global_search(
data_path / "create_final_community_reports.parquet"
)
# call the Query API
if streaming:
async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.global_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
return full_response, context_data
return asyncio.run(run_streaming_search())
# not streaming
return asyncio.run(
api.global_search(
config=config,
@ -62,11 +92,12 @@ def run_global_search(
def run_local_search(
config_dir: str | None,
config_filepath: str | None,
data_dir: str | None,
root_dir: str | None,
community_level: int,
response_type: str,
streaming: bool,
query: str,
):
"""Perform a local search with a given query.
@ -74,7 +105,7 @@ def run_local_search(
Loads index files required for local search and calls the Query API.
"""
data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir
data_dir, root_dir, config_filepath
)
data_path = Path(data_dir)
@ -95,6 +126,36 @@ def run_local_search(
)
# call the Query API
if streaming:
async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.local_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
covariates=final_covariates,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
return full_response, context_data
return asyncio.run(run_streaming_search())
# not streaming
return asyncio.run(
api.local_search(
config=config,
@ -114,14 +175,14 @@ def run_local_search(
def _configure_paths_and_settings(
data_dir: str | None,
root_dir: str | None,
config_dir: str | None,
config_filepath: str | None,
) -> tuple[str, str | None, GraphRagConfig]:
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
data_dir = _infer_data_dir(cast(str, root_dir))
config = _create_graphrag_config(root_dir, config_dir)
config = _create_graphrag_config(root_dir, config_filepath)
return data_dir, root_dir, config
@ -141,10 +202,10 @@ def _infer_data_dir(root: str) -> str:
def _create_graphrag_config(
root: str | None,
config_dir: str | None,
config_filepath: str | None,
) -> GraphRagConfig:
"""Create a GraphRag configuration."""
return _read_config_parameters(root or "./", config_dir)
return _read_config_parameters(root or "./", config_filepath)
def _read_config_parameters(root: str, config: str | None):

View File

@ -4,6 +4,7 @@
"""Base classes for LLM and Embedding models."""
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from typing import Any
@ -31,6 +32,15 @@ class BaseLLM(ABC):
) -> str:
"""Generate a response."""
@abstractmethod
def stream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> Generator[str, None, None]:
"""Generate a response with streaming."""
@abstractmethod
async def agenerate(
self,
@ -41,6 +51,16 @@ class BaseLLM(ABC):
) -> str:
"""Generate a response asynchronously."""
@abstractmethod
async def astream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
"""Generate a response asynchronously with streaming."""
...
class BaseTextEmbedding(ABC):
"""The text embedding interface."""

View File

@ -3,7 +3,7 @@
"""Chat-based OpenAI LLM implementation."""
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable, Generator
from typing import Any
from tenacity import (
@ -92,6 +92,38 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
# TODO: why not just throw in this case?
return ""
def stream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> Generator[str, None, None]:
"""Generate text with streaming."""
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
generator = self._stream_generate(
messages=messages,
callbacks=callbacks,
**kwargs,
)
yield from generator
except RetryError as e:
self._reporter.error(
message="Error at stream_generate()",
details={self.__class__.__name__: str(e)},
)
return
else:
return
async def agenerate(
self,
messages: str | list[Any],
@ -122,6 +154,35 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
# TODO: why not just throw in this case?
return ""
async def astream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
"""Generate text asynchronously with streaming."""
try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types), # type: ignore
)
async for attempt in retryer:
with attempt:
generator = self._astream_generate(
messages=messages,
callbacks=callbacks,
**kwargs,
)
async for response in generator:
yield response
except RetryError as e:
self._reporter.error(f"Error at astream_generate(): {e}")
return
else:
return
def _generate(
self,
messages: str | list[Any],
@ -163,6 +224,37 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
return full_response
return response.choices[0].message.content or "" # type: ignore
def _stream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> Generator[str, None, None]:
model = self.model
if not model:
raise ValueError(_MODEL_REQUIRED_MSG)
response = self.sync_client.chat.completions.create( # type: ignore
model=model,
messages=messages, # type: ignore
stream=True,
**kwargs,
)
for chunk in response:
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
)
yield delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)
async def _agenerate(
self,
messages: str | list[Any],
@ -204,3 +296,34 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
return full_response
return response.choices[0].message.content or "" # type: ignore
async def _astream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
model = self.model
if not model:
raise ValueError(_MODEL_REQUIRED_MSG)
response = await self.async_client.chat.completions.create( # type: ignore
model=model,
messages=messages, # type: ignore
stream=True,
**kwargs,
)
async for chunk in response:
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
) # type: ignore
yield delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)

View File

@ -4,6 +4,7 @@
"""Base classes for search algos."""
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
@ -67,3 +68,11 @@ class BaseSearch(ABC):
**kwargs,
) -> SearchResult:
"""Search for the given query asynchronously."""
@abstractmethod
def astream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[str, None]:
"""Stream search for the given query."""

View File

@ -7,6 +7,7 @@ import asyncio
import json
import logging
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
@ -100,6 +101,37 @@ class GlobalSearch(BaseSearch):
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
async def astream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
"""Stream the global search response."""
context_chunks, context_records = self.context_builder.build_context(
conversation_history=conversation_history, **self.context_builder_params
)
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_start(context_chunks) # type: ignore
map_responses = await asyncio.gather(*[
self._map_response_single_batch(
context_data=data, query=query, **self.map_llm_params
)
for data in context_chunks
])
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_end(map_responses) # type: ignore
# send context records first before sending the reduce response
yield context_records
async for response in self._stream_reduce_response(
map_responses=map_responses, # type: ignore
query=query,
**self.reduce_llm_params,
):
yield response
async def asearch(
self,
query: str,
@ -357,3 +389,81 @@ class GlobalSearch(BaseSearch):
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
)
async def _stream_reduce_response(
self,
map_responses: list[SearchResult],
query: str,
**llm_kwargs,
) -> AsyncGenerator[str, None]:
# collect all key points into a single list to prepare for sorting
key_points = []
for index, response in enumerate(map_responses):
if not isinstance(response.response, list):
continue
for element in response.response:
if not isinstance(element, dict):
continue
if "answer" not in element or "score" not in element:
continue
key_points.append({
"analyst": index,
"answer": element["answer"],
"score": element["score"],
})
# filter response with score = 0 and rank responses by descending order of score
filtered_key_points = [
point
for point in key_points
if point["score"] > 0 # type: ignore
]
if len(filtered_key_points) == 0 and not self.allow_general_knowledge:
# return no data answer if no key points are found
log.warning(
"Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."
)
yield NO_DATA_ANSWER
return
filtered_key_points = sorted(
filtered_key_points,
key=lambda x: x["score"], # type: ignore
reverse=True, # type: ignore
)
data = []
total_tokens = 0
for point in filtered_key_points:
formatted_response_data = [
f'----Analyst {point["analyst"] + 1}----',
f'Importance Score: {point["score"]}',
point["answer"],
]
formatted_response_text = "\n".join(formatted_response_data)
if (
total_tokens + num_tokens(formatted_response_text, self.token_encoder)
> self.max_data_tokens
):
break
data.append(formatted_response_text)
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
text_data = "\n\n".join(data)
search_prompt = self.reduce_system_prompt.format(
report_data=text_data, response_type=self.response_type
)
if self.allow_general_knowledge:
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
async for resp in self.llm.astream_generate( # type: ignore
search_messages,
callbacks=self.callbacks, # type: ignore
**llm_kwargs, # type: ignore
):
yield resp

View File

@ -5,6 +5,7 @@
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
import tiktoken
@ -106,6 +107,37 @@ class LocalSearch(BaseSearch):
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
)
async def astream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
"""Build local search context that fits a single context window and generate answer for the user query."""
start_time = time.time()
context_text, context_records = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**self.context_builder_params,
)
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
search_prompt = self.system_prompt.format(
context_data=context_text, response_type=self.response_type
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
# send context records first before sending the reduce response
yield context_records
async for response in self.llm.astream_generate( # type: ignore
messages=search_messages,
callbacks=self.callbacks,
**self.llm_params,
):
yield response
def search(
self,
query: str,