Enable batch mode for SAS cross encoders (#1987)

* enable batch mode for sas cross encoders

* fix mypy

* comment on top_1 values added
This commit is contained in:
tstadel 2022-01-11 17:54:43 +01:00 committed by GitHub
parent 9c3d9b4885
commit c861fdb2ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -388,24 +388,31 @@ def semantic_answer_similarity(predictions: List[List[str]],
# Compute similarities
top_1_sas = []
top_k_sas = []
lengths: List[Tuple[int,int]] = []
# Based on Modelstring we can load either Bi-Encoders or Cross Encoders.
# Similarity computation changes for both approaches
if cross_encoder_used:
model = CrossEncoder(sas_model_name_or_path)
for preds, labels in zip (predictions,gold_labels):
# TODO add efficient batch mode: put all texts and labels into grid and extract scores afterwards
grid = []
model = CrossEncoder(sas_model_name_or_path)
grid = []
for preds, labels in zip (predictions,gold_labels):
for p in preds:
for l in labels:
grid.append((p,l))
scores = model.predict(grid)
top_1_sas.append(np.max(scores[:len(labels)]))
top_k_sas.append(np.max(scores))
lengths.append((len(preds), len(labels)))
scores = model.predict(grid)
current_position = 0
for len_p, len_l in lengths:
scores_window = scores[current_position:current_position+len_p*len_l]
# Per predicted doc there are len_l entries comparing it to all len_l labels.
# So to only consider the first doc we have to take the first len_l entries
top_1_sas.append(np.max(scores_window[:len_l]))
top_k_sas.append(np.max(scores_window))
current_position += len_p*len_l
else:
# For Bi-encoders we can flatten predictions and labels into one list
model = SentenceTransformer(sas_model_name_or_path)
lengths: List[Tuple[int,int]] = []
all_texts: List[str] = []
for p, l in zip(predictions, gold_labels): # type: ignore
# TODO potentially exclude (near) exact matches from computations
@ -417,7 +424,7 @@ def semantic_answer_similarity(predictions: List[List[str]],
# then select which embeddings will be used for similarity computations
current_position = 0
for i, (len_p, len_l) in enumerate(lengths):
for len_p, len_l in lengths:
pred_embeddings = embeddings[current_position:current_position + len_p, :]
current_position += len_p
label_embeddings = embeddings[current_position:current_position + len_l, :]