mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-22 22:53:41 +00:00
Enable batch mode for SAS cross encoders (#1987)
* enable batch mode for sas cross encoders * fix mypy * comment on top_1 values added
This commit is contained in:
parent
9c3d9b4885
commit
c861fdb2ce
@ -388,24 +388,31 @@ def semantic_answer_similarity(predictions: List[List[str]],
|
||||
# Compute similarities
|
||||
top_1_sas = []
|
||||
top_k_sas = []
|
||||
lengths: List[Tuple[int,int]] = []
|
||||
|
||||
# Based on Modelstring we can load either Bi-Encoders or Cross Encoders.
|
||||
# Similarity computation changes for both approaches
|
||||
if cross_encoder_used:
|
||||
model = CrossEncoder(sas_model_name_or_path)
|
||||
for preds, labels in zip (predictions,gold_labels):
|
||||
# TODO add efficient batch mode: put all texts and labels into grid and extract scores afterwards
|
||||
grid = []
|
||||
model = CrossEncoder(sas_model_name_or_path)
|
||||
grid = []
|
||||
for preds, labels in zip (predictions,gold_labels):
|
||||
for p in preds:
|
||||
for l in labels:
|
||||
grid.append((p,l))
|
||||
scores = model.predict(grid)
|
||||
top_1_sas.append(np.max(scores[:len(labels)]))
|
||||
top_k_sas.append(np.max(scores))
|
||||
lengths.append((len(preds), len(labels)))
|
||||
scores = model.predict(grid)
|
||||
|
||||
current_position = 0
|
||||
for len_p, len_l in lengths:
|
||||
scores_window = scores[current_position:current_position+len_p*len_l]
|
||||
# Per predicted doc there are len_l entries comparing it to all len_l labels.
|
||||
# So to only consider the first doc we have to take the first len_l entries
|
||||
top_1_sas.append(np.max(scores_window[:len_l]))
|
||||
top_k_sas.append(np.max(scores_window))
|
||||
current_position += len_p*len_l
|
||||
else:
|
||||
# For Bi-encoders we can flatten predictions and labels into one list
|
||||
model = SentenceTransformer(sas_model_name_or_path)
|
||||
lengths: List[Tuple[int,int]] = []
|
||||
all_texts: List[str] = []
|
||||
for p, l in zip(predictions, gold_labels): # type: ignore
|
||||
# TODO potentially exclude (near) exact matches from computations
|
||||
@ -417,7 +424,7 @@ def semantic_answer_similarity(predictions: List[List[str]],
|
||||
|
||||
# then select which embeddings will be used for similarity computations
|
||||
current_position = 0
|
||||
for i, (len_p, len_l) in enumerate(lengths):
|
||||
for len_p, len_l in lengths:
|
||||
pred_embeddings = embeddings[current_position:current_position + len_p, :]
|
||||
current_position += len_p
|
||||
label_embeddings = embeddings[current_position:current_position + len_l, :]
|
||||
|
Loading…
x
Reference in New Issue
Block a user