mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
fix: Fix TableTextRetriever for input consisting of tables only (#4048)
* fix: update kwargs for TriAdaptiveModel * fix: squeeze batch for TTR inference * test: add test for ttr + dataframe case * test: update and reorganise ttr tests * refactor: make triadaptive model handle shapes * refactor: remove duplicate reshaping * refactor: rename test with duplicate name * fix: add device assignment back to TTR * fix: remove duplicated vars in test --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
986472c26f
commit
e6b6f70ae2
@ -705,9 +705,9 @@ class DPREncoder(LanguageModel):
|
||||
:param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, number_of_hard_negative, max_seq_len].
|
||||
:param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the
|
||||
first sentence are marked with 0 and the tokens in the second sentence are marked with 1.
|
||||
It is a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len].
|
||||
It is a tensor of shape [batch_size, max_seq_len].
|
||||
:param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, number_of_hard_negative_passages, max_seq_len].
|
||||
of shape [batch_size, max_seq_len].
|
||||
:param output_hidden_states: whether to add the hidden states along with the pooled output
|
||||
:param output_attentions: unused
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
|
||||
@ -310,12 +310,18 @@ class TriAdaptiveModel(nn.Module):
|
||||
if "passage_input_ids" in kwargs.keys():
|
||||
table_mask = torch.flatten(kwargs["is_table"]) == 1
|
||||
|
||||
# Make input two-dimensional
|
||||
max_seq_len = kwargs["passage_input_ids"].shape[-1]
|
||||
passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
|
||||
passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
|
||||
passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
|
||||
|
||||
# Current batch consists of only tables
|
||||
if all(table_mask):
|
||||
pooled_output2, _ = self.language_model3(
|
||||
passage_input_ids=kwargs["passage_input_ids"],
|
||||
passage_segment_ids=kwargs["table_segment_ids"],
|
||||
passage_attention_mask=kwargs["passage_attention_mask"],
|
||||
input_ids=passage_input_ids,
|
||||
segment_ids=passage_segment_ids,
|
||||
attention_mask=passage_attention_mask,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
)
|
||||
@ -323,12 +329,6 @@ class TriAdaptiveModel(nn.Module):
|
||||
|
||||
# Current batch consists of tables and texts
|
||||
elif any(table_mask):
|
||||
# Make input two-dimensional
|
||||
max_seq_len = kwargs["passage_input_ids"].shape[-1]
|
||||
passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
|
||||
passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
|
||||
passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
|
||||
|
||||
table_segment_ids = kwargs["table_segment_ids"].view(-1, max_seq_len)
|
||||
table_input_ids = passage_input_ids[table_mask]
|
||||
table_segment_ids = table_segment_ids[table_mask]
|
||||
@ -375,16 +375,10 @@ class TriAdaptiveModel(nn.Module):
|
||||
|
||||
# Current batch consists of only texts
|
||||
else:
|
||||
# Make input two-dimensional
|
||||
max_seq_len = kwargs["passage_input_ids"].shape[-1]
|
||||
input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
|
||||
attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
|
||||
segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
|
||||
|
||||
pooled_output2, _ = self.language_model2(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
segment_ids=segment_ids,
|
||||
input_ids=passage_input_ids,
|
||||
attention_mask=passage_attention_mask,
|
||||
segment_ids=passage_segment_ids,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
)
|
||||
|
||||
@ -446,6 +446,32 @@ def test_table_text_retriever_embedding(document_store, retriever, docs):
|
||||
assert isclose(doc.embedding[0], expected_value, rel_tol=0.001)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.embedding_dim(512)
|
||||
def test_table_text_retriever_embedding_only_text(document_store, retriever):
|
||||
docs = [
|
||||
Document(content="This is a test", content_type="text"),
|
||||
Document(content="This is another test", content_type="text"),
|
||||
]
|
||||
document_store.write_documents(docs)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.embedding_dim(512)
|
||||
def test_table_text_retriever_embedding_only_table(document_store, retriever):
|
||||
doc = Document(
|
||||
content=pd.DataFrame(columns=["id", "text"], data=[["1", "This is a test"], ["2", "This is another test"]]),
|
||||
content_type="table",
|
||||
)
|
||||
document_store.write_documents([doc])
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user