diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py index a2290d466..82f3186fe 100644 --- a/haystack/modeling/model/language_model.py +++ b/haystack/modeling/model/language_model.py @@ -705,9 +705,9 @@ class DPREncoder(LanguageModel): :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, number_of_hard_negative, max_seq_len]. :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the first sentence are marked with 0 and the tokens in the second sentence are marked with 1. - It is a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len]. + It is a tensor of shape [batch_size, max_seq_len]. :param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, number_of_hard_negative_passages, max_seq_len]. + of shape [batch_size, max_seq_len]. :param output_hidden_states: whether to add the hidden states along with the pooled output :param output_attentions: unused :return: Embeddings for each token in the input sequence. diff --git a/haystack/modeling/model/triadaptive_model.py b/haystack/modeling/model/triadaptive_model.py index 62255e756..6c649b3c5 100644 --- a/haystack/modeling/model/triadaptive_model.py +++ b/haystack/modeling/model/triadaptive_model.py @@ -310,12 +310,18 @@ class TriAdaptiveModel(nn.Module): if "passage_input_ids" in kwargs.keys(): table_mask = torch.flatten(kwargs["is_table"]) == 1 + # Make input two-dimensional + max_seq_len = kwargs["passage_input_ids"].shape[-1] + passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len) + passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len) + passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len) + # Current batch consists of only tables if all(table_mask): pooled_output2, _ = self.language_model3( - passage_input_ids=kwargs["passage_input_ids"], - passage_segment_ids=kwargs["table_segment_ids"], - passage_attention_mask=kwargs["passage_attention_mask"], + input_ids=passage_input_ids, + segment_ids=passage_segment_ids, + attention_mask=passage_attention_mask, output_hidden_states=False, output_attentions=False, ) @@ -323,12 +329,6 @@ class TriAdaptiveModel(nn.Module): # Current batch consists of tables and texts elif any(table_mask): - # Make input two-dimensional - max_seq_len = kwargs["passage_input_ids"].shape[-1] - passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len) - passage_attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len) - passage_segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len) - table_segment_ids = kwargs["table_segment_ids"].view(-1, max_seq_len) table_input_ids = passage_input_ids[table_mask] table_segment_ids = table_segment_ids[table_mask] @@ -375,16 +375,10 @@ class TriAdaptiveModel(nn.Module): # Current batch consists of only texts else: - # Make input two-dimensional - max_seq_len = kwargs["passage_input_ids"].shape[-1] - input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len) - attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len) - segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len) - pooled_output2, _ = self.language_model2( - input_ids=input_ids, - attention_mask=attention_mask, - segment_ids=segment_ids, + input_ids=passage_input_ids, + attention_mask=passage_attention_mask, + segment_ids=passage_segment_ids, output_hidden_states=False, output_attentions=False, ) diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index a44cc9349..f4063be10 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -446,6 +446,32 @@ def test_table_text_retriever_embedding(document_store, retriever, docs): assert isclose(doc.embedding[0], expected_value, rel_tol=0.001) +@pytest.mark.integration +@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True) +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +@pytest.mark.embedding_dim(512) +def test_table_text_retriever_embedding_only_text(document_store, retriever): + docs = [ + Document(content="This is a test", content_type="text"), + Document(content="This is another test", content_type="text"), + ] + document_store.write_documents(docs) + document_store.update_embeddings(retriever) + + +@pytest.mark.integration +@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True) +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +@pytest.mark.embedding_dim(512) +def test_table_text_retriever_embedding_only_table(document_store, retriever): + doc = Document( + content=pd.DataFrame(columns=["id", "text"], data=[["1", "This is a test"], ["2", "This is another test"]]), + content_type="table", + ) + document_store.write_documents([doc]) + document_store.update_embeddings(retriever) + + @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) @pytest.mark.parametrize("document_store", ["memory"], indirect=True) def test_dpr_saving_and_loading(tmp_path, retriever, document_store):