chore: remove BaseKnowledgeGraph (#4953)

* remove BaseKnowledgeGraph

* fix pylint
This commit is contained in:
Massimiliano Pippi 2023-05-21 10:42:02 +02:00 committed by GitHub
parent 5321d91f97
commit c6ea542b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 2 additions and 738 deletions

View File

@ -1,107 +0,0 @@
from pathlib import Path
from haystack.nodes import Text2SparqlRetriever
from haystack.document_stores import GraphDBKnowledgeGraph, InMemoryKnowledgeGraph
from haystack.utils import fetch_archive_from_http
def test_graph_retrieval():
# we use a timeout double the default in the CI to account for slow runners
timeout = 20
# TODO rename doc_dir
graph_dir = "../data/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip"
fetch_archive_from_http(url=s3_url, output_dir=graph_dir)
# Fetch a pre-trained BART model that translates natural language questions to SPARQL queries
model_dir = "../saved_models/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip"
fetch_archive_from_http(url=s3_url, output_dir=model_dir)
kg = GraphDBKnowledgeGraph(index="tutorial_10_index")
kg.delete_index(timeout=timeout)
kg.create_index(config_path=Path(graph_dir + "repo-config.ttl"), timeout=timeout)
kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl"), timeout=timeout)
triple = {
"p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"},
"s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"},
"o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"},
}
triples = kg.get_all_triples()
assert len(triples) > 0
assert triple in triples
# Define prefixes for names of resources so that we can use shorter resource names in queries
prefixes = """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
PREFIX hp: <https://deepset.ai/harry_potter/>
"""
kg.prefixes = prefixes
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
result = kgqa_retriever.retrieve(query="In which house is Harry Potter?")
assert result[0] == {
"answer": ["https://deepset.ai/harry_potter/Gryffindor"],
"prediction_meta": {
"model": "Text2SparqlRetriever",
"sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }",
},
}
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid"
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?obj where { <https://deepset.ai/harry_potter/Hermione_granger> <https://deepset.ai/harry_potter/patronus> ?obj . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Otter"
def test_inmemory_graph_retrieval():
# TODO rename doc_dir
graph_dir = "../data/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip"
fetch_archive_from_http(url=s3_url, output_dir=graph_dir)
# Fetch a pre-trained BART model that translates natural language questions to SPARQL queries
model_dir = "../saved_models/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip"
fetch_archive_from_http(url=s3_url, output_dir=model_dir)
kg = InMemoryKnowledgeGraph(index="tutorial_10_index")
kg.delete_index()
kg.create_index()
kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl"))
triple = {
"p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"},
"s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"},
"o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"},
}
triples = kg.get_all_triples()
assert len(triples) > 0
assert triple in triples
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
result = kgqa_retriever.retrieve(query="In which house is Harry Potter?")
assert result[0] == {
"answer": ["https://deepset.ai/harry_potter/Gryffindor"],
"prediction_meta": {
"model": "Text2SparqlRetriever",
"sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }",
},
}
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid"
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?obj where { <https://deepset.ai/harry_potter/Hermione_granger> <https://deepset.ai/harry_potter/patronus> ?obj . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Otter"

View File

@ -1,5 +1,5 @@
from haystack.utils.import_utils import safe_import
from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, KeywordDocumentStore
from haystack.document_stores.base import BaseDocumentStore, KeywordDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore
@ -19,7 +19,3 @@ SQLDocumentStore = safe_import("haystack.document_stores.sql", "SQLDocumentStore
FAISSDocumentStore = safe_import("haystack.document_stores.faiss", "FAISSDocumentStore", "faiss")
PineconeDocumentStore = safe_import("haystack.document_stores.pinecone", "PineconeDocumentStore", "pinecone")
WeaviateDocumentStore = safe_import("haystack.document_stores.weaviate", "WeaviateDocumentStore", "weaviate")
GraphDBKnowledgeGraph = safe_import("haystack.document_stores.graphdb", "GraphDBKnowledgeGraph", "graphdb")
InMemoryKnowledgeGraph = safe_import(
"haystack.document_stores.memory_knowledgegraph", "InMemoryKnowledgeGraph", "inmemorygraph"
)

View File

@ -2,7 +2,6 @@
from typing import Generator, Optional, Dict, List, Set, Union, Any
import warnings
import logging
import collections
from pathlib import Path
@ -32,35 +31,6 @@ except (ImportError, ModuleNotFoundError):
return f
class BaseKnowledgeGraph(BaseComponent):
"""
Base class for implementing Knowledge Graphs.
The BaseKnowledgeGraph component is deprecated and will be removed in future versions.
"""
def __init__(self):
warnings.warn(
"The BaseKnowledgeGraph component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
outgoing_edges = 1
def run(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): # type: ignore
result = self.query(sparql_query=sparql_query, index=index, headers=headers)
output = {"sparql_result": result}
return output, "output_1"
def run_batch(self):
raise NotImplementedError("run_batch is not implemented for KnowledgeGraphs.")
@abstractmethod
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
raise NotImplementedError
class BaseDocumentStore(BaseComponent):
"""
Base class for implementing Document Stores.

View File

@ -1,206 +0,0 @@
from typing import Dict, Optional, Union, Tuple
import warnings
from pathlib import Path
import requests
from requests.auth import HTTPBasicAuth
try:
from SPARQLWrapper import SPARQLWrapper, JSON
except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed
_optional_component_not_installed(__name__, "graphdb", ie)
from haystack.document_stores import BaseKnowledgeGraph
class GraphDBKnowledgeGraph(BaseKnowledgeGraph):
"""
Knowledge graph store that runs on a GraphDB instance.
"""
def __init__(
self,
host: str = "localhost",
port: int = 7200,
username: str = "",
password: str = "",
index: Optional[str] = None,
prefixes: str = "",
):
"""
The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions.
Init the knowledge graph by defining the settings to connect with a GraphDB instance
:param host: address of server where the GraphDB instance is running
:param port: port where the GraphDB instance is running
:param username: username to login to the GraphDB instance (if any)
:param password: password to login to the GraphDB instance (if any)
:param index: name of the index (also called repository) stored in the GraphDB instance
:param prefixes: definitions of namespaces with a new line after each namespace, e.g., PREFIX hp: <https://deepset.ai/harry_potter/>
"""
warnings.warn(
"The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
self.url = f"http://{host}:{port}"
self.index = index
self.username = username
self.password = password
self.prefixes = prefixes
def create_index(
self,
config_path: Path,
headers: Optional[Dict[str, str]] = None,
timeout: Union[float, Tuple[float, float]] = 10.0,
):
"""
Create a new index (also called repository) stored in the GraphDB instance
:param config_path: path to a .ttl file with configuration settings, details:
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:param timeout: How many seconds to wait for the server to send data before giving up,
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
https://graphdb.ontotext.com/documentation/free/configuring-a-repository.html#configure-a-repository-programmatically
"""
url = f"{self.url}/rest/repositories"
files = {"config": open(config_path, "r", encoding="utf-8")}
response = requests.post(url, files=files, headers=headers, timeout=timeout)
if response.status_code > 299:
raise Exception(response.text)
def delete_index(self, headers: Optional[Dict[str, str]] = None, timeout: Union[float, Tuple[float, float]] = 10.0):
"""
Delete the index that GraphDBKnowledgeGraph is connected to. This method deletes all data stored in the index.
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:param timeout: How many seconds to wait for the server to send data before giving up,
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
"""
url = f"{self.url}/rest/repositories/{self.index}"
response = requests.delete(url, headers=headers, timeout=timeout)
if response.status_code > 299:
raise Exception(response.text)
def import_from_ttl_file(
self,
index: str,
path: Path,
headers: Optional[Dict[str, str]] = None,
timeout: Union[float, Tuple[float, float]] = 10.0,
):
"""
Load an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file into an index of GraphDB
:param index: name of the index (also called repository) in the GraphDB instance where the imported triples shall be stored
:param path: path to a .ttl containing a knowledge graph
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:param timeout: How many seconds to wait for the server to send data before giving up,
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
"""
url = f"{self.url}/repositories/{index}/statements"
headers = (
{"Content-type": "application/x-turtle"}
if headers is None
else {**{"Content-type": "application/x-turtle"}, **headers}
)
response = requests.post(
url,
headers=headers,
data=open(path, "r", encoding="utf-8").read().encode("utf-8"),
auth=HTTPBasicAuth(self.username, self.password),
timeout=timeout,
)
if response.status_code > 299:
raise Exception(response.text)
def get_all_triples(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored triples. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all triples stored in the index
"""
sparql_query = "SELECT * WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def get_all_subjects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored subjects. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all subjects stored in the index
"""
sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def get_all_predicates(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored predicates. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all predicates stored in the index
"""
sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def _create_document_field_map(self) -> Dict:
"""
There is no field mapping required
"""
return {}
def get_all_objects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored objects. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all objects stored in the index
"""
sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Execute a SPARQL query on the given index in the GraphDB instance
:param sparql_query: SPARQL query that shall be executed
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: query result
"""
if self.index is None and index is None:
raise Exception("Index name is required")
index = index or self.index
sparql = SPARQLWrapper(f"{self.url}/repositories/{index}")
sparql.setCredentials(self.username, self.password)
sparql.setQuery(self.prefixes + sparql_query)
sparql.setReturnFormat(JSON)
if headers is not None:
sparql.customHttpHeaders = headers
results = sparql.query().convert()
# if query is a boolean query, return boolean instead of text result
# FIXME: 'results' likely doesn't support membership test (`"something" in results`).
# Pylint raises unsupported-membership-test and unsubscriptable-object.
# Silenced for now, keep in mind for future debugging.
return (
results["results"]["bindings"] # type: ignore # pylint: disable=unsubscriptable-object
if "results" in results # type: ignore # pylint: disable=unsupported-membership-test
else results["boolean"] # type: ignore # pylint: disable=unsubscriptable-object
)

View File

@ -1,144 +0,0 @@
from typing import Dict, Optional
import warnings
import logging
from collections import defaultdict
from pathlib import Path
from rdflib import Graph
from haystack.document_stores import BaseKnowledgeGraph
logger = logging.getLogger(__name__)
class InMemoryKnowledgeGraph(BaseKnowledgeGraph):
"""
In memory Knowledge graph store, based on rdflib.
"""
def __init__(self, index: str = "document"):
"""
The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions.
Init the in memory knowledge graph
:param index: name of the index
"""
warnings.warn(
"The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
self.indexes: Dict[str, Graph] = defaultdict(dict) # type: ignore [arg-type]
self.index: str = index
def create_index(self, index: Optional[str] = None):
"""
Create a new index stored in memory
:param index: name of the index
"""
index = index or self.index
if index not in self.indexes:
self.indexes[index] = Graph()
else:
logger.warning("Index '%s' is already present.", index)
def delete_index(self, index: Optional[str] = None):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
"""
index = index or self.index
if index in self.indexes:
del self.indexes[index]
logger.info("Index '%s' deleted.", index)
def import_from_ttl_file(self, path: Path, index: Optional[str] = None):
"""
Load in memory an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file
:param path: path to a .ttl containing a knowledge graph
:param index: name of the index
"""
index = index or self.index
self.indexes[index].parse(path)
def get_all_triples(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored triples. Duplicates are not filtered.
:param index: name of the index
:return: all triples stored in the index
"""
sparql_query = "SELECT * WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def get_all_subjects(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored subjects. Duplicates are not filtered.
:param index: name of the index
:return: all subjects stored in the index
"""
sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def get_all_predicates(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored predicates. Duplicates are not filtered.
:param index: name of the index
:return: all predicates stored in the index
"""
sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def _create_document_field_map(self) -> Dict:
"""
There is no field mapping required
"""
return {}
def get_all_objects(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored objects. Duplicates are not filtered.
:param index: name of the index
:return: all objects stored in the index
"""
sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Execute a SPARQL query on the given in memory index
:param sparql_query: SPARQL query that shall be executed
:param index: name of the index
:return: query result
"""
index = index or self.index
raw_results = self.indexes[index].query(sparql_query)
if raw_results.askAnswer is not None:
return raw_results.askAnswer
else:
formatted_results = []
for b in raw_results.bindings:
formatted_result = {}
items = list(b.items())
for item in items:
type_ = item[0].toPython()[1:]
uri = item[1].toPython() # type: ignore [attr-defined]
formatted_result[type_] = {"type": "uri", "value": uri}
formatted_results.append(formatted_result)
return formatted_results

View File

@ -40,7 +40,6 @@ from haystack.nodes.retriever import (
FilterRetriever,
MultihopEmbeddingRetriever,
TfidfRetriever,
Text2SparqlRetriever,
TableTextRetriever,
MultiModalRetriever,
WebRetriever,

View File

@ -8,5 +8,4 @@ from haystack.nodes.retriever.dense import (
)
from haystack.nodes.retriever.multimodal import MultiModalRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.text2sparql import Text2SparqlRetriever
from haystack.nodes.retriever.web import WebRetriever

View File

@ -11,42 +11,12 @@ from haystack.schema import Document, MultiLabel
from haystack.errors import HaystackError, PipelineError
from haystack.nodes.base import BaseComponent
from haystack.telemetry import send_event
from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, FilterType
from haystack.document_stores.base import BaseDocumentStore, FilterType
logger = logging.getLogger(__name__)
class BaseGraphRetriever(BaseComponent):
"""
Base classfor knowledge graph retrievers.
"""
knowledge_graph: BaseKnowledgeGraph
outgoing_edges = 1
@abstractmethod
def retrieve(self, query: str, top_k: Optional[int] = None):
pass
@abstractmethod
def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None):
pass
def eval(self):
raise NotImplementedError
def run(self, query: str, top_k: Optional[int] = None): # type: ignore
answers = self.retrieve(query=query, top_k=top_k)
results = {"answers": answers}
return results, "output_1"
def run_batch(self, queries: List[str], top_k: Optional[int] = None): # type: ignore
answers = self.retrieve_batch(queries=queries, top_k=top_k)
results = {"answers": answers}
return results, "output_1"
class BaseRetriever(BaseComponent):
"""
Base class for regular retrievers.

View File

@ -1,150 +0,0 @@
from typing import Optional, List, Union
import warnings
import logging
from transformers import BartForConditionalGeneration, BartTokenizer
from haystack.document_stores import BaseKnowledgeGraph
from haystack.nodes.retriever.base import BaseGraphRetriever
logger = logging.getLogger(__name__)
class Text2SparqlRetriever(BaseGraphRetriever):
"""
Graph retriever that uses a pre-trained Bart model to translate natural language questions
given in text form to queries in SPARQL format.
The generated SPARQL query is executed on a knowledge graph.
"""
def __init__(
self,
knowledge_graph: BaseKnowledgeGraph,
model_name_or_path: Optional[str] = None,
model_version: Optional[str] = None,
top_k: int = 1,
use_auth_token: Optional[Union[str, bool]] = None,
):
"""
The Text2SparqlRetriever component is deprecated and will be removed in future versions.
Init the Retriever by providing a knowledge graph and a pre-trained BART model
:param knowledge_graph: An instance of BaseKnowledgeGraph on which to execute SPARQL queries.
:param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model.
:param model_version: The version of the model to use for entity extraction.
:param top_k: How many SPARQL queries to generate per text query.
:param use_auth_token: The API token used to download private models from Huggingface.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
Additional information can be found here
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
"""
warnings.warn(
"The Text2SparqlRetriever component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
self.knowledge_graph = knowledge_graph
# TODO We should extend this to any seq2seq models and use the AutoModel class
self.model = BartForConditionalGeneration.from_pretrained(
model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token, revision=model_version
)
self.tok = BartTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
self.top_k = top_k
def retrieve(self, query: str, top_k: Optional[int] = None):
"""
Translate a text query to SPARQL and execute it on the knowledge graph to retrieve a list of answers
:param query: Text query that shall be translated to SPARQL and then executed on the knowledge graph
:param top_k: How many SPARQL queries to generate per text query.
"""
if top_k is None:
top_k = self.top_k
inputs = self.tok([query], max_length=100, truncation=True, return_tensors="pt")
# generate top_k+2 SPARQL queries so that we can dismiss some queries with wrong syntax
temp = self.model.generate(
inputs["input_ids"], num_beams=5, max_length=100, num_return_sequences=top_k + 2, early_stopping=True
)
sparql_queries = [
self.tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in temp
]
answers = []
for sparql_query in sparql_queries:
ans, query = self._query_kg(sparql_query=sparql_query)
if len(ans) > 0:
answers.append((ans, query))
# if there are no answers we still want to return something
if len(answers) == 0:
answers.append(("", ""))
results = answers[:top_k]
results = [self.format_result(result) for result in results]
return results
def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None):
"""
Translate a list of queries to SPARQL and execute it on the knowledge graph to retrieve
a list of lists of answers (one per query).
:param queries: List of queries that shall be translated to SPARQL and then executed on the
knowledge graph.
:param top_k: How many SPARQL queries to generate per text query.
"""
# TODO: This method currently just calls the retrieve method multiple times, so there is room for improvement.
results = []
for query in queries:
cur_result = self.run(query=query, top_k=top_k)
results.append(cur_result)
return results
def _query_kg(self, sparql_query):
"""
Execute a single SPARQL query on the knowledge graph to retrieve an answer and unpack
different answer styles for boolean queries, count queries, and list queries.
:param sparql_query: SPARQL query that shall be executed on the knowledge graph
"""
try:
response = self.knowledge_graph.query(sparql_query=sparql_query)
# unpack different answer styles
if isinstance(response, list):
if len(response) == 0:
result = ""
else:
result = []
for x in response:
for v in x.values():
result.append(v["value"])
elif isinstance(response, bool):
result = str(response)
elif "count" in response[0]:
result = str(int(response[0]["count"]["value"]))
else:
result = ""
except Exception:
result = ""
return result, sparql_query
def format_result(self, result):
"""
Generate formatted dictionary output with text answer and additional info
:param result: The result of a SPARQL query as retrieved from the knowledge graph
"""
query = result[1]
prediction = result[0]
prediction_meta = {"model": self.__class__.__name__, "sparql_query": query}
return {"answer": prediction, "prediction_meta": prediction_meta}
def eval(self):
raise NotImplementedError

View File

@ -1,22 +0,0 @@
import pytest
from haystack.document_stores.graphdb import GraphDBKnowledgeGraph
from ..conftest import fail_at_version
@pytest.mark.unit
@fail_at_version(1, 17)
def test_graphdb_knowledge_graph_deprecation_warning():
with pytest.warns(DeprecationWarning) as w:
GraphDBKnowledgeGraph()
assert len(w) == 2
assert (
w[0].message.args[0]
== "The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions."
)
assert (
w[1].message.args[0]
== "The BaseKnowledgeGraph component is deprecated and will be removed in future versions."
)

View File

@ -1,22 +0,0 @@
import pytest
from haystack.document_stores.memory_knowledgegraph import InMemoryKnowledgeGraph
from ..conftest import fail_at_version
@pytest.mark.unit
@fail_at_version(1, 17)
def test_in_memory_knowledge_graph_deprecation_warning():
with pytest.warns(DeprecationWarning) as w:
InMemoryKnowledgeGraph()
assert len(w) == 2
assert (
w[0].message.args[0]
== "The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions."
)
assert (
w[1].message.args[0]
== "The BaseKnowledgeGraph component is deprecated and will be removed in future versions."
)

View File

@ -27,7 +27,6 @@ from haystack.document_stores import WeaviateDocumentStore
from haystack.nodes.retriever.base import BaseRetriever
from haystack.nodes.retriever.web import WebRetriever
from haystack.nodes.search_engine import WebSearch
from haystack.nodes.retriever import Text2SparqlRetriever
from haystack.pipelines import DocumentSearchPipeline
from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
@ -1271,21 +1270,3 @@ def test_web_retriever_mode_snippets(monkeypatch):
web_retriever = WebRetriever(api_key="", top_search_results=2)
result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
assert result == expected_search_results["documents"]
@fail_at_version(1, 17)
def test_text_2_sparql_retriever_deprecation():
BartForConditionalGeneration = object()
BartTokenizer = object()
with patch.multiple(
"haystack.nodes.retriever.text2sparql", BartForConditionalGeneration=DEFAULT, BartTokenizer=DEFAULT
):
knowledge_graph = Mock()
with pytest.warns(DeprecationWarning) as w:
Text2SparqlRetriever(knowledge_graph)
assert len(w) == 1
assert (
w[0].message.args[0]
== "The Text2SparqlRetriever component is deprecated and will be removed in future versions."
)