chore: remove BaseKnowledgeGraph (#4953)

* remove BaseKnowledgeGraph

* fix pylint
This commit is contained in:
Massimiliano Pippi 2023-05-21 10:42:02 +02:00 committed by GitHub
parent 5321d91f97
commit c6ea542b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 2 additions and 738 deletions

View File

@ -1,107 +0,0 @@
from pathlib import Path
from haystack.nodes import Text2SparqlRetriever
from haystack.document_stores import GraphDBKnowledgeGraph, InMemoryKnowledgeGraph
from haystack.utils import fetch_archive_from_http
def test_graph_retrieval():
# we use a timeout double the default in the CI to account for slow runners
timeout = 20
# TODO rename doc_dir
graph_dir = "../data/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip"
fetch_archive_from_http(url=s3_url, output_dir=graph_dir)
# Fetch a pre-trained BART model that translates natural language questions to SPARQL queries
model_dir = "../saved_models/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip"
fetch_archive_from_http(url=s3_url, output_dir=model_dir)
kg = GraphDBKnowledgeGraph(index="tutorial_10_index")
kg.delete_index(timeout=timeout)
kg.create_index(config_path=Path(graph_dir + "repo-config.ttl"), timeout=timeout)
kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl"), timeout=timeout)
triple = {
"p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"},
"s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"},
"o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"},
}
triples = kg.get_all_triples()
assert len(triples) > 0
assert triple in triples
# Define prefixes for names of resources so that we can use shorter resource names in queries
prefixes = """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
PREFIX hp: <https://deepset.ai/harry_potter/>
"""
kg.prefixes = prefixes
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
result = kgqa_retriever.retrieve(query="In which house is Harry Potter?")
assert result[0] == {
"answer": ["https://deepset.ai/harry_potter/Gryffindor"],
"prediction_meta": {
"model": "Text2SparqlRetriever",
"sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }",
},
}
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid"
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?obj where { <https://deepset.ai/harry_potter/Hermione_granger> <https://deepset.ai/harry_potter/patronus> ?obj . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Otter"
def test_inmemory_graph_retrieval():
# TODO rename doc_dir
graph_dir = "../data/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip"
fetch_archive_from_http(url=s3_url, output_dir=graph_dir)
# Fetch a pre-trained BART model that translates natural language questions to SPARQL queries
model_dir = "../saved_models/tutorial10_knowledge_graph/"
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip"
fetch_archive_from_http(url=s3_url, output_dir=model_dir)
kg = InMemoryKnowledgeGraph(index="tutorial_10_index")
kg.delete_index()
kg.create_index()
kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl"))
triple = {
"p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"},
"s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"},
"o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"},
}
triples = kg.get_all_triples()
assert len(triples) > 0
assert triple in triples
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
result = kgqa_retriever.retrieve(query="In which house is Harry Potter?")
assert result[0] == {
"answer": ["https://deepset.ai/harry_potter/Gryffindor"],
"prediction_meta": {
"model": "Text2SparqlRetriever",
"sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }",
},
}
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid"
result = kgqa_retriever._query_kg(
sparql_query="select distinct ?obj where { <https://deepset.ai/harry_potter/Hermione_granger> <https://deepset.ai/harry_potter/patronus> ?obj . }"
)
assert result[0][0] == "https://deepset.ai/harry_potter/Otter"

View File

@ -1,5 +1,5 @@
from haystack.utils.import_utils import safe_import from haystack.utils.import_utils import safe_import
from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, KeywordDocumentStore from haystack.document_stores.base import BaseDocumentStore, KeywordDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore
@ -19,7 +19,3 @@ SQLDocumentStore = safe_import("haystack.document_stores.sql", "SQLDocumentStore
FAISSDocumentStore = safe_import("haystack.document_stores.faiss", "FAISSDocumentStore", "faiss") FAISSDocumentStore = safe_import("haystack.document_stores.faiss", "FAISSDocumentStore", "faiss")
PineconeDocumentStore = safe_import("haystack.document_stores.pinecone", "PineconeDocumentStore", "pinecone") PineconeDocumentStore = safe_import("haystack.document_stores.pinecone", "PineconeDocumentStore", "pinecone")
WeaviateDocumentStore = safe_import("haystack.document_stores.weaviate", "WeaviateDocumentStore", "weaviate") WeaviateDocumentStore = safe_import("haystack.document_stores.weaviate", "WeaviateDocumentStore", "weaviate")
GraphDBKnowledgeGraph = safe_import("haystack.document_stores.graphdb", "GraphDBKnowledgeGraph", "graphdb")
InMemoryKnowledgeGraph = safe_import(
"haystack.document_stores.memory_knowledgegraph", "InMemoryKnowledgeGraph", "inmemorygraph"
)

View File

@ -2,7 +2,6 @@
from typing import Generator, Optional, Dict, List, Set, Union, Any from typing import Generator, Optional, Dict, List, Set, Union, Any
import warnings
import logging import logging
import collections import collections
from pathlib import Path from pathlib import Path
@ -32,35 +31,6 @@ except (ImportError, ModuleNotFoundError):
return f return f
class BaseKnowledgeGraph(BaseComponent):
"""
Base class for implementing Knowledge Graphs.
The BaseKnowledgeGraph component is deprecated and will be removed in future versions.
"""
def __init__(self):
warnings.warn(
"The BaseKnowledgeGraph component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
outgoing_edges = 1
def run(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): # type: ignore
result = self.query(sparql_query=sparql_query, index=index, headers=headers)
output = {"sparql_result": result}
return output, "output_1"
def run_batch(self):
raise NotImplementedError("run_batch is not implemented for KnowledgeGraphs.")
@abstractmethod
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
raise NotImplementedError
class BaseDocumentStore(BaseComponent): class BaseDocumentStore(BaseComponent):
""" """
Base class for implementing Document Stores. Base class for implementing Document Stores.

View File

@ -1,206 +0,0 @@
from typing import Dict, Optional, Union, Tuple
import warnings
from pathlib import Path
import requests
from requests.auth import HTTPBasicAuth
try:
from SPARQLWrapper import SPARQLWrapper, JSON
except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed
_optional_component_not_installed(__name__, "graphdb", ie)
from haystack.document_stores import BaseKnowledgeGraph
class GraphDBKnowledgeGraph(BaseKnowledgeGraph):
"""
Knowledge graph store that runs on a GraphDB instance.
"""
def __init__(
self,
host: str = "localhost",
port: int = 7200,
username: str = "",
password: str = "",
index: Optional[str] = None,
prefixes: str = "",
):
"""
The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions.
Init the knowledge graph by defining the settings to connect with a GraphDB instance
:param host: address of server where the GraphDB instance is running
:param port: port where the GraphDB instance is running
:param username: username to login to the GraphDB instance (if any)
:param password: password to login to the GraphDB instance (if any)
:param index: name of the index (also called repository) stored in the GraphDB instance
:param prefixes: definitions of namespaces with a new line after each namespace, e.g., PREFIX hp: <https://deepset.ai/harry_potter/>
"""
warnings.warn(
"The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
self.url = f"http://{host}:{port}"
self.index = index
self.username = username
self.password = password
self.prefixes = prefixes
def create_index(
self,
config_path: Path,
headers: Optional[Dict[str, str]] = None,
timeout: Union[float, Tuple[float, float]] = 10.0,
):
"""
Create a new index (also called repository) stored in the GraphDB instance
:param config_path: path to a .ttl file with configuration settings, details:
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:param timeout: How many seconds to wait for the server to send data before giving up,
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
https://graphdb.ontotext.com/documentation/free/configuring-a-repository.html#configure-a-repository-programmatically
"""
url = f"{self.url}/rest/repositories"
files = {"config": open(config_path, "r", encoding="utf-8")}
response = requests.post(url, files=files, headers=headers, timeout=timeout)
if response.status_code > 299:
raise Exception(response.text)
def delete_index(self, headers: Optional[Dict[str, str]] = None, timeout: Union[float, Tuple[float, float]] = 10.0):
"""
Delete the index that GraphDBKnowledgeGraph is connected to. This method deletes all data stored in the index.
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:param timeout: How many seconds to wait for the server to send data before giving up,
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
"""
url = f"{self.url}/rest/repositories/{self.index}"
response = requests.delete(url, headers=headers, timeout=timeout)
if response.status_code > 299:
raise Exception(response.text)
def import_from_ttl_file(
self,
index: str,
path: Path,
headers: Optional[Dict[str, str]] = None,
timeout: Union[float, Tuple[float, float]] = 10.0,
):
"""
Load an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file into an index of GraphDB
:param index: name of the index (also called repository) in the GraphDB instance where the imported triples shall be stored
:param path: path to a .ttl containing a knowledge graph
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:param timeout: How many seconds to wait for the server to send data before giving up,
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
"""
url = f"{self.url}/repositories/{index}/statements"
headers = (
{"Content-type": "application/x-turtle"}
if headers is None
else {**{"Content-type": "application/x-turtle"}, **headers}
)
response = requests.post(
url,
headers=headers,
data=open(path, "r", encoding="utf-8").read().encode("utf-8"),
auth=HTTPBasicAuth(self.username, self.password),
timeout=timeout,
)
if response.status_code > 299:
raise Exception(response.text)
def get_all_triples(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored triples. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all triples stored in the index
"""
sparql_query = "SELECT * WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def get_all_subjects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored subjects. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all subjects stored in the index
"""
sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def get_all_predicates(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored predicates. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all predicates stored in the index
"""
sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def _create_document_field_map(self) -> Dict:
"""
There is no field mapping required
"""
return {}
def get_all_objects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Query the given index in the GraphDB instance for all its stored objects. Duplicates are not filtered.
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: all objects stored in the index
"""
sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
return results
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Execute a SPARQL query on the given index in the GraphDB instance
:param sparql_query: SPARQL query that shall be executed
:param index: name of the index (also called repository) in the GraphDB instance
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
:return: query result
"""
if self.index is None and index is None:
raise Exception("Index name is required")
index = index or self.index
sparql = SPARQLWrapper(f"{self.url}/repositories/{index}")
sparql.setCredentials(self.username, self.password)
sparql.setQuery(self.prefixes + sparql_query)
sparql.setReturnFormat(JSON)
if headers is not None:
sparql.customHttpHeaders = headers
results = sparql.query().convert()
# if query is a boolean query, return boolean instead of text result
# FIXME: 'results' likely doesn't support membership test (`"something" in results`).
# Pylint raises unsupported-membership-test and unsubscriptable-object.
# Silenced for now, keep in mind for future debugging.
return (
results["results"]["bindings"] # type: ignore # pylint: disable=unsubscriptable-object
if "results" in results # type: ignore # pylint: disable=unsupported-membership-test
else results["boolean"] # type: ignore # pylint: disable=unsubscriptable-object
)

View File

@ -1,144 +0,0 @@
from typing import Dict, Optional
import warnings
import logging
from collections import defaultdict
from pathlib import Path
from rdflib import Graph
from haystack.document_stores import BaseKnowledgeGraph
logger = logging.getLogger(__name__)
class InMemoryKnowledgeGraph(BaseKnowledgeGraph):
"""
In memory Knowledge graph store, based on rdflib.
"""
def __init__(self, index: str = "document"):
"""
The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions.
Init the in memory knowledge graph
:param index: name of the index
"""
warnings.warn(
"The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
self.indexes: Dict[str, Graph] = defaultdict(dict) # type: ignore [arg-type]
self.index: str = index
def create_index(self, index: Optional[str] = None):
"""
Create a new index stored in memory
:param index: name of the index
"""
index = index or self.index
if index not in self.indexes:
self.indexes[index] = Graph()
else:
logger.warning("Index '%s' is already present.", index)
def delete_index(self, index: Optional[str] = None):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
"""
index = index or self.index
if index in self.indexes:
del self.indexes[index]
logger.info("Index '%s' deleted.", index)
def import_from_ttl_file(self, path: Path, index: Optional[str] = None):
"""
Load in memory an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file
:param path: path to a .ttl containing a knowledge graph
:param index: name of the index
"""
index = index or self.index
self.indexes[index].parse(path)
def get_all_triples(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored triples. Duplicates are not filtered.
:param index: name of the index
:return: all triples stored in the index
"""
sparql_query = "SELECT * WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def get_all_subjects(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored subjects. Duplicates are not filtered.
:param index: name of the index
:return: all subjects stored in the index
"""
sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def get_all_predicates(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored predicates. Duplicates are not filtered.
:param index: name of the index
:return: all predicates stored in the index
"""
sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def _create_document_field_map(self) -> Dict:
"""
There is no field mapping required
"""
return {}
def get_all_objects(self, index: Optional[str] = None):
"""
Query the given in memory index for all its stored objects. Duplicates are not filtered.
:param index: name of the index
:return: all objects stored in the index
"""
sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }"
results = self.query(sparql_query=sparql_query, index=index)
return results
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
"""
Execute a SPARQL query on the given in memory index
:param sparql_query: SPARQL query that shall be executed
:param index: name of the index
:return: query result
"""
index = index or self.index
raw_results = self.indexes[index].query(sparql_query)
if raw_results.askAnswer is not None:
return raw_results.askAnswer
else:
formatted_results = []
for b in raw_results.bindings:
formatted_result = {}
items = list(b.items())
for item in items:
type_ = item[0].toPython()[1:]
uri = item[1].toPython() # type: ignore [attr-defined]
formatted_result[type_] = {"type": "uri", "value": uri}
formatted_results.append(formatted_result)
return formatted_results

View File

@ -40,7 +40,6 @@ from haystack.nodes.retriever import (
FilterRetriever, FilterRetriever,
MultihopEmbeddingRetriever, MultihopEmbeddingRetriever,
TfidfRetriever, TfidfRetriever,
Text2SparqlRetriever,
TableTextRetriever, TableTextRetriever,
MultiModalRetriever, MultiModalRetriever,
WebRetriever, WebRetriever,

View File

@ -8,5 +8,4 @@ from haystack.nodes.retriever.dense import (
) )
from haystack.nodes.retriever.multimodal import MultiModalRetriever from haystack.nodes.retriever.multimodal import MultiModalRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.text2sparql import Text2SparqlRetriever
from haystack.nodes.retriever.web import WebRetriever from haystack.nodes.retriever.web import WebRetriever

View File

@ -11,42 +11,12 @@ from haystack.schema import Document, MultiLabel
from haystack.errors import HaystackError, PipelineError from haystack.errors import HaystackError, PipelineError
from haystack.nodes.base import BaseComponent from haystack.nodes.base import BaseComponent
from haystack.telemetry import send_event from haystack.telemetry import send_event
from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, FilterType from haystack.document_stores.base import BaseDocumentStore, FilterType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseGraphRetriever(BaseComponent):
"""
Base classfor knowledge graph retrievers.
"""
knowledge_graph: BaseKnowledgeGraph
outgoing_edges = 1
@abstractmethod
def retrieve(self, query: str, top_k: Optional[int] = None):
pass
@abstractmethod
def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None):
pass
def eval(self):
raise NotImplementedError
def run(self, query: str, top_k: Optional[int] = None): # type: ignore
answers = self.retrieve(query=query, top_k=top_k)
results = {"answers": answers}
return results, "output_1"
def run_batch(self, queries: List[str], top_k: Optional[int] = None): # type: ignore
answers = self.retrieve_batch(queries=queries, top_k=top_k)
results = {"answers": answers}
return results, "output_1"
class BaseRetriever(BaseComponent): class BaseRetriever(BaseComponent):
""" """
Base class for regular retrievers. Base class for regular retrievers.

View File

@ -1,150 +0,0 @@
from typing import Optional, List, Union
import warnings
import logging
from transformers import BartForConditionalGeneration, BartTokenizer
from haystack.document_stores import BaseKnowledgeGraph
from haystack.nodes.retriever.base import BaseGraphRetriever
logger = logging.getLogger(__name__)
class Text2SparqlRetriever(BaseGraphRetriever):
"""
Graph retriever that uses a pre-trained Bart model to translate natural language questions
given in text form to queries in SPARQL format.
The generated SPARQL query is executed on a knowledge graph.
"""
def __init__(
self,
knowledge_graph: BaseKnowledgeGraph,
model_name_or_path: Optional[str] = None,
model_version: Optional[str] = None,
top_k: int = 1,
use_auth_token: Optional[Union[str, bool]] = None,
):
"""
The Text2SparqlRetriever component is deprecated and will be removed in future versions.
Init the Retriever by providing a knowledge graph and a pre-trained BART model
:param knowledge_graph: An instance of BaseKnowledgeGraph on which to execute SPARQL queries.
:param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model.
:param model_version: The version of the model to use for entity extraction.
:param top_k: How many SPARQL queries to generate per text query.
:param use_auth_token: The API token used to download private models from Huggingface.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
Additional information can be found here
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
"""
warnings.warn(
"The Text2SparqlRetriever component is deprecated and will be removed in future versions.",
category=DeprecationWarning,
)
super().__init__()
self.knowledge_graph = knowledge_graph
# TODO We should extend this to any seq2seq models and use the AutoModel class
self.model = BartForConditionalGeneration.from_pretrained(
model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token, revision=model_version
)
self.tok = BartTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
self.top_k = top_k
def retrieve(self, query: str, top_k: Optional[int] = None):
"""
Translate a text query to SPARQL and execute it on the knowledge graph to retrieve a list of answers
:param query: Text query that shall be translated to SPARQL and then executed on the knowledge graph
:param top_k: How many SPARQL queries to generate per text query.
"""
if top_k is None:
top_k = self.top_k
inputs = self.tok([query], max_length=100, truncation=True, return_tensors="pt")
# generate top_k+2 SPARQL queries so that we can dismiss some queries with wrong syntax
temp = self.model.generate(
inputs["input_ids"], num_beams=5, max_length=100, num_return_sequences=top_k + 2, early_stopping=True
)
sparql_queries = [
self.tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in temp
]
answers = []
for sparql_query in sparql_queries:
ans, query = self._query_kg(sparql_query=sparql_query)
if len(ans) > 0:
answers.append((ans, query))
# if there are no answers we still want to return something
if len(answers) == 0:
answers.append(("", ""))
results = answers[:top_k]
results = [self.format_result(result) for result in results]
return results
def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None):
"""
Translate a list of queries to SPARQL and execute it on the knowledge graph to retrieve
a list of lists of answers (one per query).
:param queries: List of queries that shall be translated to SPARQL and then executed on the
knowledge graph.
:param top_k: How many SPARQL queries to generate per text query.
"""
# TODO: This method currently just calls the retrieve method multiple times, so there is room for improvement.
results = []
for query in queries:
cur_result = self.run(query=query, top_k=top_k)
results.append(cur_result)
return results
def _query_kg(self, sparql_query):
"""
Execute a single SPARQL query on the knowledge graph to retrieve an answer and unpack
different answer styles for boolean queries, count queries, and list queries.
:param sparql_query: SPARQL query that shall be executed on the knowledge graph
"""
try:
response = self.knowledge_graph.query(sparql_query=sparql_query)
# unpack different answer styles
if isinstance(response, list):
if len(response) == 0:
result = ""
else:
result = []
for x in response:
for v in x.values():
result.append(v["value"])
elif isinstance(response, bool):
result = str(response)
elif "count" in response[0]:
result = str(int(response[0]["count"]["value"]))
else:
result = ""
except Exception:
result = ""
return result, sparql_query
def format_result(self, result):
"""
Generate formatted dictionary output with text answer and additional info
:param result: The result of a SPARQL query as retrieved from the knowledge graph
"""
query = result[1]
prediction = result[0]
prediction_meta = {"model": self.__class__.__name__, "sparql_query": query}
return {"answer": prediction, "prediction_meta": prediction_meta}
def eval(self):
raise NotImplementedError

View File

@ -1,22 +0,0 @@
import pytest
from haystack.document_stores.graphdb import GraphDBKnowledgeGraph
from ..conftest import fail_at_version
@pytest.mark.unit
@fail_at_version(1, 17)
def test_graphdb_knowledge_graph_deprecation_warning():
with pytest.warns(DeprecationWarning) as w:
GraphDBKnowledgeGraph()
assert len(w) == 2
assert (
w[0].message.args[0]
== "The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions."
)
assert (
w[1].message.args[0]
== "The BaseKnowledgeGraph component is deprecated and will be removed in future versions."
)

View File

@ -1,22 +0,0 @@
import pytest
from haystack.document_stores.memory_knowledgegraph import InMemoryKnowledgeGraph
from ..conftest import fail_at_version
@pytest.mark.unit
@fail_at_version(1, 17)
def test_in_memory_knowledge_graph_deprecation_warning():
with pytest.warns(DeprecationWarning) as w:
InMemoryKnowledgeGraph()
assert len(w) == 2
assert (
w[0].message.args[0]
== "The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions."
)
assert (
w[1].message.args[0]
== "The BaseKnowledgeGraph component is deprecated and will be removed in future versions."
)

View File

@ -27,7 +27,6 @@ from haystack.document_stores import WeaviateDocumentStore
from haystack.nodes.retriever.base import BaseRetriever from haystack.nodes.retriever.base import BaseRetriever
from haystack.nodes.retriever.web import WebRetriever from haystack.nodes.retriever.web import WebRetriever
from haystack.nodes.search_engine import WebSearch from haystack.nodes.search_engine import WebSearch
from haystack.nodes.retriever import Text2SparqlRetriever
from haystack.pipelines import DocumentSearchPipeline from haystack.pipelines import DocumentSearchPipeline
from haystack.schema import Document from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
@ -1271,21 +1270,3 @@ def test_web_retriever_mode_snippets(monkeypatch):
web_retriever = WebRetriever(api_key="", top_search_results=2) web_retriever = WebRetriever(api_key="", top_search_results=2)
result = web_retriever.retrieve(query="Who is the father of Arya Stark?") result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
assert result == expected_search_results["documents"] assert result == expected_search_results["documents"]
@fail_at_version(1, 17)
def test_text_2_sparql_retriever_deprecation():
BartForConditionalGeneration = object()
BartTokenizer = object()
with patch.multiple(
"haystack.nodes.retriever.text2sparql", BartForConditionalGeneration=DEFAULT, BartTokenizer=DEFAULT
):
knowledge_graph = Mock()
with pytest.warns(DeprecationWarning) as w:
Text2SparqlRetriever(knowledge_graph)
assert len(w) == 1
assert (
w[0].message.args[0]
== "The Text2SparqlRetriever component is deprecated and will be removed in future versions."
)