Fix/minor query fixes (#1893)

* fixed token count for drift search

* basic search fixes

* updated basic search prompt

* fixed text splitting logic

* Lint/format

* Semver

* Fix text splitting tests

---------

Co-authored-by: ha2trinh <trinhha@microsoft.com>
This commit is contained in:
Nathan Evans 2025-04-25 14:12:18 -07:00 committed by GitHub
parent ad4cdd685f
commit e2a448170a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 108 additions and 31 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fixes to basic search."
}

View File

@ -42,6 +42,7 @@ class BasicSearchDefaults:
prompt: None = None
k: int = 10
max_context_tokens: int = 12_000
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID

View File

@ -27,3 +27,7 @@ class BasicSearchConfig(BaseModel):
description="The number of text units to include in search context.",
default=graphrag_config_defaults.basic_search.k,
)
max_context_tokens: int = Field(
description="The maximum tokens.",
default=graphrag_config_defaults.basic_search.max_context_tokens,
)

View File

@ -152,6 +152,8 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
while start_idx < len(input_ids):
chunk_text = tokenizer.decode(list(chunk_ids))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
@ -186,6 +188,8 @@ def split_multiple_texts_on_tokens(
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]

View File

@ -11,23 +11,25 @@ You are a helpful assistant responding to questions about data in the tables pro
---Goal---
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data tables appropriate for the response length and format.
If you don't know the answer, just say so. Do not make anything up.
You should use the data provided in the data tables below as the primary context for generating the response.
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
Points supported by data should list their data references as follows:
"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"
where 15 and 16 represent the id (not the index) of the relevant data record.
where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.
Do not include information where the supporting text for it is not provided.
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
@ -42,23 +44,26 @@ Do not include information where the supporting text for it is not provided.
---Goal---
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data appropriate for the response length and format.
If you don't know the answer, just say so. Do not make anything up.
You should use the data provided in the data tables below as the primary context for generating the response.
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
Points supported by data should list their data references as follows:
"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"
where 15 and 16 represent the id (not the index) of the relevant data record.
where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.
Do not include information where the supporting evidence for it is not provided.
Do not include information where the supporting text for it is not provided.
---Target response length and format---

View File

@ -275,6 +275,7 @@ def get_basic_search_engine(
text_unit_embeddings: BaseVectorStore,
config: GraphRagConfig,
system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
callbacks: list[QueryCallbacks] | None = None,
) -> BasicSearch:
"""Create a basic search engine based on data + configuration."""
@ -312,6 +313,7 @@ def get_basic_search_engine(
return BasicSearch(
model=chat_model,
system_prompt=system_prompt,
response_type=response_type,
context_builder=BasicSearchContext(
text_embedder=embedding_model,
text_unit_embeddings=text_unit_embeddings,
@ -323,6 +325,7 @@ def get_basic_search_engine(
context_builder_params={
"embedding_vectorstore_key": "id",
"k": bs_config.k,
"max_context_tokens": bs_config.max_context_tokens,
},
callbacks=callbacks,
)

View File

@ -3,6 +3,9 @@
"""Basic Context Builder implementation."""
import logging
from typing import cast
import pandas as pd
import tiktoken
@ -13,8 +16,11 @@ from graphrag.query.context_builder.builders import (
ContextBuilderResult,
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.llm.text_utils import num_tokens
from graphrag.vector_stores.base import BaseVectorStore
log = logging.getLogger(__name__)
class BasicSearchContext(BasicContextBuilder):
"""Class representing the Basic Search Context Builder."""
@ -32,30 +38,76 @@ class BasicSearchContext(BasicContextBuilder):
self.text_units = text_units
self.text_unit_embeddings = text_unit_embeddings
self.embedding_vectorstore_key = embedding_vectorstore_key
self.text_id_map = self._map_ids()
def build_context(
self,
query: str,
conversation_history: ConversationHistory | None = None,
k: int = 10,
max_context_tokens: int = 12_000,
context_name: str = "Sources",
column_delimiter: str = "|",
text_id_col: str = "source_id",
text_col: str = "text",
**kwargs,
) -> ContextBuilderResult:
"""Build the context for the local search mode."""
search_results = self.text_unit_embeddings.similarity_search_by_text(
text=query,
text_embedder=lambda t: self.text_embedder.embed(t),
k=kwargs.get("k", 10),
)
# we don't have a friendly id on text_units, so just copy the index
sources = [
{"id": str(search_results.index(r)), "text": r.document.text}
for r in search_results
]
# make a delimited table for the context; this imitates graphrag context building
table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources]
"""Build the context for the basic search mode."""
if query != "":
related_texts = self.text_unit_embeddings.similarity_search_by_text(
text=query,
text_embedder=lambda t: self.text_embedder.embed(t),
k=k,
)
related_text_list = [
{
text_id_col: self.text_id_map[f"{chunk.document.id}"],
text_col: chunk.document.text,
}
for chunk in related_texts
]
related_text_df = pd.DataFrame(related_text_list)
else:
related_text_df = pd.DataFrame({
text_id_col: [],
text_col: [],
})
columns = pd.Index(["id", "text"])
# add these related text chunks into context until we fill up the context window
current_tokens = 0
text_ids = []
current_tokens = num_tokens(
text_id_col + column_delimiter + text_col + "\n", self.token_encoder
)
for i, row in related_text_df.iterrows():
text = row[text_id_col] + column_delimiter + row[text_col] + "\n"
tokens = num_tokens(text, self.token_encoder)
if current_tokens + tokens > max_context_tokens:
msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state"
log.info(msg)
break
current_tokens += tokens
text_ids.append(i)
final_text_df = cast(
"pd.DataFrame",
related_text_df[related_text_df.index.isin(text_ids)].reset_index(
drop=True
),
)
final_text = final_text_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ContextBuilderResult(
context_chunks="\n\n".join(table),
context_records={"sources": pd.DataFrame(sources, columns=columns)},
context_chunks=final_text,
context_records={context_name: final_text_df},
)
def _map_ids(self) -> dict[str, str]:
"""Map id to short id in the text units."""
id_map = {}
text_units = self.text_units or []
for unit in text_units:
id_map[unit.id] = unit.short_id
return id_map

View File

@ -108,6 +108,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
@ -120,6 +123,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
async def stream_search(

View File

@ -213,7 +213,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
primer_context, token_ct = await self.context_builder.build_context(query)
llm_calls["build_context"] = token_ct["llm_calls"]
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["output_tokens"]
primer_response = await self.primer.search(
query=query, top_k_reports=primer_context

View File

@ -136,7 +136,6 @@ def test_split_single_text_on_tokens():
" by this t",
"his test o",
"est only.",
"nly.",
]
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
@ -197,7 +196,6 @@ def test_split_single_text_on_tokens_no_overlap():
" this test",
" test only",
" only.",
".",
]
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)