Expose community levels to api (#187)

This commit is contained in:
KennyZhang1 2024-10-01 17:26:04 -04:00 committed by GitHub
parent d9708d53f3
commit 1892eb8a65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 51 additions and 25 deletions

View File

@ -79,8 +79,12 @@ async def global_query(request: GraphRequest):
validate_index_file_exist(index_name, ENTITIES_TABLE)
validate_index_file_exist(index_name, NODES_TABLE)
# current investigations show that community level 1 is the most useful for global search
COMMUNITY_LEVEL = 1
if isinstance(request.community_level, int):
COMMUNITY_LEVEL = request.community_level
else:
# Current investigations show that community level 1 is the most useful for global search. Set this as the default value
COMMUNITY_LEVEL = 1
try:
links = {
"nodes": {},
@ -259,7 +263,11 @@ async def local_query(request: GraphRequest):
RELATIONSHIPS_TABLE = "output/create_final_relationships.parquet"
TEXT_UNITS_TABLE = "output/create_final_text_units.parquet"
COMMUNITY_LEVEL = 2
if isinstance(request.community_level, int):
COMMUNITY_LEVEL = request.community_level
else:
# Current investigations show that community level 2 is the most useful for local search. Set this as the default value
COMMUNITY_LEVEL = 2
for index_name in sanitized_index_names:
# check for existence of files the query relies on to validate the index is complete

View File

@ -64,8 +64,12 @@ async def global_search_streaming(request: GraphRequest):
ENTITIES_TABLE = "output/create_final_entities.parquet"
NODES_TABLE = "output/create_final_nodes.parquet"
# current investigations show that community level 1 is the most useful for global search
COMMUNITY_LEVEL = 1
if isinstance(request.community_level, int):
COMMUNITY_LEVEL = request.community_level
else:
# Current investigations show that community level 1 is the most useful for global search. Set this as the default value
COMMUNITY_LEVEL = 1
for index_name in sanitized_index_names:
validate_index_file_exist(index_name, COMMUNITY_REPORT_TABLE)
validate_index_file_exist(index_name, ENTITIES_TABLE)
@ -245,7 +249,12 @@ async def local_search_streaming(request: GraphRequest):
NODES_TABLE = "output/create_final_nodes.parquet"
RELATIONSHIPS_TABLE = "output/create_final_relationships.parquet"
TEXT_UNITS_TABLE = "output/create_final_text_units.parquet"
COMMUNITY_LEVEL = 2
if isinstance(request.community_level, int):
COMMUNITY_LEVEL = request.community_level
else:
# Current investigations show that community level 2 is the most useful for local search. Set this as the default value
COMMUNITY_LEVEL = 2
try:
for index_name in sanitized_index_names:

View File

@ -40,12 +40,13 @@ class EntityResponse(BaseModel):
class GraphRequest(BaseModel):
index_name: str | List[str]
query: str
community_level: int | None = None
class GraphResponse(BaseModel):
result: Any
context_data: Any
class GraphDataResponse(BaseModel):
nodes: int

View File

@ -7,4 +7,4 @@
"GRAPHRAG_LLM_MODEL": "__GRAPHRAG_LLM_MODEL__",
"LOCATION": "__LOCATION__",
"RESOURCE_GROUP": "__RESOURCE_GROUP__"
}
}

View File

@ -327,19 +327,20 @@
"%%time\n",
"\n",
"\n",
"def global_search(index_name: str | list[str], query: str) -> requests.Response:\n",
"def global_search(index_name: str | list[str], query: str, community_level: int) -> requests.Response:\n",
" \"\"\"Run a global query over the knowledge graph(s) associated with one or more indexes\"\"\"\n",
" url = endpoint + \"/query/global\"\n",
" request = {\"index_name\": index_name, \"query\": query}\n",
" # optional parameter: community level to query the graph at (default for global query = 1)\n",
" request = {\"index_name\": index_name, \"query\": query, \"community_level\": community_level}\n",
" return requests.post(url, json=request, headers=headers)\n",
"\n",
"\n",
"# perform a global query\n",
"global_response = global_search(\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\"\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\", community_level=1\n",
")\n",
"global_response_data = parse_query_response(global_response, return_context_data=True)\n",
"global_response_data"
"global_response_data\n"
]
},
{
@ -360,16 +361,17 @@
"%%time\n",
"\n",
"\n",
"def local_search(index_name: str | list[str], query: str) -> requests.Response:\n",
"def local_search(index_name: str | list[str], query: str, community_level: int) -> requests.Response:\n",
" \"\"\"Run a local query over the knowledge graph(s) associated with one or more indexes\"\"\"\n",
" url = endpoint + \"/query/local\"\n",
" request = {\"index_name\": index_name, \"query\": query}\n",
" # optional parameter: community level to query the graph at (default for local query = 2)\n",
" request = {\"index_name\": index_name, \"query\": query, \"community_level\": community_level}\n",
" return requests.post(url, json=request, headers=headers)\n",
"\n",
"\n",
"# perform a local query\n",
"local_response = local_search(\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\"\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\", community_level=2\n",
")\n",
"local_response_data = parse_query_response(local_response, return_context_data=True)\n",
"local_response_data"

View File

@ -324,19 +324,21 @@
" return requests.get(url, headers=headers)\n",
"\n",
"\n",
"def global_search(index_name: str | list[str], query: str) -> requests.Response:\n",
"def global_search(index_name: str | list[str], query: str, community_level: int) -> requests.Response:\n",
" \"\"\"Run a global query over the knowledge graph(s) associated with one or more indexes\"\"\"\n",
" url = endpoint + \"/query/global\"\n",
" request = {\"index_name\": index_name, \"query\": query}\n",
" # optional parameter: community level to query the graph at (default for global query = 1)\n",
" request = {\"index_name\": index_name, \"query\": query, \"community_level\": community_level}\n",
" return requests.post(url, json=request, headers=headers)\n",
"\n",
"\n",
"def global_search_streaming(\n",
" index_name: str | list[str], query: str\n",
" index_name: str | list[str], query: str, community_level: int\n",
") -> requests.Response:\n",
" \"\"\"Run a global query across one or more indexes and stream back the response\"\"\"\n",
" url = endpoint + \"/query/streaming/global\"\n",
" request = {\"index_name\": index_name, \"query\": query}\n",
" # optional parameter: community level to query the graph at (default for global query = 1)\n",
" request = {\"index_name\": index_name, \"query\": query, \"community_level\": community_level}\n",
" context_list = []\n",
" with requests.post(url, json=request, headers=headers, stream=True) as r:\n",
" r.raise_for_status()\n",
@ -356,19 +358,21 @@
" display(pd.DataFrame.from_dict(context_list[0][\"reports\"]).head(10))\n",
"\n",
"\n",
"def local_search(index_name: str | list[str], query: str) -> requests.Response:\n",
"def local_search(index_name: str | list[str], query: str, community_level: int) -> requests.Response:\n",
" \"\"\"Run a local query over the knowledge graph(s) associated with one or more indexes\"\"\"\n",
" url = endpoint + \"/query/local\"\n",
" request = {\"index_name\": index_name, \"query\": query}\n",
" # optional parameter: community level to query the graph at (default for local query = 2)\n",
" request = {\"index_name\": index_name, \"query\": query, \"community_level\": community_level}\n",
" return requests.post(url, json=request, headers=headers)\n",
"\n",
"\n",
"def local_search_streaming(\n",
" index_name: str | list[str], query: str\n",
" index_name: str | list[str], query: str, community_level: int\n",
") -> requests.Response:\n",
" \"\"\"Run a global query across one or more indexes and stream back the response\"\"\"\n",
" url = endpoint + \"/query/streaming/local\"\n",
" request = {\"index_name\": index_name, \"query\": query}\n",
" # optional parameter: community level to query the graph at (default for local query = 2)\n",
" request = {\"index_name\": index_name, \"query\": query, \"community_level\": community_level}\n",
" context_list = []\n",
" with requests.post(url, json=request, headers=headers, stream=True) as r:\n",
" r.raise_for_status()\n",
@ -742,7 +746,7 @@
"%%time\n",
"# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]\n",
"global_response = global_search(\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\"\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\", community_level=1\n",
")\n",
"# print the result and save context data in a variable\n",
"global_response_data = parse_query_response(global_response, return_context_data=True)\n",
@ -765,7 +769,7 @@
"outputs": [],
"source": [
"global_search_streaming(\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\"\n",
" index_name=index_name, query=\"Summarize the main topics found in this data\", community_level=1\n",
")"
]
},
@ -793,6 +797,7 @@
"local_response = local_search(\n",
" index_name=index_name,\n",
" query=\"Who are the primary actors in these communities?\",\n",
" community_level=2\n",
")\n",
"# print the result and save context data in a variable\n",
"local_response_data = parse_query_response(local_response, return_context_data=True)\n",
@ -817,6 +822,7 @@
"local_search_streaming(\n",
" index_name=index_name,\n",
" query=\"Who are the primary actors in these communities?\",\n",
" community_level=2\n",
")"
]
},