mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-12 08:03:50 +00:00
fix: Fix TableTextRetriever for input consisting of tables only (#4048)
* fix: update kwargs for TriAdaptiveModel * fix: squeeze batch for TTR inference * test: add test for ttr + dataframe case * test: update and reorganise ttr tests * refactor: make triadaptive model handle shapes * refactor: remove duplicate reshaping * refactor: rename test with duplicate name * fix: add device assignment back to TTR * fix: remove duplicated vars in test --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
986472c26f
commit
e6b6f70ae2
@ -705,9 +705,9 @@ class DPREncoder(LanguageModel):
|
|||||||
:param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, number_of_hard_negative, max_seq_len].
|
:param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, number_of_hard_negative, max_seq_len].
|
||||||
:param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the
|
:param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the
|
||||||
first sentence are marked with 0 and the tokens in the second sentence are marked with 1.
|
first sentence are marked with 0 and the tokens in the second sentence are marked with 1.
|
||||||
It is a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len].
|
It is a tensor of shape [batch_size, max_seq_len].
|
||||||
:param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens
|
:param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens
|
||||||
of shape [batch_size, number_of_hard_negative_passages, max_seq_len].
|
of shape [batch_size, max_seq_len].
|
||||||
:param output_hidden_states: whether to add the hidden states along with the pooled output
|
:param output_hidden_states: whether to add the hidden states along with the pooled output
|
||||||
:param output_attentions: unused
|
:param output_attentions: unused
|
||||||
:return: Embeddings for each token in the input sequence.
|
:return: Embeddings for each token in the input sequence.
|
||||||
|
|||||||
@ -310,12 +310,18 @@ class TriAdaptiveModel(nn.Module):
|
|||||||
if "passage_input_ids" in kwargs.keys():
|
if "passage_input_ids" in kwargs.keys():
|
||||||
table_mask = torch.flatten(kwargs["is_table"]) == 1
|
table_mask = torch.flatten(kwargs["is_table"]) == 1
|
||||||
|
|
||||||
|
# Make input two-dimensional
|
||||||
|
max_seq_len = kwargs["passage_input_ids"].shape[-1]
|
||||||
|
passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
|
||||||
|
passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
|
||||||
|
passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
|
||||||
|
|
||||||
# Current batch consists of only tables
|
# Current batch consists of only tables
|
||||||
if all(table_mask):
|
if all(table_mask):
|
||||||
pooled_output2, _ = self.language_model3(
|
pooled_output2, _ = self.language_model3(
|
||||||
passage_input_ids=kwargs["passage_input_ids"],
|
input_ids=passage_input_ids,
|
||||||
passage_segment_ids=kwargs["table_segment_ids"],
|
segment_ids=passage_segment_ids,
|
||||||
passage_attention_mask=kwargs["passage_attention_mask"],
|
attention_mask=passage_attention_mask,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
)
|
)
|
||||||
@ -323,12 +329,6 @@ class TriAdaptiveModel(nn.Module):
|
|||||||
|
|
||||||
# Current batch consists of tables and texts
|
# Current batch consists of tables and texts
|
||||||
elif any(table_mask):
|
elif any(table_mask):
|
||||||
# Make input two-dimensional
|
|
||||||
max_seq_len = kwargs["passage_input_ids"].shape[-1]
|
|
||||||
passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
|
|
||||||
passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
|
|
||||||
passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
|
|
||||||
|
|
||||||
table_segment_ids = kwargs["table_segment_ids"].view(-1, max_seq_len)
|
table_segment_ids = kwargs["table_segment_ids"].view(-1, max_seq_len)
|
||||||
table_input_ids = passage_input_ids[table_mask]
|
table_input_ids = passage_input_ids[table_mask]
|
||||||
table_segment_ids = table_segment_ids[table_mask]
|
table_segment_ids = table_segment_ids[table_mask]
|
||||||
@ -375,16 +375,10 @@ class TriAdaptiveModel(nn.Module):
|
|||||||
|
|
||||||
# Current batch consists of only texts
|
# Current batch consists of only texts
|
||||||
else:
|
else:
|
||||||
# Make input two-dimensional
|
|
||||||
max_seq_len = kwargs["passage_input_ids"].shape[-1]
|
|
||||||
input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
|
|
||||||
attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len)
|
|
||||||
segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len)
|
|
||||||
|
|
||||||
pooled_output2, _ = self.language_model2(
|
pooled_output2, _ = self.language_model2(
|
||||||
input_ids=input_ids,
|
input_ids=passage_input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=passage_attention_mask,
|
||||||
segment_ids=segment_ids,
|
segment_ids=passage_segment_ids,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -446,6 +446,32 @@ def test_table_text_retriever_embedding(document_store, retriever, docs):
|
|||||||
assert isclose(doc.embedding[0], expected_value, rel_tol=0.001)
|
assert isclose(doc.embedding[0], expected_value, rel_tol=0.001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
|
@pytest.mark.embedding_dim(512)
|
||||||
|
def test_table_text_retriever_embedding_only_text(document_store, retriever):
|
||||||
|
docs = [
|
||||||
|
Document(content="This is a test", content_type="text"),
|
||||||
|
Document(content="This is another test", content_type="text"),
|
||||||
|
]
|
||||||
|
document_store.write_documents(docs)
|
||||||
|
document_store.update_embeddings(retriever)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
|
@pytest.mark.embedding_dim(512)
|
||||||
|
def test_table_text_retriever_embedding_only_table(document_store, retriever):
|
||||||
|
doc = Document(
|
||||||
|
content=pd.DataFrame(columns=["id", "text"], data=[["1", "This is a test"], ["2", "This is another test"]]),
|
||||||
|
content_type="table",
|
||||||
|
)
|
||||||
|
document_store.write_documents([doc])
|
||||||
|
document_store.update_embeddings(retriever)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
|
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user