diff --git a/e2e/document_stores/test_knowledge_graph.py b/e2e/document_stores/test_knowledge_graph.py deleted file mode 100644 index cfb028b19..000000000 --- a/e2e/document_stores/test_knowledge_graph.py +++ /dev/null @@ -1,107 +0,0 @@ -from pathlib import Path - -from haystack.nodes import Text2SparqlRetriever -from haystack.document_stores import GraphDBKnowledgeGraph, InMemoryKnowledgeGraph -from haystack.utils import fetch_archive_from_http - - -def test_graph_retrieval(): - # we use a timeout double the default in the CI to account for slow runners - timeout = 20 - - # TODO rename doc_dir - graph_dir = "../data/tutorial10_knowledge_graph/" - s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip" - fetch_archive_from_http(url=s3_url, output_dir=graph_dir) - - # Fetch a pre-trained BART model that translates natural language questions to SPARQL queries - model_dir = "../saved_models/tutorial10_knowledge_graph/" - s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip" - fetch_archive_from_http(url=s3_url, output_dir=model_dir) - - kg = GraphDBKnowledgeGraph(index="tutorial_10_index") - kg.delete_index(timeout=timeout) - kg.create_index(config_path=Path(graph_dir + "repo-config.ttl"), timeout=timeout) - kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl"), timeout=timeout) - triple = { - "p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"}, - "s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"}, - "o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"}, - } - triples = kg.get_all_triples() - assert len(triples) > 0 - assert triple in triples - - # Define prefixes for names of resources so that we can use shorter resource names in queries - prefixes = """PREFIX rdf: - PREFIX xsd: - PREFIX hp: - """ - kg.prefixes = prefixes - - kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4") - - result = kgqa_retriever.retrieve(query="In which house is Harry Potter?") - assert result[0] == { - "answer": ["https://deepset.ai/harry_potter/Gryffindor"], - "prediction_meta": { - "model": "Text2SparqlRetriever", - "sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }", - }, - } - - result = kgqa_retriever._query_kg( - sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }" - ) - assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid" - - result = kgqa_retriever._query_kg( - sparql_query="select distinct ?obj where { ?obj . }" - ) - assert result[0][0] == "https://deepset.ai/harry_potter/Otter" - - -def test_inmemory_graph_retrieval(): - # TODO rename doc_dir - graph_dir = "../data/tutorial10_knowledge_graph/" - s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/triples_and_config.zip" - fetch_archive_from_http(url=s3_url, output_dir=graph_dir) - - # Fetch a pre-trained BART model that translates natural language questions to SPARQL queries - model_dir = "../saved_models/tutorial10_knowledge_graph/" - s3_url = "https://fandom-qa.s3-eu-west-1.amazonaws.com/saved_models/hp_v3.4.zip" - fetch_archive_from_http(url=s3_url, output_dir=model_dir) - - kg = InMemoryKnowledgeGraph(index="tutorial_10_index") - kg.delete_index() - kg.create_index() - kg.import_from_ttl_file(index="tutorial_10_index", path=Path(graph_dir + "triples.ttl")) - triple = { - "p": {"type": "uri", "value": "https://deepset.ai/harry_potter/_paternalgrandfather"}, - "s": {"type": "uri", "value": "https://deepset.ai/harry_potter/Melody_fawley"}, - "o": {"type": "uri", "value": "https://deepset.ai/harry_potter/Marshall_fawley"}, - } - triples = kg.get_all_triples() - assert len(triples) > 0 - assert triple in triples - - kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4") - - result = kgqa_retriever.retrieve(query="In which house is Harry Potter?") - assert result[0] == { - "answer": ["https://deepset.ai/harry_potter/Gryffindor"], - "prediction_meta": { - "model": "Text2SparqlRetriever", - "sparql_query": "select ?a { hp:Harry_potter hp:house ?a . }", - }, - } - - result = kgqa_retriever._query_kg( - sparql_query="select distinct ?sbj where { ?sbj hp:job hp:Keeper_of_keys_and_grounds . }" - ) - assert result[0][0] == "https://deepset.ai/harry_potter/Rubeus_hagrid" - - result = kgqa_retriever._query_kg( - sparql_query="select distinct ?obj where { ?obj . }" - ) - assert result[0][0] == "https://deepset.ai/harry_potter/Otter" diff --git a/haystack/document_stores/__init__.py b/haystack/document_stores/__init__.py index 11ea47b51..24825989b 100644 --- a/haystack/document_stores/__init__.py +++ b/haystack/document_stores/__init__.py @@ -1,5 +1,5 @@ from haystack.utils.import_utils import safe_import -from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, KeywordDocumentStore +from haystack.document_stores.base import BaseDocumentStore, KeywordDocumentStore from haystack.document_stores.memory import InMemoryDocumentStore from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore @@ -19,7 +19,3 @@ SQLDocumentStore = safe_import("haystack.document_stores.sql", "SQLDocumentStore FAISSDocumentStore = safe_import("haystack.document_stores.faiss", "FAISSDocumentStore", "faiss") PineconeDocumentStore = safe_import("haystack.document_stores.pinecone", "PineconeDocumentStore", "pinecone") WeaviateDocumentStore = safe_import("haystack.document_stores.weaviate", "WeaviateDocumentStore", "weaviate") -GraphDBKnowledgeGraph = safe_import("haystack.document_stores.graphdb", "GraphDBKnowledgeGraph", "graphdb") -InMemoryKnowledgeGraph = safe_import( - "haystack.document_stores.memory_knowledgegraph", "InMemoryKnowledgeGraph", "inmemorygraph" -) diff --git a/haystack/document_stores/base.py b/haystack/document_stores/base.py index 8ed029894..487ba4af4 100644 --- a/haystack/document_stores/base.py +++ b/haystack/document_stores/base.py @@ -2,7 +2,6 @@ from typing import Generator, Optional, Dict, List, Set, Union, Any -import warnings import logging import collections from pathlib import Path @@ -32,35 +31,6 @@ except (ImportError, ModuleNotFoundError): return f -class BaseKnowledgeGraph(BaseComponent): - """ - Base class for implementing Knowledge Graphs. - - The BaseKnowledgeGraph component is deprecated and will be removed in future versions. - """ - - def __init__(self): - warnings.warn( - "The BaseKnowledgeGraph component is deprecated and will be removed in future versions.", - category=DeprecationWarning, - ) - super().__init__() - - outgoing_edges = 1 - - def run(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): # type: ignore - result = self.query(sparql_query=sparql_query, index=index, headers=headers) - output = {"sparql_result": result} - return output, "output_1" - - def run_batch(self): - raise NotImplementedError("run_batch is not implemented for KnowledgeGraphs.") - - @abstractmethod - def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): - raise NotImplementedError - - class BaseDocumentStore(BaseComponent): """ Base class for implementing Document Stores. diff --git a/haystack/document_stores/graphdb.py b/haystack/document_stores/graphdb.py deleted file mode 100644 index 10996b676..000000000 --- a/haystack/document_stores/graphdb.py +++ /dev/null @@ -1,206 +0,0 @@ -from typing import Dict, Optional, Union, Tuple - -import warnings -from pathlib import Path - -import requests -from requests.auth import HTTPBasicAuth - -try: - from SPARQLWrapper import SPARQLWrapper, JSON -except (ImportError, ModuleNotFoundError) as ie: - from haystack.utils.import_utils import _optional_component_not_installed - - _optional_component_not_installed(__name__, "graphdb", ie) - -from haystack.document_stores import BaseKnowledgeGraph - - -class GraphDBKnowledgeGraph(BaseKnowledgeGraph): - """ - Knowledge graph store that runs on a GraphDB instance. - """ - - def __init__( - self, - host: str = "localhost", - port: int = 7200, - username: str = "", - password: str = "", - index: Optional[str] = None, - prefixes: str = "", - ): - """ - The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions. - - Init the knowledge graph by defining the settings to connect with a GraphDB instance - - :param host: address of server where the GraphDB instance is running - :param port: port where the GraphDB instance is running - :param username: username to login to the GraphDB instance (if any) - :param password: password to login to the GraphDB instance (if any) - :param index: name of the index (also called repository) stored in the GraphDB instance - :param prefixes: definitions of namespaces with a new line after each namespace, e.g., PREFIX hp: - """ - warnings.warn( - "The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions.", - category=DeprecationWarning, - ) - super().__init__() - - self.url = f"http://{host}:{port}" - self.index = index - self.username = username - self.password = password - self.prefixes = prefixes - - def create_index( - self, - config_path: Path, - headers: Optional[Dict[str, str]] = None, - timeout: Union[float, Tuple[float, float]] = 10.0, - ): - """ - Create a new index (also called repository) stored in the GraphDB instance - - :param config_path: path to a .ttl file with configuration settings, details: - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :param timeout: How many seconds to wait for the server to send data before giving up, - as a float, or a :ref:`(connect timeout, read timeout) ` tuple. - Defaults to 10 seconds. - https://graphdb.ontotext.com/documentation/free/configuring-a-repository.html#configure-a-repository-programmatically - """ - url = f"{self.url}/rest/repositories" - files = {"config": open(config_path, "r", encoding="utf-8")} - response = requests.post(url, files=files, headers=headers, timeout=timeout) - if response.status_code > 299: - raise Exception(response.text) - - def delete_index(self, headers: Optional[Dict[str, str]] = None, timeout: Union[float, Tuple[float, float]] = 10.0): - """ - Delete the index that GraphDBKnowledgeGraph is connected to. This method deletes all data stored in the index. - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :param timeout: How many seconds to wait for the server to send data before giving up, - as a float, or a :ref:`(connect timeout, read timeout) ` tuple. - Defaults to 10 seconds. - """ - url = f"{self.url}/rest/repositories/{self.index}" - response = requests.delete(url, headers=headers, timeout=timeout) - if response.status_code > 299: - raise Exception(response.text) - - def import_from_ttl_file( - self, - index: str, - path: Path, - headers: Optional[Dict[str, str]] = None, - timeout: Union[float, Tuple[float, float]] = 10.0, - ): - """ - Load an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file into an index of GraphDB - - :param index: name of the index (also called repository) in the GraphDB instance where the imported triples shall be stored - :param path: path to a .ttl containing a knowledge graph - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :param timeout: How many seconds to wait for the server to send data before giving up, - as a float, or a :ref:`(connect timeout, read timeout) ` tuple. - Defaults to 10 seconds. - """ - url = f"{self.url}/repositories/{index}/statements" - headers = ( - {"Content-type": "application/x-turtle"} - if headers is None - else {**{"Content-type": "application/x-turtle"}, **headers} - ) - response = requests.post( - url, - headers=headers, - data=open(path, "r", encoding="utf-8").read().encode("utf-8"), - auth=HTTPBasicAuth(self.username, self.password), - timeout=timeout, - ) - if response.status_code > 299: - raise Exception(response.text) - - def get_all_triples(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): - """ - Query the given index in the GraphDB instance for all its stored triples. Duplicates are not filtered. - - :param index: name of the index (also called repository) in the GraphDB instance - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :return: all triples stored in the index - """ - sparql_query = "SELECT * WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index, headers=headers) - return results - - def get_all_subjects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): - """ - Query the given index in the GraphDB instance for all its stored subjects. Duplicates are not filtered. - - :param index: name of the index (also called repository) in the GraphDB instance - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :return: all subjects stored in the index - """ - sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index, headers=headers) - return results - - def get_all_predicates(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): - """ - Query the given index in the GraphDB instance for all its stored predicates. Duplicates are not filtered. - - :param index: name of the index (also called repository) in the GraphDB instance - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :return: all predicates stored in the index - """ - sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index, headers=headers) - return results - - def _create_document_field_map(self) -> Dict: - """ - There is no field mapping required - """ - return {} - - def get_all_objects(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): - """ - Query the given index in the GraphDB instance for all its stored objects. Duplicates are not filtered. - - :param index: name of the index (also called repository) in the GraphDB instance - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :return: all objects stored in the index - """ - sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index, headers=headers) - return results - - def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): - """ - Execute a SPARQL query on the given index in the GraphDB instance - - :param sparql_query: SPARQL query that shall be executed - :param index: name of the index (also called repository) in the GraphDB instance - :param headers: Custom HTTP headers to pass to http client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) - :return: query result - """ - if self.index is None and index is None: - raise Exception("Index name is required") - index = index or self.index - sparql = SPARQLWrapper(f"{self.url}/repositories/{index}") - sparql.setCredentials(self.username, self.password) - sparql.setQuery(self.prefixes + sparql_query) - sparql.setReturnFormat(JSON) - if headers is not None: - sparql.customHttpHeaders = headers - results = sparql.query().convert() - # if query is a boolean query, return boolean instead of text result - # FIXME: 'results' likely doesn't support membership test (`"something" in results`). - # Pylint raises unsupported-membership-test and unsubscriptable-object. - # Silenced for now, keep in mind for future debugging. - return ( - results["results"]["bindings"] # type: ignore # pylint: disable=unsubscriptable-object - if "results" in results # type: ignore # pylint: disable=unsupported-membership-test - else results["boolean"] # type: ignore # pylint: disable=unsubscriptable-object - ) diff --git a/haystack/document_stores/memory_knowledgegraph.py b/haystack/document_stores/memory_knowledgegraph.py deleted file mode 100644 index c1aad7e36..000000000 --- a/haystack/document_stores/memory_knowledgegraph.py +++ /dev/null @@ -1,144 +0,0 @@ -from typing import Dict, Optional - -import warnings -import logging -from collections import defaultdict -from pathlib import Path - -from rdflib import Graph - -from haystack.document_stores import BaseKnowledgeGraph - -logger = logging.getLogger(__name__) - - -class InMemoryKnowledgeGraph(BaseKnowledgeGraph): - """ - In memory Knowledge graph store, based on rdflib. - """ - - def __init__(self, index: str = "document"): - """ - The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions. - - Init the in memory knowledge graph - - :param index: name of the index - """ - warnings.warn( - "The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions.", - category=DeprecationWarning, - ) - super().__init__() - - self.indexes: Dict[str, Graph] = defaultdict(dict) # type: ignore [arg-type] - self.index: str = index - - def create_index(self, index: Optional[str] = None): - """ - Create a new index stored in memory - - :param index: name of the index - """ - index = index or self.index - if index not in self.indexes: - self.indexes[index] = Graph() - else: - logger.warning("Index '%s' is already present.", index) - - def delete_index(self, index: Optional[str] = None): - """ - Delete an existing index. The index including all data will be removed. - - :param index: The name of the index to delete. - """ - index = index or self.index - - if index in self.indexes: - del self.indexes[index] - logger.info("Index '%s' deleted.", index) - - def import_from_ttl_file(self, path: Path, index: Optional[str] = None): - """ - Load in memory an existing knowledge graph represented in the form of triples of subject, predicate, and object from a .ttl file - - :param path: path to a .ttl containing a knowledge graph - :param index: name of the index - """ - index = index or self.index - self.indexes[index].parse(path) - - def get_all_triples(self, index: Optional[str] = None): - """ - Query the given in memory index for all its stored triples. Duplicates are not filtered. - - :param index: name of the index - :return: all triples stored in the index - """ - sparql_query = "SELECT * WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index) - return results - - def get_all_subjects(self, index: Optional[str] = None): - """ - Query the given in memory index for all its stored subjects. Duplicates are not filtered. - - :param index: name of the index - :return: all subjects stored in the index - """ - sparql_query = "SELECT ?s WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index) - return results - - def get_all_predicates(self, index: Optional[str] = None): - """ - Query the given in memory index for all its stored predicates. Duplicates are not filtered. - - :param index: name of the index - :return: all predicates stored in the index - """ - sparql_query = "SELECT ?p WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index) - return results - - def _create_document_field_map(self) -> Dict: - """ - There is no field mapping required - """ - return {} - - def get_all_objects(self, index: Optional[str] = None): - """ - Query the given in memory index for all its stored objects. Duplicates are not filtered. - - :param index: name of the index - :return: all objects stored in the index - """ - sparql_query = "SELECT ?o WHERE { ?s ?p ?o. }" - results = self.query(sparql_query=sparql_query, index=index) - return results - - def query(self, sparql_query: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None): - """ - Execute a SPARQL query on the given in memory index - - :param sparql_query: SPARQL query that shall be executed - :param index: name of the index - :return: query result - """ - index = index or self.index - raw_results = self.indexes[index].query(sparql_query) - - if raw_results.askAnswer is not None: - return raw_results.askAnswer - else: - formatted_results = [] - for b in raw_results.bindings: - formatted_result = {} - items = list(b.items()) - for item in items: - type_ = item[0].toPython()[1:] - uri = item[1].toPython() # type: ignore [attr-defined] - formatted_result[type_] = {"type": "uri", "value": uri} - formatted_results.append(formatted_result) - return formatted_results diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index 830327732..e0fa2f25b 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -40,7 +40,6 @@ from haystack.nodes.retriever import ( FilterRetriever, MultihopEmbeddingRetriever, TfidfRetriever, - Text2SparqlRetriever, TableTextRetriever, MultiModalRetriever, WebRetriever, diff --git a/haystack/nodes/retriever/__init__.py b/haystack/nodes/retriever/__init__.py index 1df133e99..25154a177 100644 --- a/haystack/nodes/retriever/__init__.py +++ b/haystack/nodes/retriever/__init__.py @@ -8,5 +8,4 @@ from haystack.nodes.retriever.dense import ( ) from haystack.nodes.retriever.multimodal import MultiModalRetriever from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever -from haystack.nodes.retriever.text2sparql import Text2SparqlRetriever from haystack.nodes.retriever.web import WebRetriever diff --git a/haystack/nodes/retriever/base.py b/haystack/nodes/retriever/base.py index 0157fabdb..c2b80e667 100644 --- a/haystack/nodes/retriever/base.py +++ b/haystack/nodes/retriever/base.py @@ -11,42 +11,12 @@ from haystack.schema import Document, MultiLabel from haystack.errors import HaystackError, PipelineError from haystack.nodes.base import BaseComponent from haystack.telemetry import send_event -from haystack.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, FilterType +from haystack.document_stores.base import BaseDocumentStore, FilterType logger = logging.getLogger(__name__) -class BaseGraphRetriever(BaseComponent): - """ - Base classfor knowledge graph retrievers. - """ - - knowledge_graph: BaseKnowledgeGraph - outgoing_edges = 1 - - @abstractmethod - def retrieve(self, query: str, top_k: Optional[int] = None): - pass - - @abstractmethod - def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None): - pass - - def eval(self): - raise NotImplementedError - - def run(self, query: str, top_k: Optional[int] = None): # type: ignore - answers = self.retrieve(query=query, top_k=top_k) - results = {"answers": answers} - return results, "output_1" - - def run_batch(self, queries: List[str], top_k: Optional[int] = None): # type: ignore - answers = self.retrieve_batch(queries=queries, top_k=top_k) - results = {"answers": answers} - return results, "output_1" - - class BaseRetriever(BaseComponent): """ Base class for regular retrievers. diff --git a/haystack/nodes/retriever/text2sparql.py b/haystack/nodes/retriever/text2sparql.py deleted file mode 100644 index 3148a07e2..000000000 --- a/haystack/nodes/retriever/text2sparql.py +++ /dev/null @@ -1,150 +0,0 @@ -from typing import Optional, List, Union - -import warnings -import logging -from transformers import BartForConditionalGeneration, BartTokenizer - -from haystack.document_stores import BaseKnowledgeGraph -from haystack.nodes.retriever.base import BaseGraphRetriever - - -logger = logging.getLogger(__name__) - - -class Text2SparqlRetriever(BaseGraphRetriever): - """ - Graph retriever that uses a pre-trained Bart model to translate natural language questions - given in text form to queries in SPARQL format. - The generated SPARQL query is executed on a knowledge graph. - """ - - def __init__( - self, - knowledge_graph: BaseKnowledgeGraph, - model_name_or_path: Optional[str] = None, - model_version: Optional[str] = None, - top_k: int = 1, - use_auth_token: Optional[Union[str, bool]] = None, - ): - """ - The Text2SparqlRetriever component is deprecated and will be removed in future versions. - - Init the Retriever by providing a knowledge graph and a pre-trained BART model - - :param knowledge_graph: An instance of BaseKnowledgeGraph on which to execute SPARQL queries. - :param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model. - :param model_version: The version of the model to use for entity extraction. - :param top_k: How many SPARQL queries to generate per text query. - :param use_auth_token: The API token used to download private models from Huggingface. - If this parameter is set to `True`, then the token generated when running - `transformers-cli login` (stored in ~/.huggingface) will be used. - Additional information can be found here - https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained - """ - warnings.warn( - "The Text2SparqlRetriever component is deprecated and will be removed in future versions.", - category=DeprecationWarning, - ) - super().__init__() - - self.knowledge_graph = knowledge_graph - # TODO We should extend this to any seq2seq models and use the AutoModel class - self.model = BartForConditionalGeneration.from_pretrained( - model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token, revision=model_version - ) - self.tok = BartTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) - self.top_k = top_k - - def retrieve(self, query: str, top_k: Optional[int] = None): - """ - Translate a text query to SPARQL and execute it on the knowledge graph to retrieve a list of answers - - :param query: Text query that shall be translated to SPARQL and then executed on the knowledge graph - :param top_k: How many SPARQL queries to generate per text query. - """ - if top_k is None: - top_k = self.top_k - inputs = self.tok([query], max_length=100, truncation=True, return_tensors="pt") - # generate top_k+2 SPARQL queries so that we can dismiss some queries with wrong syntax - temp = self.model.generate( - inputs["input_ids"], num_beams=5, max_length=100, num_return_sequences=top_k + 2, early_stopping=True - ) - sparql_queries = [ - self.tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in temp - ] - answers = [] - for sparql_query in sparql_queries: - ans, query = self._query_kg(sparql_query=sparql_query) - if len(ans) > 0: - answers.append((ans, query)) - - # if there are no answers we still want to return something - if len(answers) == 0: - answers.append(("", "")) - results = answers[:top_k] - results = [self.format_result(result) for result in results] - return results - - def retrieve_batch(self, queries: List[str], top_k: Optional[int] = None): - """ - Translate a list of queries to SPARQL and execute it on the knowledge graph to retrieve - a list of lists of answers (one per query). - - :param queries: List of queries that shall be translated to SPARQL and then executed on the - knowledge graph. - :param top_k: How many SPARQL queries to generate per text query. - """ - # TODO: This method currently just calls the retrieve method multiple times, so there is room for improvement. - - results = [] - for query in queries: - cur_result = self.run(query=query, top_k=top_k) - results.append(cur_result) - - return results - - def _query_kg(self, sparql_query): - """ - Execute a single SPARQL query on the knowledge graph to retrieve an answer and unpack - different answer styles for boolean queries, count queries, and list queries. - - :param sparql_query: SPARQL query that shall be executed on the knowledge graph - """ - try: - response = self.knowledge_graph.query(sparql_query=sparql_query) - - # unpack different answer styles - if isinstance(response, list): - if len(response) == 0: - result = "" - else: - result = [] - for x in response: - for v in x.values(): - result.append(v["value"]) - elif isinstance(response, bool): - result = str(response) - elif "count" in response[0]: - result = str(int(response[0]["count"]["value"])) - else: - result = "" - - except Exception: - result = "" - - return result, sparql_query - - def format_result(self, result): - """ - Generate formatted dictionary output with text answer and additional info - - :param result: The result of a SPARQL query as retrieved from the knowledge graph - """ - query = result[1] - prediction = result[0] - prediction_meta = {"model": self.__class__.__name__, "sparql_query": query} - - return {"answer": prediction, "prediction_meta": prediction_meta} - - def eval(self): - raise NotImplementedError diff --git a/test/document_stores/test_graphdb.py b/test/document_stores/test_graphdb.py deleted file mode 100644 index e3ae98486..000000000 --- a/test/document_stores/test_graphdb.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from haystack.document_stores.graphdb import GraphDBKnowledgeGraph - -from ..conftest import fail_at_version - - -@pytest.mark.unit -@fail_at_version(1, 17) -def test_graphdb_knowledge_graph_deprecation_warning(): - with pytest.warns(DeprecationWarning) as w: - GraphDBKnowledgeGraph() - - assert len(w) == 2 - assert ( - w[0].message.args[0] - == "The GraphDBKnowledgeGraph component is deprecated and will be removed in future versions." - ) - assert ( - w[1].message.args[0] - == "The BaseKnowledgeGraph component is deprecated and will be removed in future versions." - ) diff --git a/test/document_stores/test_memory_knowledgegraph.py b/test/document_stores/test_memory_knowledgegraph.py deleted file mode 100644 index f871439e6..000000000 --- a/test/document_stores/test_memory_knowledgegraph.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from haystack.document_stores.memory_knowledgegraph import InMemoryKnowledgeGraph - -from ..conftest import fail_at_version - - -@pytest.mark.unit -@fail_at_version(1, 17) -def test_in_memory_knowledge_graph_deprecation_warning(): - with pytest.warns(DeprecationWarning) as w: - InMemoryKnowledgeGraph() - - assert len(w) == 2 - assert ( - w[0].message.args[0] - == "The InMemoryKnowledgeGraph component is deprecated and will be removed in future versions." - ) - assert ( - w[1].message.args[0] - == "The BaseKnowledgeGraph component is deprecated and will be removed in future versions." - ) diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 16ea3eb05..8c5d9899d 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -27,7 +27,6 @@ from haystack.document_stores import WeaviateDocumentStore from haystack.nodes.retriever.base import BaseRetriever from haystack.nodes.retriever.web import WebRetriever from haystack.nodes.search_engine import WebSearch -from haystack.nodes.retriever import Text2SparqlRetriever from haystack.pipelines import DocumentSearchPipeline from haystack.schema import Document from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore @@ -1271,21 +1270,3 @@ def test_web_retriever_mode_snippets(monkeypatch): web_retriever = WebRetriever(api_key="", top_search_results=2) result = web_retriever.retrieve(query="Who is the father of Arya Stark?") assert result == expected_search_results["documents"] - - -@fail_at_version(1, 17) -def test_text_2_sparql_retriever_deprecation(): - BartForConditionalGeneration = object() - BartTokenizer = object() - with patch.multiple( - "haystack.nodes.retriever.text2sparql", BartForConditionalGeneration=DEFAULT, BartTokenizer=DEFAULT - ): - knowledge_graph = Mock() - with pytest.warns(DeprecationWarning) as w: - Text2SparqlRetriever(knowledge_graph) - - assert len(w) == 1 - assert ( - w[0].message.args[0] - == "The Text2SparqlRetriever component is deprecated and will be removed in future versions." - )