mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-29 16:09:25 +00:00
Fix/minor query fixes (#1893)
* fixed token count for drift search * basic search fixes * updated basic search prompt * fixed text splitting logic * Lint/format * Semver * Fix text splitting tests --------- Co-authored-by: ha2trinh <trinhha@microsoft.com>
This commit is contained in:
parent
ad4cdd685f
commit
e2a448170a
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fixes to basic search."
|
||||
}
|
||||
@ -42,6 +42,7 @@ class BasicSearchDefaults:
|
||||
|
||||
prompt: None = None
|
||||
k: int = 10
|
||||
max_context_tokens: int = 12_000
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
|
||||
@ -27,3 +27,7 @@ class BasicSearchConfig(BaseModel):
|
||||
description="The number of text units to include in search context.",
|
||||
default=graphrag_config_defaults.basic_search.k,
|
||||
)
|
||||
max_context_tokens: int = Field(
|
||||
description="The maximum tokens.",
|
||||
default=graphrag_config_defaults.basic_search.max_context_tokens,
|
||||
)
|
||||
|
||||
@ -152,6 +152,8 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
while start_idx < len(input_ids):
|
||||
chunk_text = tokenizer.decode(list(chunk_ids))
|
||||
result.append(chunk_text) # Append chunked text as string
|
||||
if cur_idx == len(input_ids):
|
||||
break
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
@ -186,6 +188,8 @@ def split_multiple_texts_on_tokens(
|
||||
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
|
||||
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
|
||||
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
|
||||
if cur_idx == len(input_ids):
|
||||
break
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
@ -11,23 +11,25 @@ You are a helpful assistant responding to questions about data in the tables pro
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data tables appropriate for the response length and format.
|
||||
|
||||
If you don't know the answer, just say so. Do not make anything up.
|
||||
You should use the data provided in the data tables below as the primary context for generating the response.
|
||||
|
||||
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
|
||||
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"
|
||||
|
||||
where 15 and 16 represent the id (not the index) of the relevant data record.
|
||||
where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.
|
||||
|
||||
Do not include information where the supporting text for it is not provided.
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
|
||||
---Target response length and format---
|
||||
@ -42,23 +44,26 @@ Do not include information where the supporting text for it is not provided.
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data appropriate for the response length and format.
|
||||
|
||||
If you don't know the answer, just say so. Do not make anything up.
|
||||
You should use the data provided in the data tables below as the primary context for generating the response.
|
||||
|
||||
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
|
||||
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"
|
||||
|
||||
where 15 and 16 represent the id (not the index) of the relevant data record.
|
||||
where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.
|
||||
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
Do not include information where the supporting text for it is not provided.
|
||||
|
||||
---Target response length and format---
|
||||
|
||||
|
||||
@ -275,6 +275,7 @@ def get_basic_search_engine(
|
||||
text_unit_embeddings: BaseVectorStore,
|
||||
config: GraphRagConfig,
|
||||
system_prompt: str | None = None,
|
||||
response_type: str = "multiple paragraphs",
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
) -> BasicSearch:
|
||||
"""Create a basic search engine based on data + configuration."""
|
||||
@ -312,6 +313,7 @@ def get_basic_search_engine(
|
||||
return BasicSearch(
|
||||
model=chat_model,
|
||||
system_prompt=system_prompt,
|
||||
response_type=response_type,
|
||||
context_builder=BasicSearchContext(
|
||||
text_embedder=embedding_model,
|
||||
text_unit_embeddings=text_unit_embeddings,
|
||||
@ -323,6 +325,7 @@ def get_basic_search_engine(
|
||||
context_builder_params={
|
||||
"embedding_vectorstore_key": "id",
|
||||
"k": bs_config.k,
|
||||
"max_context_tokens": bs_config.max_context_tokens,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
@ -3,6 +3,9 @@
|
||||
|
||||
"""Basic Context Builder implementation."""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
@ -13,8 +16,11 @@ from graphrag.query.context_builder.builders import (
|
||||
ContextBuilderResult,
|
||||
)
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasicSearchContext(BasicContextBuilder):
|
||||
"""Class representing the Basic Search Context Builder."""
|
||||
@ -32,30 +38,76 @@ class BasicSearchContext(BasicContextBuilder):
|
||||
self.text_units = text_units
|
||||
self.text_unit_embeddings = text_unit_embeddings
|
||||
self.embedding_vectorstore_key = embedding_vectorstore_key
|
||||
self.text_id_map = self._map_ids()
|
||||
|
||||
def build_context(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
k: int = 10,
|
||||
max_context_tokens: int = 12_000,
|
||||
context_name: str = "Sources",
|
||||
column_delimiter: str = "|",
|
||||
text_id_col: str = "source_id",
|
||||
text_col: str = "text",
|
||||
**kwargs,
|
||||
) -> ContextBuilderResult:
|
||||
"""Build the context for the local search mode."""
|
||||
search_results = self.text_unit_embeddings.similarity_search_by_text(
|
||||
text=query,
|
||||
text_embedder=lambda t: self.text_embedder.embed(t),
|
||||
k=kwargs.get("k", 10),
|
||||
)
|
||||
# we don't have a friendly id on text_units, so just copy the index
|
||||
sources = [
|
||||
{"id": str(search_results.index(r)), "text": r.document.text}
|
||||
for r in search_results
|
||||
]
|
||||
# make a delimited table for the context; this imitates graphrag context building
|
||||
table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources]
|
||||
"""Build the context for the basic search mode."""
|
||||
if query != "":
|
||||
related_texts = self.text_unit_embeddings.similarity_search_by_text(
|
||||
text=query,
|
||||
text_embedder=lambda t: self.text_embedder.embed(t),
|
||||
k=k,
|
||||
)
|
||||
related_text_list = [
|
||||
{
|
||||
text_id_col: self.text_id_map[f"{chunk.document.id}"],
|
||||
text_col: chunk.document.text,
|
||||
}
|
||||
for chunk in related_texts
|
||||
]
|
||||
related_text_df = pd.DataFrame(related_text_list)
|
||||
else:
|
||||
related_text_df = pd.DataFrame({
|
||||
text_id_col: [],
|
||||
text_col: [],
|
||||
})
|
||||
|
||||
columns = pd.Index(["id", "text"])
|
||||
# add these related text chunks into context until we fill up the context window
|
||||
current_tokens = 0
|
||||
text_ids = []
|
||||
current_tokens = num_tokens(
|
||||
text_id_col + column_delimiter + text_col + "\n", self.token_encoder
|
||||
)
|
||||
for i, row in related_text_df.iterrows():
|
||||
text = row[text_id_col] + column_delimiter + row[text_col] + "\n"
|
||||
tokens = num_tokens(text, self.token_encoder)
|
||||
if current_tokens + tokens > max_context_tokens:
|
||||
msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state"
|
||||
log.info(msg)
|
||||
break
|
||||
|
||||
current_tokens += tokens
|
||||
text_ids.append(i)
|
||||
final_text_df = cast(
|
||||
"pd.DataFrame",
|
||||
related_text_df[related_text_df.index.isin(text_ids)].reset_index(
|
||||
drop=True
|
||||
),
|
||||
)
|
||||
final_text = final_text_df.to_csv(
|
||||
index=False, escapechar="\\", sep=column_delimiter
|
||||
)
|
||||
|
||||
return ContextBuilderResult(
|
||||
context_chunks="\n\n".join(table),
|
||||
context_records={"sources": pd.DataFrame(sources, columns=columns)},
|
||||
context_chunks=final_text,
|
||||
context_records={context_name: final_text_df},
|
||||
)
|
||||
|
||||
def _map_ids(self) -> dict[str, str]:
|
||||
"""Map id to short id in the text units."""
|
||||
id_map = {}
|
||||
text_units = self.text_units or []
|
||||
for unit in text_units:
|
||||
id_map[unit.id] = unit.short_id
|
||||
return id_map
|
||||
|
||||
@ -108,6 +108,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@ -120,6 +123,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=0,
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
async def stream_search(
|
||||
|
||||
@ -213,7 +213,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
primer_context, token_ct = await self.context_builder.build_context(query)
|
||||
llm_calls["build_context"] = token_ct["llm_calls"]
|
||||
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
output_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
output_tokens["build_context"] = token_ct["output_tokens"]
|
||||
|
||||
primer_response = await self.primer.search(
|
||||
query=query, top_k_reports=primer_context
|
||||
|
||||
@ -136,7 +136,6 @@ def test_split_single_text_on_tokens():
|
||||
" by this t",
|
||||
"his test o",
|
||||
"est only.",
|
||||
"nly.",
|
||||
]
|
||||
|
||||
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
@ -197,7 +196,6 @@ def test_split_single_text_on_tokens_no_overlap():
|
||||
" this test",
|
||||
" test only",
|
||||
" only.",
|
||||
".",
|
||||
]
|
||||
|
||||
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user