diff --git a/docs/_src/api/api/reader.md b/docs/_src/api/api/reader.md index 1aa0c4883..a6d816255 100644 --- a/docs/_src/api/api/reader.md +++ b/docs/_src/api/api/reader.md @@ -609,3 +609,75 @@ WARNING: The answer scores are not reliable, as they are always extremely high, Dict containing query and answers + +## RCIReader + +```python +class RCIReader(BaseReader) +``` + +Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model. +See the original paper for more details: +Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables" +(https://aclanthology.org/2021.naacl-main.96/) + +Each row and each column is given a score with regard to the query by two separate models. The score of each cell +is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is +the cell with the highest score. + + +#### \_\_init\_\_ + +```python + | __init__(row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row", column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col", row_model_version: Optional[str] = None, column_model_version: Optional[str] = None, row_tokenizer: Optional[str] = None, column_tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256) +``` + +Load an RCI model from Transformers. +Available models include: + +- ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'`` +- ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'`` + + + +**Arguments**: + +- `row_model_name_or_path`: Directory of a saved row scoring model or the name of a public model +- `column_model_name_or_path`: Directory of a saved column scoring model or the name of a public model +- `row_model_version`: The version of row model to use from the HuggingFace model hub. + Can be tag name, branch name, or commit hash. +- `column_model_version`: The version of column model to use from the HuggingFace model hub. + Can be tag name, branch name, or commit hash. +- `row_tokenizer`: Name of the tokenizer for the row model (usually the same as model) +- `column_tokenizer`: Name of the tokenizer for the column model (usually the same as model) +- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available. +- `top_k`: The maximum number of answers to return +- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of + query + table exceed max_seq_len, the table will be truncated by removing rows until the + input size fits the model. + + +#### predict + +```python + | predict(query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict +``` + +Use loaded RCI models to find answers for a query in the supplied list of Documents +of content_type ``'table'``. + +Returns dictionary containing query and list of Answer objects sorted by (desc.) score. +The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be +composed of a single cell. + +**Arguments**: + +- `query`: Query string +- `documents`: List of Document in which to search for the answer. Documents should be + of content_type ``'table'``. +- `top_k`: The maximum number of answers to return + +**Returns**: + +Dict containing query and answers + diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index 807728fe9..f9cd1008b 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -23,7 +23,7 @@ from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier from haystack.nodes.question_generator import QuestionGenerator from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker -from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader +from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader, RCIReader from haystack.nodes.retriever import ( BaseRetriever, DensePassageRetriever, diff --git a/haystack/nodes/reader/__init__.py b/haystack/nodes/reader/__init__.py index e06975298..0d80527e0 100644 --- a/haystack/nodes/reader/__init__.py +++ b/haystack/nodes/reader/__init__.py @@ -1,4 +1,4 @@ from haystack.nodes.reader.base import BaseReader from haystack.nodes.reader.farm import FARMReader from haystack.nodes.reader.transformers import TransformersReader -from haystack.nodes.reader.table import TableReader +from haystack.nodes.reader.table import TableReader, RCIReader diff --git a/haystack/nodes/reader/table.py b/haystack/nodes/reader/table.py index b4453383a..711e256e1 100644 --- a/haystack/nodes/reader/table.py +++ b/haystack/nodes/reader/table.py @@ -6,7 +6,8 @@ import torch import numpy as np import pandas as pd from quantulum3 import parser -from transformers import TapasTokenizer, TapasForQuestionAnswering, BatchEncoding +from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \ + BatchEncoding, AutoConfig from haystack.schema import Document, Answer, Span from haystack.nodes.reader.base import BaseReader @@ -252,3 +253,209 @@ class TableReader(BaseReader): def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None): raise NotImplementedError("Batch prediction not yet available in TableReader.") + + +class RCIReader(BaseReader): + """ + Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model. + See the original paper for more details: + Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables" + (https://aclanthology.org/2021.naacl-main.96/) + + Each row and each column is given a score with regard to the query by two separate models. The score of each cell + is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is + the cell with the highest score. + + Pros and Cons of RCIReader compared to TableReader: + + Provides meaningful confidence scores + + Allows larger tables as input + - Does not support aggregation over table cells + - Slower + """ + + def __init__(self, + row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row", + column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col", + row_model_version: Optional[str] = None, + column_model_version: Optional[str] = None, + row_tokenizer: Optional[str] = None, + column_tokenizer: Optional[str] = None, + use_gpu: bool = True, + top_k: int = 10, + max_seq_len: int = 256, + ): + """ + Load an RCI model from Transformers. + Available models include: + + - ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'`` + - ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'`` + + + + :param row_model_name_or_path: Directory of a saved row scoring model or the name of a public model + :param column_model_name_or_path: Directory of a saved column scoring model or the name of a public model + :param row_model_version: The version of row model to use from the HuggingFace model hub. + Can be tag name, branch name, or commit hash. + :param column_model_version: The version of column model to use from the HuggingFace model hub. + Can be tag name, branch name, or commit hash. + :param row_tokenizer: Name of the tokenizer for the row model (usually the same as model) + :param column_tokenizer: Name of the tokenizer for the column model (usually the same as model) + :param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available. + :param top_k: The maximum number of answers to return + :param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of + query + table exceed max_seq_len, the table will be truncated by removing rows until the + input size fits the model. + """ + # Save init parameters to enable export of component config as YAML + self.set_config(row_model_name_or_path=row_model_name_or_path, + column_model_name_or_path=column_model_name_or_path, row_model_version=row_model_version, + column_model_version=column_model_version, row_tokenizer=row_tokenizer, + column_tokenizer=column_tokenizer, use_gpu=use_gpu, top_k=top_k, max_seq_len=max_seq_len) + + self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) + self.row_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path, + revision=row_model_version) + self.column_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path, + revision=column_model_version) + self.row_model.to(str(self.devices[0])) + self.column_model.to(str(self.devices[0])) + + if row_tokenizer is None: + try: + self.row_tokenizer = AutoTokenizer.from_pretrained(row_model_name_or_path) + # The existing RCI models on the model hub don't come with tokenizer vocab files. + except TypeError: + self.row_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") + else: + self.row_tokenizer = AutoTokenizer.from_pretrained(row_tokenizer) + + if column_tokenizer is None: + try: + self.column_tokenizer = AutoTokenizer.from_pretrained(column_model_name_or_path) + # The existing RCI models on the model hub don't come with tokenizer vocab files. + except TypeError: + self.column_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") + else: + self.column_tokenizer = AutoTokenizer.from_pretrained(column_tokenizer) + + self.top_k = top_k + self.max_seq_len = max_seq_len + self.return_no_answers = False + + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict: + """ + Use loaded RCI models to find answers for a query in the supplied list of Documents + of content_type ``'table'``. + + Returns dictionary containing query and list of Answer objects sorted by (desc.) score. + The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be + composed of a single cell. + + :param query: Query string + :param documents: List of Document in which to search for the answer. Documents should be + of content_type ``'table'``. + :param top_k: The maximum number of answers to return + :return: Dict containing query and answers + """ + if top_k is None: + top_k = self.top_k + + answers = [] + for document in documents: + if document.content_type != "table": + logger.warning(f"Skipping document with id {document.id} in RCIReader, as it is not of type table.") + continue + + table: pd.DataFrame = document.content + table = table.astype(str) + # Create row and column representations + row_reps, column_reps = self._create_row_column_representations(table) + + # Get row logits + row_inputs = self.row_tokenizer.batch_encode_plus( + batch_text_or_text_pairs=[(query, row_rep) for row_rep in row_reps], + max_length=self.max_seq_len, + return_tensors="pt", + add_special_tokens=True, + truncation=True, + padding=True + ) + row_inputs.to(self.devices[0]) + row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1] + + # Get column logits + column_inputs = self.column_tokenizer.batch_encode_plus( + batch_text_or_text_pairs=[(query, column_rep) for column_rep in column_reps], + max_length=self.max_seq_len, + return_tensors="pt", + add_special_tokens=True, + truncation=True, + padding=True + ) + column_inputs.to(self.devices[0]) + column_logits = self.column_model(**column_inputs)[0].detach().cpu().numpy()[:, 1] + + # Calculate cell scores + current_answers: List[Answer] = [] + cell_scores_table: List[List[float]] = [] + for row_idx, row_score in enumerate(row_logits): + cell_scores_table.append([]) + for col_idx, col_score in enumerate(column_logits): + current_cell_score = float(row_score + col_score) + cell_scores_table[-1].append(current_cell_score) + + answer_str = table.iloc[row_idx, col_idx] + answer_offsets = self._calculate_answer_offsets(row_idx, col_idx, table) + current_answers.append( + Answer( + answer=answer_str, + type="extractive", + score=current_cell_score, + context=table, + offsets_in_document=[answer_offsets], + offsets_in_context=[answer_offsets], + document_id=document.id, + ) + ) + + # Add cell scores to Answers' meta to be able to use as heatmap + for answer in current_answers: + answer.meta = {"table_scores": cell_scores_table} + answers.extend(current_answers) + + # Sort answers by score and select top-k answers + answers = sorted(answers, reverse=True) + answers = answers[:top_k] + + results = {"query": query, + "answers": answers} + + return results + + @staticmethod + def _create_row_column_representations(table: pd.DataFrame) -> Tuple[List[str], List[str]]: + row_reps = [] + column_reps = [] + columns = table.columns + + for idx, row in table.iterrows(): + current_row_rep = " * ".join([header + " : " + cell for header, cell in zip(columns, row)]) + row_reps.append(current_row_rep) + + for col_name in columns: + current_column_rep = f"{col_name} * " + current_column_rep += " * ".join(table[col_name]) + column_reps.append(current_column_rep) + + return row_reps, column_reps + + @staticmethod + def _calculate_answer_offsets(row_idx, column_index, table) -> Span: + n_rows, n_columns = table.shape + answer_cell_offset = (row_idx * n_columns) + column_index + + return Span(start=answer_cell_offset, end=answer_cell_offset + 1) + + def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None): + raise NotImplementedError("Batch prediction not yet available in RCIReader.") diff --git a/test/conftest.py b/test/conftest.py index 117f5063f..eccce83b7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -34,7 +34,7 @@ from haystack.document_stores.memory import InMemoryDocumentStore from haystack.document_stores.sql import SQLDocumentStore from haystack.nodes.reader.farm import FARMReader from haystack.nodes.reader.transformers import TransformersReader -from haystack.nodes.reader.table import TableReader +from haystack.nodes.reader.table import TableReader, RCIReader from haystack.nodes.summarizer.transformers import TransformersSummarizer from haystack.nodes.translator import TransformersTranslator from haystack.nodes.question_generator import QuestionGenerator @@ -338,9 +338,13 @@ def reader(request): ) -@pytest.fixture(scope="function") -def table_reader(): - return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq") +@pytest.fixture(params=["tapas", "rci"], scope="function") +def table_reader(request): + if request.param == "tapas": + return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq") + elif request.param == "rci": + return RCIReader(row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row", + column_model_name_or_path="michaelrglass/albert-base-rci-wikisql-col") @pytest.fixture(scope="function") diff --git a/test/test_table_reader.py b/test/test_table_reader.py index 6d5f52c79..a99865d13 100644 --- a/test/test_table_reader.py +++ b/test/test_table_reader.py @@ -1,4 +1,5 @@ import pandas as pd +import pytest from haystack.schema import Document from haystack.pipelines.base import Pipeline @@ -7,53 +8,39 @@ from haystack.pipelines.base import Pipeline def test_table_reader(table_reader): data = { "actors": ["brad pitt", "leonardo di caprio", "george clooney"], - "age": ["57", "46", "60"], + "age": ["58", "47", "60"], "number of movies": ["87", "53", "69"], - "date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"], + "date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"], } table = pd.DataFrame(data) - query = "When was DiCaprio born?" + query = "When was Di Caprio born?" prediction = table_reader.predict(query=query, documents=[Document(content=table, content_type="table")]) - assert prediction["answers"][0].answer == "10 june 1996" + assert prediction["answers"][0].answer == "11 november 1974" assert prediction["answers"][0].offsets_in_context[0].start == 7 assert prediction["answers"][0].offsets_in_context[0].end == 8 - # test aggregation - query = "How old are DiCaprio and Pitt on average?" - prediction = table_reader.predict(query=query, documents=[Document(content=table, content_type="table")]) - assert prediction["answers"][0].answer == "51.5" - assert prediction["answers"][0].meta["answer_cells"] == ["57", "46"] - assert prediction["answers"][0].meta["aggregation_operator"] == "AVERAGE" - assert prediction["answers"][0].offsets_in_context[0].start == 1 - assert prediction["answers"][0].offsets_in_context[0].end == 2 - assert prediction["answers"][0].offsets_in_context[1].start == 5 - assert prediction["answers"][0].offsets_in_context[1].end == 6 - def test_table_reader_in_pipeline(table_reader): pipeline = Pipeline() pipeline.add_node(table_reader, "TableReader", ["Query"]) data = { "actors": ["brad pitt", "leonardo di caprio", "george clooney"], - "age": ["57", "46", "60"], + "age": ["58", "47", "60"], "number of movies": ["87", "53", "69"], - "date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"], + "date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"], } table = pd.DataFrame(data) - query = "Which actors played in more than 60 movies?" + query = "When was Di Caprio born?" prediction = pipeline.run(query=query, documents=[Document(content=table, content_type="table")]) - assert prediction["answers"][0].answer == "brad pitt, george clooney" - assert prediction["answers"][0].meta["aggregation_operator"] == "NONE" - assert prediction["answers"][0].offsets_in_context[0].start == 0 - assert prediction["answers"][0].offsets_in_context[0].end == 1 - assert prediction["answers"][0].offsets_in_context[1].start == 8 - assert prediction["answers"][0].offsets_in_context[1].end == 9 - + assert prediction["answers"][0].answer == "11 november 1974" + assert prediction["answers"][0].offsets_in_context[0].start == 7 + assert prediction["answers"][0].offsets_in_context[0].end == 8 +@pytest.mark.parametrize("table_reader", ["tapas"], indirect=True) def test_table_reader_aggregation(table_reader): data = { "Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"], @@ -72,4 +59,3 @@ def test_table_reader_aggregation(table_reader): assert prediction["answers"][0].answer == "43046.0 m" assert prediction["answers"][0].meta["aggregation_operator"] == "SUM" assert prediction["answers"][0].meta["answer_cells"] == ['8848m', '8,611 m', '8 586m', '8 516 m', '8,485m'] -