fix: Fix TableTextRetriever for input consisting of tables only (#4048)

* fix: update kwargs for TriAdaptiveModel

* fix: squeeze batch for TTR inference

* test: add test for ttr + dataframe case

* test: update and reorganise ttr tests

* refactor: make triadaptive model handle shapes

* refactor: remove duplicate reshaping

* refactor: rename test with duplicate name

* fix: add device assignment back to TTR

* fix: remove duplicated vars in test

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Jack Butler 2023-02-09 10:38:16 +00:00 committed by GitHub
parent 986472c26f
commit e6b6f70ae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 20 deletions

View File

@ -705,9 +705,9 @@ class DPREncoder(LanguageModel):
:param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, number_of_hard_negative, max_seq_len].
:param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the
first sentence are marked with 0 and the tokens in the second sentence are marked with 1.
It is a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len].
It is a tensor of shape [batch_size, max_seq_len].
:param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, number_of_hard_negative_passages, max_seq_len].
of shape [batch_size, max_seq_len].
:param output_hidden_states: whether to add the hidden states along with the pooled output
:param output_attentions: unused
:return: Embeddings for each token in the input sequence.

View File

@ -310,12 +310,18 @@ class TriAdaptiveModel(nn.Module):
if "passage_input_ids" in kwargs.keys():
table_mask = torch.flatten(kwargs["is_table"]) == 1
# Make input two-dimensional
max_seq_len = kwargs["passage_input_ids"].shape[-1]
passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
# Current batch consists of only tables
if all(table_mask):
pooled_output2, _ = self.language_model3(
passage_input_ids=kwargs["passage_input_ids"],
passage_segment_ids=kwargs["table_segment_ids"],
passage_attention_mask=kwargs["passage_attention_mask"],
input_ids=passage_input_ids,
segment_ids=passage_segment_ids,
attention_mask=passage_attention_mask,
output_hidden_states=False,
output_attentions=False,
)
@ -323,12 +329,6 @@ class TriAdaptiveModel(nn.Module):
# Current batch consists of tables and texts
elif any(table_mask):
# Make input two-dimensional
max_seq_len = kwargs["passage_input_ids"].shape[-1]
passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
table_segment_ids = kwargs["table_segment_ids"].view(-1, max_seq_len)
table_input_ids = passage_input_ids[table_mask]
table_segment_ids = table_segment_ids[table_mask]
@ -375,16 +375,10 @@ class TriAdaptiveModel(nn.Module):
# Current batch consists of only texts
else:
# Make input two-dimensional
max_seq_len = kwargs["passage_input_ids"].shape[-1]
input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
pooled_output2, _ = self.language_model2(
input_ids=input_ids,
attention_mask=attention_mask,
segment_ids=segment_ids,
input_ids=passage_input_ids,
attention_mask=passage_attention_mask,
segment_ids=passage_segment_ids,
output_hidden_states=False,
output_attentions=False,
)

View File

@ -446,6 +446,32 @@ def test_table_text_retriever_embedding(document_store, retriever, docs):
assert isclose(doc.embedding[0], expected_value, rel_tol=0.001)
@pytest.mark.integration
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.embedding_dim(512)
def test_table_text_retriever_embedding_only_text(document_store, retriever):
docs = [
Document(content="This is a test", content_type="text"),
Document(content="This is another test", content_type="text"),
]
document_store.write_documents(docs)
document_store.update_embeddings(retriever)
@pytest.mark.integration
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.embedding_dim(512)
def test_table_text_retriever_embedding_only_table(document_store, retriever):
doc = Document(
content=pd.DataFrame(columns=["id", "text"], data=[["1", "This is a test"], ["2", "This is another test"]]),
content_type="table",
)
document_store.write_documents([doc])
document_store.update_embeddings(retriever)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):