mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-03 02:21:01 +00:00
Fix sort_context max_tokens & max_tokens param in verb (#888)
* Fix sort_context max_tokens & max_tokens param in verb * Fix sort_context for windows test * add semversioner file --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
238f1c2adc
commit
5a7dbaa051
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "fix sort_context & max_tokens params in verb"
|
||||
}
|
||||
@ -144,7 +144,7 @@ def sort_context(
|
||||
new_context_string = _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
if num_tokens(context_string) > max_tokens:
|
||||
if num_tokens(new_context_string) > max_tokens:
|
||||
break
|
||||
context_string = new_context_string
|
||||
|
||||
|
||||
@ -19,6 +19,10 @@ def build_steps(
|
||||
"""
|
||||
covariates_enabled = config.get("covariates_enabled", False)
|
||||
create_community_reports_config = config.get("create_community_reports", {})
|
||||
community_report_strategy = create_community_reports_config.get("strategy", {})
|
||||
community_report_max_input_length = community_report_strategy.get(
|
||||
"max_input_length", 16_000
|
||||
)
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
community_report_full_content_embed_config = config.get(
|
||||
"community_report_full_content_embed", base_text_embed
|
||||
@ -77,6 +81,7 @@ def build_steps(
|
||||
{
|
||||
"id": "local_contexts",
|
||||
"verb": "prepare_community_reports",
|
||||
"args": {"max_tokens": community_report_max_input_length},
|
||||
"input": {
|
||||
"source": "nodes",
|
||||
"nodes": "nodes",
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
import math
|
||||
import platform
|
||||
|
||||
from graphrag.index.graph.extractors.community_reports import sort_context
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
nan = math.nan
|
||||
|
||||
|
||||
def test_sort_context():
|
||||
context: list[dict] = [
|
||||
context: list[dict] = [
|
||||
{
|
||||
"title": "ALI BABA",
|
||||
"degree": 1,
|
||||
@ -198,7 +198,16 @@ def test_sort_context():
|
||||
],
|
||||
"claim_details": [nan],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def test_sort_context():
|
||||
ctx = sort_context(context)
|
||||
assert num_tokens(ctx) == 827 if platform.system() == "Windows" else 826
|
||||
assert ctx is not None
|
||||
|
||||
|
||||
def test_sort_context_max_tokens():
|
||||
ctx = sort_context(context, max_tokens=800)
|
||||
assert ctx is not None
|
||||
assert num_tokens(ctx) <= 800
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user