mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-23 07:03:45 +00:00
Enable batch mode for SAS cross encoders (#1987)
* enable batch mode for sas cross encoders * fix mypy * comment on top_1 values added
This commit is contained in:
parent
9c3d9b4885
commit
c861fdb2ce
@ -388,24 +388,31 @@ def semantic_answer_similarity(predictions: List[List[str]],
|
|||||||
# Compute similarities
|
# Compute similarities
|
||||||
top_1_sas = []
|
top_1_sas = []
|
||||||
top_k_sas = []
|
top_k_sas = []
|
||||||
|
lengths: List[Tuple[int,int]] = []
|
||||||
|
|
||||||
# Based on Modelstring we can load either Bi-Encoders or Cross Encoders.
|
# Based on Modelstring we can load either Bi-Encoders or Cross Encoders.
|
||||||
# Similarity computation changes for both approaches
|
# Similarity computation changes for both approaches
|
||||||
if cross_encoder_used:
|
if cross_encoder_used:
|
||||||
model = CrossEncoder(sas_model_name_or_path)
|
model = CrossEncoder(sas_model_name_or_path)
|
||||||
for preds, labels in zip (predictions,gold_labels):
|
|
||||||
# TODO add efficient batch mode: put all texts and labels into grid and extract scores afterwards
|
|
||||||
grid = []
|
grid = []
|
||||||
|
for preds, labels in zip (predictions,gold_labels):
|
||||||
for p in preds:
|
for p in preds:
|
||||||
for l in labels:
|
for l in labels:
|
||||||
grid.append((p,l))
|
grid.append((p,l))
|
||||||
|
lengths.append((len(preds), len(labels)))
|
||||||
scores = model.predict(grid)
|
scores = model.predict(grid)
|
||||||
top_1_sas.append(np.max(scores[:len(labels)]))
|
|
||||||
top_k_sas.append(np.max(scores))
|
current_position = 0
|
||||||
|
for len_p, len_l in lengths:
|
||||||
|
scores_window = scores[current_position:current_position+len_p*len_l]
|
||||||
|
# Per predicted doc there are len_l entries comparing it to all len_l labels.
|
||||||
|
# So to only consider the first doc we have to take the first len_l entries
|
||||||
|
top_1_sas.append(np.max(scores_window[:len_l]))
|
||||||
|
top_k_sas.append(np.max(scores_window))
|
||||||
|
current_position += len_p*len_l
|
||||||
else:
|
else:
|
||||||
# For Bi-encoders we can flatten predictions and labels into one list
|
# For Bi-encoders we can flatten predictions and labels into one list
|
||||||
model = SentenceTransformer(sas_model_name_or_path)
|
model = SentenceTransformer(sas_model_name_or_path)
|
||||||
lengths: List[Tuple[int,int]] = []
|
|
||||||
all_texts: List[str] = []
|
all_texts: List[str] = []
|
||||||
for p, l in zip(predictions, gold_labels): # type: ignore
|
for p, l in zip(predictions, gold_labels): # type: ignore
|
||||||
# TODO potentially exclude (near) exact matches from computations
|
# TODO potentially exclude (near) exact matches from computations
|
||||||
@ -417,7 +424,7 @@ def semantic_answer_similarity(predictions: List[List[str]],
|
|||||||
|
|
||||||
# then select which embeddings will be used for similarity computations
|
# then select which embeddings will be used for similarity computations
|
||||||
current_position = 0
|
current_position = 0
|
||||||
for i, (len_p, len_l) in enumerate(lengths):
|
for len_p, len_l in lengths:
|
||||||
pred_embeddings = embeddings[current_position:current_position + len_p, :]
|
pred_embeddings = embeddings[current_position:current_position + len_p, :]
|
||||||
current_position += len_p
|
current_position += len_p
|
||||||
label_embeddings = embeddings[current_position:current_position + len_l, :]
|
label_embeddings = embeddings[current_position:current_position + len_l, :]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user