mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
Sas gpu additions (#2308)
* Add batch_size and use_gpu to SAS from #2306 * Add batch_size and use_gpu to SAS from #2306 * Added docstrings for SAS-GPU to evluator.py * Added docstrings for SAS-GPU to pipelines/base.py * Typo fix in pipelines/base.py * streamline docstrings with related params in code base Co-authored-by: Thomas Stadelmann <thomas.stadelmann@deepset.ai>
This commit is contained in:
parent
8f7dd13eb9
commit
46fa166c36
@ -393,6 +393,8 @@ def semantic_answer_similarity(
|
||||
predictions: List[List[str]],
|
||||
gold_labels: List[List[str]],
|
||||
sas_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
batch_size: int = 32,
|
||||
use_gpu: bool = True
|
||||
) -> Tuple[List[float], List[float]]:
|
||||
"""
|
||||
Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1.
|
||||
@ -403,6 +405,9 @@ def semantic_answer_similarity(
|
||||
:param gold_labels: Labels as list of multiple possible answers per question
|
||||
:param sas_model_name_or_path: SentenceTransformers semantic textual similarity model, should be path or string
|
||||
pointing to downloadable models.
|
||||
:param batch_size: Number of prediction label pairs to encode at once.
|
||||
:param use_gpu: Whether to use a GPU or the CPU for calculating semantic answer similarity.
|
||||
Falls back to CPU if no GPU is available.
|
||||
:return: top_1_sas, top_k_sas
|
||||
"""
|
||||
assert len(predictions) == len(gold_labels)
|
||||
@ -411,6 +416,8 @@ def semantic_answer_similarity(
|
||||
cross_encoder_used = False
|
||||
if config.architectures is not None:
|
||||
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
|
||||
|
||||
device = None if use_gpu else 'cpu'
|
||||
|
||||
# Compute similarities
|
||||
top_1_sas = []
|
||||
@ -420,14 +427,14 @@ def semantic_answer_similarity(
|
||||
# Based on Modelstring we can load either Bi-Encoders or Cross Encoders.
|
||||
# Similarity computation changes for both approaches
|
||||
if cross_encoder_used:
|
||||
model = CrossEncoder(sas_model_name_or_path)
|
||||
model = CrossEncoder(sas_model_name_or_path, device=device)
|
||||
grid = []
|
||||
for preds, labels in zip(predictions, gold_labels):
|
||||
for p in preds:
|
||||
for l in labels:
|
||||
grid.append((p, l))
|
||||
lengths.append((len(preds), len(labels)))
|
||||
scores = model.predict(grid)
|
||||
scores = model.predict(grid, batch_size=batch_size)
|
||||
|
||||
current_position = 0
|
||||
for len_p, len_l in lengths:
|
||||
@ -439,7 +446,7 @@ def semantic_answer_similarity(
|
||||
current_position += len_p * len_l
|
||||
else:
|
||||
# For Bi-encoders we can flatten predictions and labels into one list
|
||||
model = SentenceTransformer(sas_model_name_or_path)
|
||||
model = SentenceTransformer(sas_model_name_or_path, device=device)
|
||||
all_texts: List[str] = []
|
||||
for p, l in zip(predictions, gold_labels): # type: ignore
|
||||
# TODO potentially exclude (near) exact matches from computations
|
||||
@ -447,7 +454,7 @@ def semantic_answer_similarity(
|
||||
all_texts.extend(l)
|
||||
lengths.append((len(p), len(l)))
|
||||
# then compute embeddings
|
||||
embeddings = model.encode(all_texts)
|
||||
embeddings = model.encode(all_texts, batch_size=batch_size)
|
||||
|
||||
# then select which embeddings will be used for similarity computations
|
||||
current_position = 0
|
||||
|
||||
@ -696,6 +696,8 @@ class Pipeline(BasePipeline):
|
||||
documents: Optional[List[List[Document]]] = None,
|
||||
params: Optional[dict] = None,
|
||||
sas_model_name_or_path: str = None,
|
||||
sas_batch_size: int = 32,
|
||||
sas_use_gpu: bool = True,
|
||||
add_isolated_node_eval: bool = False,
|
||||
) -> EvaluationResult:
|
||||
"""
|
||||
@ -719,6 +721,9 @@ class Pipeline(BasePipeline):
|
||||
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
||||
- Large model for German only: "deepset/gbert-large-sts"
|
||||
:param sas_batch_size: Number of prediction label pairs to encode at once by CrossEncoder or SentenceTransformer while calculating SAS.
|
||||
:param sas_use_gpu: Whether to use a GPU or the CPU for calculating semantic answer similarity.
|
||||
Falls back to CPU if no GPU is available.
|
||||
:param add_isolated_node_eval: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
|
||||
This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
|
||||
If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
|
||||
@ -761,7 +766,8 @@ class Pipeline(BasePipeline):
|
||||
gold_labels = df["gold_answers"].values
|
||||
predictions = [[a] for a in df["answer"].values]
|
||||
sas, _ = semantic_answer_similarity(
|
||||
predictions=predictions, gold_labels=gold_labels, sas_model_name_or_path=sas_model_name_or_path
|
||||
predictions=predictions, gold_labels=gold_labels, sas_model_name_or_path=sas_model_name_or_path,
|
||||
batch_size=sas_batch_size, use_gpu=sas_use_gpu
|
||||
)
|
||||
df["sas"] = sas
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user