mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
chore: remove BaseKnowledgeGraph (#4953)
* remove BaseKnowledgeGraph * fix pylint
This commit is contained in:
parent
5321d91f97
commit
c6ea542b57
@ -1,107 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.nodes import Text2SparqlRetriever
|
||||
from haystack.document_stores import GraphDBKnowledgeGraph, InMemoryKnowledgeGraph
|
||||
from haystack.utils import fetch_archive_from_http
|
||||
|
||||
|
||||
def test_graph_retrieval():
|
||||
# we use a timeout double the default in the CI to account for slow runners
|
||||
timeout = 20
|
||||
|
||||
# TODO rename doc_dir
|
||||
graph_dir = "../data/tutorial10_knowledge_graph/"
|
||||
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=graph_dir)
|
||||
|
||||
# Fetch a pre-trained BART model that translates natural language questions to SPARQL queries
|
||||
model_dir = "../saved_models/tutorial10_knowledge_graph/"
|
||||
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=model_dir)
|
||||
|
||||
kg = GraphDBKnowledgeGraph(index="tutorial_10_index")
|
||||
kg.delete_index(timeout=timeout)
|
||||
kg.create_index(config_path=Path(graph_dir + "repo-config.ttl"), timeout=timeout)
|
||||
kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl"), timeout=timeout)
|
||||
triple = {
|
||||
"p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"},
|
||||
"s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"},
|
||||
"o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"},
|
||||
}
|
||||
triples = kg.get_all_triples()
|
||||
assert len(triples) > 0
|
||||
assert triple in triples
|
||||
|
||||
# Define prefixes for names of resources so that we can use shorter resource names in queries
|
||||
prefixes = """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
|
||||
PREFIX hp: <https://deepset.ai/harry_potter/>
|
||||
"""
|
||||
kg.prefixes = prefixes
|
||||
|
||||
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
|
||||
|
||||
result = kgqa_retriever.retrieve(query="In which house is Harry Potter?")
|
||||
assert result[0] == {
|
||||
"answer": ["https://deepset.ai/harry_potter/Gryffindor"],
|
||||
"prediction_meta": {
|
||||
"model": "Text2SparqlRetriever",
|
||||
"sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }",
|
||||
},
|
||||
}
|
||||
|
||||
result = kgqa_retriever._query_kg(
|
||||
sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }"
|
||||
)
|
||||
assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid"
|
||||
|
||||
result = kgqa_retriever._query_kg(
|
||||
sparql_query="select distinct ?obj where { <https://deepset.ai/harry_potter/Hermione_granger> <https://deepset.ai/harry_potter/patronus> ?obj . }"
|
||||
)
|
||||
assert result[0][0] == "https://deepset.ai/harry_potter/Otter"
|
||||
|
||||
|
||||
def test_inmemory_graph_retrieval():
|
||||
# TODO rename doc_dir
|
||||
graph_dir = "../data/tutorial10_knowledge_graph/"
|
||||
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=graph_dir)
|
||||
|
||||
# Fetch a pre-trained BART model that translates natural language questions to SPARQL queries
|
||||
model_dir = "../saved_models/tutorial10_knowledge_graph/"
|
||||
s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=model_dir)
|
||||
|
||||
kg = InMemoryKnowledgeGraph(index="tutorial_10_index")
|
||||
kg.delete_index()
|
||||
kg.create_index()
|
||||
kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl"))
|
||||
triple = {
|
||||
"p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"},
|
||||
"s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"},
|
||||
"o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"},
|
||||
}
|
||||
triples = kg.get_all_triples()
|
||||
assert len(triples) > 0
|
||||
assert triple in triples
|
||||
|
||||
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
|
||||
|
||||
result = kgqa_retriever.retrieve(query="In which house is Harry Potter?")
|
||||
assert result[0] == {
|
||||
"answer": ["https://deepset.ai/harry_potter/Gryffindor"],
|
||||
"prediction_meta": {
|
||||
"model": "Text2SparqlRetriever",
|
||||
"sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }",
|
||||
},
|
||||
}
|
||||
|
||||
result = kgqa_retriever._query_kg(
|
||||
sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }"
|
||||
)
|
||||
assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid"
|
||||
|
||||
result = kgqa_retriever._query_kg(
|
||||
sparql_query="select distinct ?obj where { <https://deepset.ai/harry_potter/Hermione_granger> <https://deepset.ai/harry_potter/patronus> ?obj . }"
|
||||
)
|
||||
assert result[0][0] == "https://deepset.ai/harry_potter/Otter"
|
||||
@ -1,5 +1,5 @@
|
||||
from haystack.utils.import_utils import safe_import
|
||||
from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, KeywordDocumentStore
|
||||
from haystack.document_stores.base import BaseDocumentStore, KeywordDocumentStore
|
||||
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore
|
||||
@ -19,7 +19,3 @@ SQLDocumentStore = safe_import("haystack.document_stores.sql", "SQLDocumentStore
|
||||
FAISSDocumentStore = safe_import("haystack.document_stores.faiss", "FAISSDocumentStore", "faiss")
|
||||
PineconeDocumentStore = safe_import("haystack.document_stores.pinecone", "PineconeDocumentStore", "pinecone")
|
||||
WeaviateDocumentStore = safe_import("haystack.document_stores.weaviate", "WeaviateDocumentStore", "weaviate")
|
||||
GraphDBKnowledgeGraph = safe_import("haystack.document_stores.graphdb", "GraphDBKnowledgeGraph", "graphdb")
|
||||
InMemoryKnowledgeGraph = safe_import(
|
||||
"haystack.document_stores.memory_knowledgegraph", "InMemoryKnowledgeGraph", "inmemorygraph"
|
||||
)
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Generator, Optional, Dict, List, Set, Union, Any
|
||||
|
||||
import warnings
|
||||
import logging
|
||||
import collections
|
||||
from pathlib import Path
|
||||
@ -32,35 +31,6 @@ except (ImportError, ModuleNotFoundError):
|
||||
return f
|
||||
|
||||
|
||||
class BaseKnowledgeGraph(BaseComponent):
|
||||
"""
|
||||
Base class for implementing Knowledge Graphs.
|
||||
|
||||
The BaseKnowledgeGraph component is deprecated and will be removed in future versions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
warnings.warn(
|
||||
"The BaseKnowledgeGraph component is deprecated and will be removed in future versions.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def run(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): # type: ignore
|
||||
result = self.query(sparql_query=sparql_query, index=index, headers=headers)
|
||||
output = {"sparql_result": result}
|
||||
return output, "output_1"
|
||||
|
||||
def run_batch(self):
|
||||
raise NotImplementedError("run_batch is not implemented for KnowledgeGraphs.")
|
||||
|
||||
@abstractmethod
|
||||
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseDocumentStore(BaseComponent):
|
||||
"""
|
||||
Base class for implementing Document Stores.
|
||||
|
||||
@ -1,206 +0,0 @@
|
||||
from typing import Dict, Optional, Union, Tuple
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
try:
|
||||
from SPARQLWrapper import SPARQLWrapper, JSON
|
||||
except (ImportError, ModuleNotFoundError) as ie:
|
||||
from haystack.utils.import_utils import _optional_component_not_installed
|
||||
|
||||
_optional_component_not_installed(__name__, "graphdb", ie)
|
||||
|
||||
from haystack.document_stores import BaseKnowledgeGraph
|
||||
|
||||
|
||||
class GraphDBKnowledgeGraph(BaseKnowledgeGraph):
|
||||
"""
|
||||
Knowledge graph store that runs on a GraphDB instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 7200,
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
index: Optional[str] = None,
|
||||
prefixes: str = "",
|
||||
):
|
||||
"""
|
||||
The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions.
|
||||
|
||||
Init the knowledge graph by defining the settings to connect with a GraphDB instance
|
||||
|
||||
:param host: address of server where the GraphDB instance is running
|
||||
:param port: port where the GraphDB instance is running
|
||||
:param username: username to login to the GraphDB instance (if any)
|
||||
:param password: password to login to the GraphDB instance (if any)
|
||||
:param index: name of the index (also called repository) stored in the GraphDB instance
|
||||
:param prefixes: definitions of namespaces with a new line after each namespace, e.g., PREFIX hp: <https://deepset.ai/harry_potter/>
|
||||
"""
|
||||
warnings.warn(
|
||||
"The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.url = f"http://{host}:{port}"
|
||||
self.index = index
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.prefixes = prefixes
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
config_path: Path,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Union[float, Tuple[float, float]] = 10.0,
|
||||
):
|
||||
"""
|
||||
Create a new index (also called repository) stored in the GraphDB instance
|
||||
|
||||
:param config_path: path to a .ttl file with configuration settings, details:
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:param timeout: How many seconds to wait for the server to send data before giving up,
|
||||
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
|
||||
Defaults to 10 seconds.
|
||||
https://graphdb.ontotext.com/documentation/free/configuring-a-repository.html#configure-a-repository-programmatically
|
||||
"""
|
||||
url = f"{self.url}/rest/repositories"
|
||||
files = {"config": open(config_path, "r", encoding="utf-8")}
|
||||
response = requests.post(url, files=files, headers=headers, timeout=timeout)
|
||||
if response.status_code > 299:
|
||||
raise Exception(response.text)
|
||||
|
||||
def delete_index(self, headers: Optional[Dict[str, str]] = None, timeout: Union[float, Tuple[float, float]] = 10.0):
|
||||
"""
|
||||
Delete the index that GraphDBKnowledgeGraph is connected to. This method deletes all data stored in the index.
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:param timeout: How many seconds to wait for the server to send data before giving up,
|
||||
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
|
||||
Defaults to 10 seconds.
|
||||
"""
|
||||
url = f"{self.url}/rest/repositories/{self.index}"
|
||||
response = requests.delete(url, headers=headers, timeout=timeout)
|
||||
if response.status_code > 299:
|
||||
raise Exception(response.text)
|
||||
|
||||
def import_from_ttl_file(
|
||||
self,
|
||||
index: str,
|
||||
path: Path,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Union[float, Tuple[float, float]] = 10.0,
|
||||
):
|
||||
"""
|
||||
Load an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file into an index of GraphDB
|
||||
|
||||
:param index: name of the index (also called repository) in the GraphDB instance where the imported triples shall be stored
|
||||
:param path: path to a .ttl containing a knowledge graph
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:param timeout: How many seconds to wait for the server to send data before giving up,
|
||||
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
|
||||
Defaults to 10 seconds.
|
||||
"""
|
||||
url = f"{self.url}/repositories/{index}/statements"
|
||||
headers = (
|
||||
{"Content-type": "application/x-turtle"}
|
||||
if headers is None
|
||||
else {**{"Content-type": "application/x-turtle"}, **headers}
|
||||
)
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data=open(path, "r", encoding="utf-8").read().encode("utf-8"),
|
||||
auth=HTTPBasicAuth(self.username, self.password),
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code > 299:
|
||||
raise Exception(response.text)
|
||||
|
||||
def get_all_triples(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Query the given index in the GraphDB instance for all its stored triples. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index (also called repository) in the GraphDB instance
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:return: all triples stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT * WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
|
||||
return results
|
||||
|
||||
def get_all_subjects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Query the given index in the GraphDB instance for all its stored subjects. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index (also called repository) in the GraphDB instance
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:return: all subjects stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
|
||||
return results
|
||||
|
||||
def get_all_predicates(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Query the given index in the GraphDB instance for all its stored predicates. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index (also called repository) in the GraphDB instance
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:return: all predicates stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
|
||||
return results
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
"""
|
||||
There is no field mapping required
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_all_objects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Query the given index in the GraphDB instance for all its stored objects. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index (also called repository) in the GraphDB instance
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:return: all objects stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index, headers=headers)
|
||||
return results
|
||||
|
||||
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Execute a SPARQL query on the given index in the GraphDB instance
|
||||
|
||||
:param sparql_query: SPARQL query that shall be executed
|
||||
:param index: name of the index (also called repository) in the GraphDB instance
|
||||
:param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
:return: query result
|
||||
"""
|
||||
if self.index is None and index is None:
|
||||
raise Exception("Index name is required")
|
||||
index = index or self.index
|
||||
sparql = SPARQLWrapper(f"{self.url}/repositories/{index}")
|
||||
sparql.setCredentials(self.username, self.password)
|
||||
sparql.setQuery(self.prefixes + sparql_query)
|
||||
sparql.setReturnFormat(JSON)
|
||||
if headers is not None:
|
||||
sparql.customHttpHeaders = headers
|
||||
results = sparql.query().convert()
|
||||
# if query is a boolean query, return boolean instead of text result
|
||||
# FIXME: 'results' likely doesn't support membership test (`"something" in results`).
|
||||
# Pylint raises unsupported-membership-test and unsubscriptable-object.
|
||||
# Silenced for now, keep in mind for future debugging.
|
||||
return (
|
||||
results["results"]["bindings"] # type: ignore # pylint: disable=unsubscriptable-object
|
||||
if "results" in results # type: ignore # pylint: disable=unsupported-membership-test
|
||||
else results["boolean"] # type: ignore # pylint: disable=unsubscriptable-object
|
||||
)
|
||||
@ -1,144 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import warnings
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from rdflib import Graph
|
||||
|
||||
from haystack.document_stores import BaseKnowledgeGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryKnowledgeGraph(BaseKnowledgeGraph):
|
||||
"""
|
||||
In memory Knowledge graph store, based on rdflib.
|
||||
"""
|
||||
|
||||
def __init__(self, index: str = "document"):
|
||||
"""
|
||||
The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions.
|
||||
|
||||
Init the in memory knowledge graph
|
||||
|
||||
:param index: name of the index
|
||||
"""
|
||||
warnings.warn(
|
||||
"The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.indexes: Dict[str, Graph] = defaultdict(dict) # type: ignore [arg-type]
|
||||
self.index: str = index
|
||||
|
||||
def create_index(self, index: Optional[str] = None):
|
||||
"""
|
||||
Create a new index stored in memory
|
||||
|
||||
:param index: name of the index
|
||||
"""
|
||||
index = index or self.index
|
||||
if index not in self.indexes:
|
||||
self.indexes[index] = Graph()
|
||||
else:
|
||||
logger.warning("Index '%s' is already present.", index)
|
||||
|
||||
def delete_index(self, index: Optional[str] = None):
|
||||
"""
|
||||
Delete an existing index. The index including all data will be removed.
|
||||
|
||||
:param index: The name of the index to delete.
|
||||
"""
|
||||
index = index or self.index
|
||||
|
||||
if index in self.indexes:
|
||||
del self.indexes[index]
|
||||
logger.info("Index '%s' deleted.", index)
|
||||
|
||||
def import_from_ttl_file(self, path: Path, index: Optional[str] = None):
|
||||
"""
|
||||
Load in memory an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file
|
||||
|
||||
:param path: path to a .ttl containing a knowledge graph
|
||||
:param index: name of the index
|
||||
"""
|
||||
index = index or self.index
|
||||
self.indexes[index].parse(path)
|
||||
|
||||
def get_all_triples(self, index: Optional[str] = None):
|
||||
"""
|
||||
Query the given in memory index for all its stored triples. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index
|
||||
:return: all triples stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT * WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index)
|
||||
return results
|
||||
|
||||
def get_all_subjects(self, index: Optional[str] = None):
|
||||
"""
|
||||
Query the given in memory index for all its stored subjects. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index
|
||||
:return: all subjects stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index)
|
||||
return results
|
||||
|
||||
def get_all_predicates(self, index: Optional[str] = None):
|
||||
"""
|
||||
Query the given in memory index for all its stored predicates. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index
|
||||
:return: all predicates stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index)
|
||||
return results
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
"""
|
||||
There is no field mapping required
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_all_objects(self, index: Optional[str] = None):
|
||||
"""
|
||||
Query the given in memory index for all its stored objects. Duplicates are not filtered.
|
||||
|
||||
:param index: name of the index
|
||||
:return: all objects stored in the index
|
||||
"""
|
||||
sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }"
|
||||
results = self.query(sparql_query=sparql_query, index=index)
|
||||
return results
|
||||
|
||||
def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Execute a SPARQL query on the given in memory index
|
||||
|
||||
:param sparql_query: SPARQL query that shall be executed
|
||||
:param index: name of the index
|
||||
:return: query result
|
||||
"""
|
||||
index = index or self.index
|
||||
raw_results = self.indexes[index].query(sparql_query)
|
||||
|
||||
if raw_results.askAnswer is not None:
|
||||
return raw_results.askAnswer
|
||||
else:
|
||||
formatted_results = []
|
||||
for b in raw_results.bindings:
|
||||
formatted_result = {}
|
||||
items = list(b.items())
|
||||
for item in items:
|
||||
type_ = item[0].toPython()[1:]
|
||||
uri = item[1].toPython() # type: ignore [attr-defined]
|
||||
formatted_result[type_] = {"type": "uri", "value": uri}
|
||||
formatted_results.append(formatted_result)
|
||||
return formatted_results
|
||||
@ -40,7 +40,6 @@ from haystack.nodes.retriever import (
|
||||
FilterRetriever,
|
||||
MultihopEmbeddingRetriever,
|
||||
TfidfRetriever,
|
||||
Text2SparqlRetriever,
|
||||
TableTextRetriever,
|
||||
MultiModalRetriever,
|
||||
WebRetriever,
|
||||
|
||||
@ -8,5 +8,4 @@ from haystack.nodes.retriever.dense import (
|
||||
)
|
||||
from haystack.nodes.retriever.multimodal import MultiModalRetriever
|
||||
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
|
||||
from haystack.nodes.retriever.text2sparql import Text2SparqlRetriever
|
||||
from haystack.nodes.retriever.web import WebRetriever
|
||||
|
||||
@ -11,42 +11,12 @@ from haystack.schema import Document, MultiLabel
|
||||
from haystack.errors import HaystackError, PipelineError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.telemetry import send_event
|
||||
from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, FilterType
|
||||
from haystack.document_stores.base import BaseDocumentStore, FilterType
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseGraphRetriever(BaseComponent):
|
||||
"""
|
||||
Base classfor knowledge graph retrievers.
|
||||
"""
|
||||
|
||||
knowledge_graph: BaseKnowledgeGraph
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, query: str, top_k: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None):
|
||||
pass
|
||||
|
||||
def eval(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, query: str, top_k: Optional[int] = None): # type: ignore
|
||||
answers = self.retrieve(query=query, top_k=top_k)
|
||||
results = {"answers": answers}
|
||||
return results, "output_1"
|
||||
|
||||
def run_batch(self, queries: List[str], top_k: Optional[int] = None): # type: ignore
|
||||
answers = self.retrieve_batch(queries=queries, top_k=top_k)
|
||||
results = {"answers": answers}
|
||||
return results, "output_1"
|
||||
|
||||
|
||||
class BaseRetriever(BaseComponent):
|
||||
"""
|
||||
Base class for regular retrievers.
|
||||
|
||||
@ -1,150 +0,0 @@
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import warnings
|
||||
import logging
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
from haystack.document_stores import BaseKnowledgeGraph
|
||||
from haystack.nodes.retriever.base import BaseGraphRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Text2SparqlRetriever(BaseGraphRetriever):
|
||||
"""
|
||||
Graph retriever that uses a pre-trained Bart model to translate natural language questions
|
||||
given in text form to queries in SPARQL format.
|
||||
The generated SPARQL query is executed on a knowledge graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_graph: BaseKnowledgeGraph,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
top_k: int = 1,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
"""
|
||||
The Text2SparqlRetriever component is deprecated and will be removed in future versions.
|
||||
|
||||
Init the Retriever by providing a knowledge graph and a pre-trained BART model
|
||||
|
||||
:param knowledge_graph: An instance of BaseKnowledgeGraph on which to execute SPARQL queries.
|
||||
:param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model.
|
||||
:param model_version: The version of the model to use for entity extraction.
|
||||
:param top_k: How many SPARQL queries to generate per text query.
|
||||
:param use_auth_token: The API token used to download private models from Huggingface.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
Additional information can be found here
|
||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
warnings.warn(
|
||||
"The Text2SparqlRetriever component is deprecated and will be removed in future versions.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.knowledge_graph = knowledge_graph
|
||||
# TODO We should extend this to any seq2seq models and use the AutoModel class
|
||||
self.model = BartForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token, revision=model_version
|
||||
)
|
||||
self.tok = BartTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
self.top_k = top_k
|
||||
|
||||
def retrieve(self, query: str, top_k: Optional[int] = None):
|
||||
"""
|
||||
Translate a text query to SPARQL and execute it on the knowledge graph to retrieve a list of answers
|
||||
|
||||
:param query: Text query that shall be translated to SPARQL and then executed on the knowledge graph
|
||||
:param top_k: How many SPARQL queries to generate per text query.
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
inputs = self.tok([query], max_length=100, truncation=True, return_tensors="pt")
|
||||
# generate top_k+2 SPARQL queries so that we can dismiss some queries with wrong syntax
|
||||
temp = self.model.generate(
|
||||
inputs["input_ids"], num_beams=5, max_length=100, num_return_sequences=top_k + 2, early_stopping=True
|
||||
)
|
||||
sparql_queries = [
|
||||
self.tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in temp
|
||||
]
|
||||
answers = []
|
||||
for sparql_query in sparql_queries:
|
||||
ans, query = self._query_kg(sparql_query=sparql_query)
|
||||
if len(ans) > 0:
|
||||
answers.append((ans, query))
|
||||
|
||||
# if there are no answers we still want to return something
|
||||
if len(answers) == 0:
|
||||
answers.append(("", ""))
|
||||
results = answers[:top_k]
|
||||
results = [self.format_result(result) for result in results]
|
||||
return results
|
||||
|
||||
def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None):
|
||||
"""
|
||||
Translate a list of queries to SPARQL and execute it on the knowledge graph to retrieve
|
||||
a list of lists of answers (one per query).
|
||||
|
||||
:param queries: List of queries that shall be translated to SPARQL and then executed on the
|
||||
knowledge graph.
|
||||
:param top_k: How many SPARQL queries to generate per text query.
|
||||
"""
|
||||
# TODO: This method currently just calls the retrieve method multiple times, so there is room for improvement.
|
||||
|
||||
results = []
|
||||
for query in queries:
|
||||
cur_result = self.run(query=query, top_k=top_k)
|
||||
results.append(cur_result)
|
||||
|
||||
return results
|
||||
|
||||
def _query_kg(self, sparql_query):
|
||||
"""
|
||||
Execute a single SPARQL query on the knowledge graph to retrieve an answer and unpack
|
||||
different answer styles for boolean queries, count queries, and list queries.
|
||||
|
||||
:param sparql_query: SPARQL query that shall be executed on the knowledge graph
|
||||
"""
|
||||
try:
|
||||
response = self.knowledge_graph.query(sparql_query=sparql_query)
|
||||
|
||||
# unpack different answer styles
|
||||
if isinstance(response, list):
|
||||
if len(response) == 0:
|
||||
result = ""
|
||||
else:
|
||||
result = []
|
||||
for x in response:
|
||||
for v in x.values():
|
||||
result.append(v["value"])
|
||||
elif isinstance(response, bool):
|
||||
result = str(response)
|
||||
elif "count" in response[0]:
|
||||
result = str(int(response[0]["count"]["value"]))
|
||||
else:
|
||||
result = ""
|
||||
|
||||
except Exception:
|
||||
result = ""
|
||||
|
||||
return result, sparql_query
|
||||
|
||||
def format_result(self, result):
|
||||
"""
|
||||
Generate formatted dictionary output with text answer and additional info
|
||||
|
||||
:param result: The result of a SPARQL query as retrieved from the knowledge graph
|
||||
"""
|
||||
query = result[1]
|
||||
prediction = result[0]
|
||||
prediction_meta = {"model": self.__class__.__name__, "sparql_query": query}
|
||||
|
||||
return {"answer": prediction, "prediction_meta": prediction_meta}
|
||||
|
||||
def eval(self):
|
||||
raise NotImplementedError
|
||||
@ -1,22 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from haystack.document_stores.graphdb import GraphDBKnowledgeGraph
|
||||
|
||||
from ..conftest import fail_at_version
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@fail_at_version(1, 17)
|
||||
def test_graphdb_knowledge_graph_deprecation_warning():
|
||||
with pytest.warns(DeprecationWarning) as w:
|
||||
GraphDBKnowledgeGraph()
|
||||
|
||||
assert len(w) == 2
|
||||
assert (
|
||||
w[0].message.args[0]
|
||||
== "The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions."
|
||||
)
|
||||
assert (
|
||||
w[1].message.args[0]
|
||||
== "The BaseKnowledgeGraph component is deprecated and will be removed in future versions."
|
||||
)
|
||||
@ -1,22 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from haystack.document_stores.memory_knowledgegraph import InMemoryKnowledgeGraph
|
||||
|
||||
from ..conftest import fail_at_version
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@fail_at_version(1, 17)
|
||||
def test_in_memory_knowledge_graph_deprecation_warning():
|
||||
with pytest.warns(DeprecationWarning) as w:
|
||||
InMemoryKnowledgeGraph()
|
||||
|
||||
assert len(w) == 2
|
||||
assert (
|
||||
w[0].message.args[0]
|
||||
== "The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions."
|
||||
)
|
||||
assert (
|
||||
w[1].message.args[0]
|
||||
== "The BaseKnowledgeGraph component is deprecated and will be removed in future versions."
|
||||
)
|
||||
@ -27,7 +27,6 @@ from haystack.document_stores import WeaviateDocumentStore
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.nodes.retriever.web import WebRetriever
|
||||
from haystack.nodes.search_engine import WebSearch
|
||||
from haystack.nodes.retriever import Text2SparqlRetriever
|
||||
from haystack.pipelines import DocumentSearchPipeline
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
@ -1271,21 +1270,3 @@ def test_web_retriever_mode_snippets(monkeypatch):
|
||||
web_retriever = WebRetriever(api_key="", top_search_results=2)
|
||||
result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
|
||||
assert result == expected_search_results["documents"]
|
||||
|
||||
|
||||
@fail_at_version(1, 17)
|
||||
def test_text_2_sparql_retriever_deprecation():
|
||||
BartForConditionalGeneration = object()
|
||||
BartTokenizer = object()
|
||||
with patch.multiple(
|
||||
"haystack.nodes.retriever.text2sparql", BartForConditionalGeneration=DEFAULT, BartTokenizer=DEFAULT
|
||||
):
|
||||
knowledge_graph = Mock()
|
||||
with pytest.warns(DeprecationWarning) as w:
|
||||
Text2SparqlRetriever(knowledge_graph)
|
||||
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
w[0].message.args[0]
|
||||
== "The Text2SparqlRetriever component is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user