fix(knext): add remote client addr (#376)

This commit is contained in:
royzhao 2024-10-25 13:26:06 +08:00 committed by GitHub
parent 593700b786
commit d0f19d1f6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 22 additions and 12 deletions

View File

@ -13,15 +13,17 @@
from knext.builder import rest
from knext.builder.rest.models.writer_graph_request import WriterGraphRequest
from knext.common.base.client import Client
from knext.common.rest import ApiClient, Configuration
class BuilderClient(Client):
""" """
_rest_client = rest.BuilderApi()
def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id)
self._rest_client: rest.BuilderApi = rest.BuilderApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def write_graph(self, sub_graph: dict, operation: str, lead_to_builder: bool):
request = WriterGraphRequest(

View File

@ -79,7 +79,7 @@ class Configuration(object):
from knext.common import env
self.host = (
os.environ.get("KAG_PROJECT_HOST_ADDR") or host or env.LOCAL_SCHEMA_URL
host or os.environ.get("KAG_PROJECT_HOST_ADDR") or env.LOCAL_SCHEMA_URL
)
"""Default Base url
"""

View File

@ -12,6 +12,7 @@
from typing import List, Dict
from knext.common.base.client import Client
from knext.common.rest import ApiClient, Configuration
from knext.graph_algo import (
GetPageRankScoresRequest,
GetPageRankScoresRequestStartNodes,
@ -23,10 +24,11 @@ from knext.graph_algo import rest
class GraphAlgoClient(Client):
""" """
_rest_client = rest.GraphApi()
def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id)
self._rest_client: rest.GraphApi = rest.GraphApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def calculate_pagerank_scores(self, target_vertex_type, start_nodes: List[Dict]):
"""

View File

@ -12,16 +12,18 @@
import json
from knext.common.base.client import Client
from knext.common.rest import Configuration, ApiClient
from knext.project import rest
class ProjectClient(Client):
""" """
_rest_client = rest.ProjectApi()
def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id)
self._rest_client: rest.ProjectApi = rest.ProjectApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def get_config(self, project_id: str):
project = self.get(id=int(project_id))
@ -36,7 +38,7 @@ class ProjectClient(Client):
for project in projects:
condition = True
for k, v in conditions.items():
condition = condition and getattr(project, k) == v
condition = condition and str(getattr(project, k)) == str(v)
if condition:
return project
return None

View File

@ -13,6 +13,7 @@ import os
import knext.common.cache
from knext.common.base.client import Client
from knext.common.rest import ApiClient, Configuration
from knext.project.client import ProjectClient
from knext.reasoner import ReasonTask
from knext.reasoner import rest
@ -26,10 +27,11 @@ reason_cache = knext.common.cache.SchemaCache()
class ReasonerClient(Client):
"""SPG Reasoner Client."""
_rest_client = rest.ReasonerApi()
def __init__(self, host_addr: str = None, project_id: int = None, namespace=None):
super().__init__(host_addr, project_id)
self._rest_client: rest.ReasonerApi = rest.ReasonerApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
self._namespace = namespace or os.environ.get("KAG_PROJECT_NAMESPACE")
self._session = None
# load schema cache

View File

@ -10,6 +10,7 @@
# or implied.
from knext.common.base.client import Client
from knext.common.rest import Configuration, ApiClient
from knext.search import rest, TextSearchRequest, VectorSearchRequest, IdxRecord
@ -21,10 +22,11 @@ def idx_record_to_dict(record: IdxRecord):
class SearchClient(Client):
""" """
_rest_client = rest.SearchApi()
def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id)
self._rest_client: rest.SearchApi = rest.SearchApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def search_text(self, query_string, label_constraints=None, topk=10):
req = TextSearchRequest(self._project_id, query_string, label_constraints, topk)