Fix sort_context max_tokens & max_tokens param in verb (#888)

* Fix sort_context max_tokens & max_tokens param in verb

* Fix sort_context for windows test

* add semversioner file

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Andres Morales 2024-08-12 15:55:31 -06:00 committed by GitHub
parent 238f1c2adc
commit 5a7dbaa051
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 211 additions and 193 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "fix sort_context & max_tokens params in verb"
}

View File

@ -144,7 +144,7 @@ def sort_context(
new_context_string = _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
if num_tokens(context_string) > max_tokens:
if num_tokens(new_context_string) > max_tokens:
break
context_string = new_context_string

View File

@ -19,6 +19,10 @@ def build_steps(
"""
covariates_enabled = config.get("covariates_enabled", False)
create_community_reports_config = config.get("create_community_reports", {})
community_report_strategy = create_community_reports_config.get("strategy", {})
community_report_max_input_length = community_report_strategy.get(
"max_input_length", 16_000
)
base_text_embed = config.get("text_embed", {})
community_report_full_content_embed_config = config.get(
"community_report_full_content_embed", base_text_embed
@ -77,6 +81,7 @@ def build_steps(
{
"id": "local_contexts",
"verb": "prepare_community_reports",
"args": {"max_tokens": community_report_max_input_length},
"input": {
"source": "nodes",
"nodes": "nodes",

View File

@ -1,14 +1,14 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import math
import platform
from graphrag.index.graph.extractors.community_reports import sort_context
from graphrag.query.llm.text_utils import num_tokens
nan = math.nan
def test_sort_context():
context: list[dict] = [
context: list[dict] = [
{
"title": "ALI BABA",
"degree": 1,
@ -198,7 +198,16 @@ def test_sort_context():
],
"claim_details": [nan],
},
]
]
def test_sort_context():
ctx = sort_context(context)
assert num_tokens(ctx) == 827 if platform.system() == "Windows" else 826
assert ctx is not None
def test_sort_context_max_tokens():
ctx = sort_context(context, max_tokens=800)
assert ctx is not None
assert num_tokens(ctx) <= 800