fix(knext): add remote client addr (#376)

This commit is contained in:
royzhao 2024-10-25 13:26:06 +08:00 committed by GitHub
parent 593700b786
commit d0f19d1f6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 22 additions and 12 deletions

View File

@ -13,15 +13,17 @@
from knext.builder import rest from knext.builder import rest
from knext.builder.rest.models.writer_graph_request import WriterGraphRequest from knext.builder.rest.models.writer_graph_request import WriterGraphRequest
from knext.common.base.client import Client from knext.common.base.client import Client
from knext.common.rest import ApiClient, Configuration
class BuilderClient(Client): class BuilderClient(Client):
""" """ """ """
_rest_client = rest.BuilderApi()
def __init__(self, host_addr: str = None, project_id: int = None): def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id) super().__init__(host_addr, project_id)
self._rest_client: rest.BuilderApi = rest.BuilderApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def write_graph(self, sub_graph: dict, operation: str, lead_to_builder: bool): def write_graph(self, sub_graph: dict, operation: str, lead_to_builder: bool):
request = WriterGraphRequest( request = WriterGraphRequest(

View File

@ -79,7 +79,7 @@ class Configuration(object):
from knext.common import env from knext.common import env
self.host = ( self.host = (
os.environ.get("KAG_PROJECT_HOST_ADDR") or host or env.LOCAL_SCHEMA_URL host or os.environ.get("KAG_PROJECT_HOST_ADDR") or env.LOCAL_SCHEMA_URL
) )
"""Default Base url """Default Base url
""" """

View File

@ -12,6 +12,7 @@
from typing import List, Dict from typing import List, Dict
from knext.common.base.client import Client from knext.common.base.client import Client
from knext.common.rest import ApiClient, Configuration
from knext.graph_algo import ( from knext.graph_algo import (
GetPageRankScoresRequest, GetPageRankScoresRequest,
GetPageRankScoresRequestStartNodes, GetPageRankScoresRequestStartNodes,
@ -23,10 +24,11 @@ from knext.graph_algo import rest
class GraphAlgoClient(Client): class GraphAlgoClient(Client):
""" """ """ """
_rest_client = rest.GraphApi()
def __init__(self, host_addr: str = None, project_id: int = None): def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id) super().__init__(host_addr, project_id)
self._rest_client: rest.GraphApi = rest.GraphApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def calculate_pagerank_scores(self, target_vertex_type, start_nodes: List[Dict]): def calculate_pagerank_scores(self, target_vertex_type, start_nodes: List[Dict]):
""" """

View File

@ -12,16 +12,18 @@
import json import json
from knext.common.base.client import Client from knext.common.base.client import Client
from knext.common.rest import Configuration, ApiClient
from knext.project import rest from knext.project import rest
class ProjectClient(Client): class ProjectClient(Client):
""" """ """ """
_rest_client = rest.ProjectApi()
def __init__(self, host_addr: str = None, project_id: int = None): def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id) super().__init__(host_addr, project_id)
self._rest_client: rest.ProjectApi = rest.ProjectApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def get_config(self, project_id: str): def get_config(self, project_id: str):
project = self.get(id=int(project_id)) project = self.get(id=int(project_id))
@ -36,7 +38,7 @@ class ProjectClient(Client):
for project in projects: for project in projects:
condition = True condition = True
for k, v in conditions.items(): for k, v in conditions.items():
condition = condition and getattr(project, k) == v condition = condition and str(getattr(project, k)) == str(v)
if condition: if condition:
return project return project
return None return None

View File

@ -13,6 +13,7 @@ import os
import knext.common.cache import knext.common.cache
from knext.common.base.client import Client from knext.common.base.client import Client
from knext.common.rest import ApiClient, Configuration
from knext.project.client import ProjectClient from knext.project.client import ProjectClient
from knext.reasoner import ReasonTask from knext.reasoner import ReasonTask
from knext.reasoner import rest from knext.reasoner import rest
@ -26,10 +27,11 @@ reason_cache = knext.common.cache.SchemaCache()
class ReasonerClient(Client): class ReasonerClient(Client):
"""SPG Reasoner Client.""" """SPG Reasoner Client."""
_rest_client = rest.ReasonerApi()
def __init__(self, host_addr: str = None, project_id: int = None, namespace=None): def __init__(self, host_addr: str = None, project_id: int = None, namespace=None):
super().__init__(host_addr, project_id) super().__init__(host_addr, project_id)
self._rest_client: rest.ReasonerApi = rest.ReasonerApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
self._namespace = namespace or os.environ.get("KAG_PROJECT_NAMESPACE") self._namespace = namespace or os.environ.get("KAG_PROJECT_NAMESPACE")
self._session = None self._session = None
# load schema cache # load schema cache

View File

@ -10,6 +10,7 @@
# or implied. # or implied.
from knext.common.base.client import Client from knext.common.base.client import Client
from knext.common.rest import Configuration, ApiClient
from knext.search import rest, TextSearchRequest, VectorSearchRequest, IdxRecord from knext.search import rest, TextSearchRequest, VectorSearchRequest, IdxRecord
@ -21,10 +22,11 @@ def idx_record_to_dict(record: IdxRecord):
class SearchClient(Client): class SearchClient(Client):
""" """ """ """
_rest_client = rest.SearchApi()
def __init__(self, host_addr: str = None, project_id: int = None): def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id) super().__init__(host_addr, project_id)
self._rest_client: rest.SearchApi = rest.SearchApi(
api_client=ApiClient(configuration=Configuration(host=host_addr))
)
def search_text(self, query_string, label_constraints=None, topk=10): def search_text(self, query_string, label_constraints=None, topk=10):
req = TextSearchRequest(self._project_id, query_string, label_constraints, topk) req = TextSearchRequest(self._project_id, query_string, label_constraints, topk)