diff --git a/.semversioner/next-release/patch-20240823233325895089.json b/.semversioner/next-release/patch-20240823233325895089.json new file mode 100644 index 00000000..d4ee1bef --- /dev/null +++ b/.semversioner/next-release/patch-20240823233325895089.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fix weight casting during graph extraction" +} diff --git a/graphrag/index/graph/extractors/graph/graph_extractor.py b/graphrag/index/graph/extractors/graph/graph_extractor.py index f1ba0011..49ca671a 100644 --- a/graphrag/index/graph/extractors/graph/graph_extractor.py +++ b/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -4,7 +4,6 @@ """A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" import logging -import numbers import re import traceback from collections.abc import Mapping @@ -248,11 +247,11 @@ class GraphExtractor: target = clean_str(record_attributes[2].upper()) edge_description = clean_str(record_attributes[3]) edge_source_id = clean_str(str(source_doc_id)) - weight = ( - float(record_attributes[-1]) - if isinstance(record_attributes[-1], numbers.Number) - else 1.0 - ) + try: + weight = float(record_attributes[-1]) + except ValueError: + weight = 1.0 + if source not in graph.nodes(): graph.add_node( source,