mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-08 05:03:10 +00:00
evaluation docstring
This commit is contained in:
parent
134a1ade23
commit
7ae0ecf4f0
@ -111,7 +111,7 @@ class AbsEvaluator:
|
||||
dataset_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called to the whole evaluation process.
|
||||
"""This is called during the evaluation process.
|
||||
|
||||
Args:
|
||||
splits (Union[str, List[str]]): Splits of datasets.
|
||||
|
||||
@ -14,6 +14,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEvalRunner:
|
||||
"""
|
||||
Abstract class of evaluation runner.
|
||||
|
||||
Args:
|
||||
eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments.
|
||||
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
eval_args: AbsEvalArgs,
|
||||
@ -28,6 +35,15 @@ class AbsEvalRunner:
|
||||
|
||||
@staticmethod
|
||||
def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]:
|
||||
"""Get the embedding and reranker model
|
||||
|
||||
Args:
|
||||
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and
|
||||
:class:FlagAutoReranker object of reranker model if path provided.
|
||||
"""
|
||||
embedder = FlagAutoModel.from_finetuned(
|
||||
model_name_or_path=model_args.embedder_name_or_path,
|
||||
model_class=model_args.embedder_model_class,
|
||||
@ -74,6 +90,12 @@ class AbsEvalRunner:
|
||||
return embedder, reranker
|
||||
|
||||
def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]:
|
||||
"""Load retriever and reranker for evaluation
|
||||
|
||||
Returns:
|
||||
Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a
|
||||
:class:EvalReranker object if reranker provided.
|
||||
"""
|
||||
embedder, reranker = self.get_models(self.model_args)
|
||||
retriever = EvalDenseRetriever(
|
||||
embedder,
|
||||
@ -85,6 +107,11 @@ class AbsEvalRunner:
|
||||
return retriever, reranker
|
||||
|
||||
def load_data_loader(self) -> AbsEvalDataLoader:
|
||||
"""Load the data loader
|
||||
|
||||
Returns:
|
||||
AbsEvalDataLoader: Data loader object for that specific task.
|
||||
"""
|
||||
data_loader = AbsEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
@ -95,6 +122,11 @@ class AbsEvalRunner:
|
||||
return data_loader
|
||||
|
||||
def load_evaluator(self) -> AbsEvaluator:
|
||||
"""Load the evaluator for evaluation
|
||||
|
||||
Returns:
|
||||
AbsEvaluator: the evaluator to run the evaluation.
|
||||
"""
|
||||
evaluator = AbsEvaluator(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
data_loader=self.data_loader,
|
||||
@ -109,6 +141,18 @@ class AbsEvalRunner:
|
||||
output_path: str = "./eval_dev_results.md",
|
||||
metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"]
|
||||
):
|
||||
"""Evaluate the provided metrics and write the results.
|
||||
|
||||
Args:
|
||||
search_results_save_dir (str): Path to save the search results.
|
||||
output_method (str, optional): Output results to `json` or `markdown`. Defaults to "markdown".
|
||||
output_path (str, optional): Path to write the output. Defaults to "./eval_dev_results.md".
|
||||
metrics (Union[str, List[str]], optional): metrics to use. Defaults to ["ndcg_at_10", "recall_at_10"].
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Eval results not found
|
||||
ValueError: Invalid output method
|
||||
"""
|
||||
eval_results_dict = {}
|
||||
for model_name in sorted(os.listdir(search_results_save_dir)):
|
||||
model_search_results_save_dir = os.path.join(search_results_save_dir, model_name)
|
||||
@ -136,6 +180,9 @@ class AbsEvalRunner:
|
||||
raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']")
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the whole evaluation.
|
||||
"""
|
||||
if self.eval_args.dataset_names is None:
|
||||
dataset_names = self.data_loader.available_dataset_names()
|
||||
else:
|
||||
|
||||
@ -16,6 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalRetriever(ABC):
|
||||
"""
|
||||
This is the base class for retriever.
|
||||
"""
|
||||
def __init__(self, embedder: AbsEmbedder, search_top_k: int = 1000, overwrite: bool = False):
|
||||
self.embedder = embedder
|
||||
self.search_top_k = search_top_k
|
||||
@ -45,7 +48,7 @@ class EvalRetriever(ABC):
|
||||
**kwargs,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
This is called during the retrieval process.
|
||||
Abstract method to be overrode. This is called during the retrieval process.
|
||||
|
||||
Parameters:
|
||||
corpus: Dict[str, Dict[str, Any]]: Corpus of documents.
|
||||
@ -63,6 +66,9 @@ class EvalRetriever(ABC):
|
||||
|
||||
|
||||
class EvalDenseRetriever(EvalRetriever):
|
||||
"""
|
||||
Child class of :class:EvalRetriever for dense retrieval.
|
||||
"""
|
||||
def __call__(
|
||||
self,
|
||||
corpus: Dict[str, Dict[str, Any]],
|
||||
@ -144,6 +150,9 @@ class EvalDenseRetriever(EvalRetriever):
|
||||
|
||||
|
||||
class EvalReranker:
|
||||
"""
|
||||
Class for reranker during evaluation.
|
||||
"""
|
||||
def __init__(self, reranker: AbsReranker, rerank_top_k: int = 100):
|
||||
self.reranker = reranker
|
||||
self.rerank_top_k = rerank_top_k
|
||||
|
||||
@ -16,6 +16,16 @@ def evaluate_mrr(
|
||||
results: Dict[str, Dict[str, float]],
|
||||
k_values: List[int],
|
||||
) -> Tuple[Dict[str, float]]:
|
||||
"""Compute mean reciprocal rank (MRR).
|
||||
|
||||
Args:
|
||||
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
|
||||
results (Dict[str, Dict[str, float]]): Search results to evaluate.
|
||||
k_values (List[int]): Cutoffs.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, float]]: MRR results at provided k values.
|
||||
"""
|
||||
mrr = defaultdict(list)
|
||||
|
||||
k_max, top_hits = max(k_values), {}
|
||||
@ -53,6 +63,17 @@ def evaluate_metrics(
|
||||
Dict[str, float],
|
||||
Dict[str, float],
|
||||
]:
|
||||
"""Evaluate the main metrics.
|
||||
|
||||
Args:
|
||||
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
|
||||
results (Dict[str, Dict[str, float]]): Search results to evaluate.
|
||||
k_values (List[int]): Cutoffs.
|
||||
|
||||
Returns:
|
||||
Tuple[ Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], ]: Results of different metrics at
|
||||
different provided k values.
|
||||
"""
|
||||
all_ndcgs, all_aps, all_recalls, all_precisions = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
|
||||
|
||||
map_string = "map_cut." + ",".join([str(k) for k in k_values])
|
||||
@ -93,6 +114,17 @@ def index(
|
||||
load_path: Optional[str] = None,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""Create and add embeddings into a Faiss index.
|
||||
|
||||
Args:
|
||||
index_factory (str, optional): Type of Faiss index to create. Defaults to "Flat".
|
||||
corpus_embeddings (Optional[np.ndarray], optional): The embedding vectors of the corpus. Defaults to None.
|
||||
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
|
||||
device (Optional[str], optional): Device to hold Faiss index. Defaults to None.
|
||||
|
||||
Returns:
|
||||
faiss.Index: The Faiss index that contains all the corpus embeddings.
|
||||
"""
|
||||
if corpus_embeddings is None:
|
||||
corpus_embeddings = np.load(load_path)
|
||||
|
||||
@ -127,6 +159,15 @@ def search(
|
||||
"""
|
||||
1. Encode queries into dense embeddings;
|
||||
2. Search through faiss index
|
||||
|
||||
Args:
|
||||
faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings.
|
||||
k (int, optional): Top k numbers of closest neighbours. Defaults to 100.
|
||||
query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to None.
|
||||
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices.
|
||||
"""
|
||||
if query_embeddings is None:
|
||||
query_embeddings = np.load(load_path)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user