evaluation docstring

This commit is contained in:
ZiyiXia 2024-10-31 16:14:13 +00:00
parent 134a1ade23
commit 7ae0ecf4f0
4 changed files with 99 additions and 2 deletions

View File

@ -111,7 +111,7 @@ class AbsEvaluator:
dataset_name: Optional[str] = None,
**kwargs,
):
"""Called to the whole evaluation process.
"""This is called during the evaluation process.
Args:
splits (Union[str, List[str]]): Splits of datasets.

View File

@ -14,6 +14,13 @@ logger = logging.getLogger(__name__)
class AbsEvalRunner:
"""
Abstract class of evaluation runner.
Args:
eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments.
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
"""
def __init__(
self,
eval_args: AbsEvalArgs,
@ -28,6 +35,15 @@ class AbsEvalRunner:
@staticmethod
def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]:
"""Get the embedding and reranker model
Args:
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
Returns:
Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and
:class:FlagAutoReranker object of reranker model if path provided.
"""
embedder = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
model_class=model_args.embedder_model_class,
@ -74,6 +90,12 @@ class AbsEvalRunner:
return embedder, reranker
def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]:
"""Load retriever and reranker for evaluation
Returns:
Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a
:class:EvalReranker object if reranker provided.
"""
embedder, reranker = self.get_models(self.model_args)
retriever = EvalDenseRetriever(
embedder,
@ -85,6 +107,11 @@ class AbsEvalRunner:
return retriever, reranker
def load_data_loader(self) -> AbsEvalDataLoader:
"""Load the data loader
Returns:
AbsEvalDataLoader: Data loader object for that specific task.
"""
data_loader = AbsEvalDataLoader(
eval_name=self.eval_args.eval_name,
dataset_dir=self.eval_args.dataset_dir,
@ -95,6 +122,11 @@ class AbsEvalRunner:
return data_loader
def load_evaluator(self) -> AbsEvaluator:
"""Load the evaluator for evaluation
Returns:
AbsEvaluator: the evaluator to run the evaluation.
"""
evaluator = AbsEvaluator(
eval_name=self.eval_args.eval_name,
data_loader=self.data_loader,
@ -109,6 +141,18 @@ class AbsEvalRunner:
output_path: str = "./eval_dev_results.md",
metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"]
):
"""Evaluate the provided metrics and write the results.
Args:
search_results_save_dir (str): Path to save the search results.
output_method (str, optional): Output results to `json` or `markdown`. Defaults to "markdown".
output_path (str, optional): Path to write the output. Defaults to "./eval_dev_results.md".
metrics (Union[str, List[str]], optional): metrics to use. Defaults to ["ndcg_at_10", "recall_at_10"].
Raises:
FileNotFoundError: Eval results not found
ValueError: Invalid output method
"""
eval_results_dict = {}
for model_name in sorted(os.listdir(search_results_save_dir)):
model_search_results_save_dir = os.path.join(search_results_save_dir, model_name)
@ -136,6 +180,9 @@ class AbsEvalRunner:
raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']")
def run(self):
"""
Run the whole evaluation.
"""
if self.eval_args.dataset_names is None:
dataset_names = self.data_loader.available_dataset_names()
else:

View File

@ -16,6 +16,9 @@ logger = logging.getLogger(__name__)
class EvalRetriever(ABC):
"""
This is the base class for retriever.
"""
def __init__(self, embedder: AbsEmbedder, search_top_k: int = 1000, overwrite: bool = False):
self.embedder = embedder
self.search_top_k = search_top_k
@ -45,7 +48,7 @@ class EvalRetriever(ABC):
**kwargs,
) -> Dict[str, Dict[str, float]]:
"""
This is called during the retrieval process.
Abstract method to be overrode. This is called during the retrieval process.
Parameters:
corpus: Dict[str, Dict[str, Any]]: Corpus of documents.
@ -63,6 +66,9 @@ class EvalRetriever(ABC):
class EvalDenseRetriever(EvalRetriever):
"""
Child class of :class:EvalRetriever for dense retrieval.
"""
def __call__(
self,
corpus: Dict[str, Dict[str, Any]],
@ -144,6 +150,9 @@ class EvalDenseRetriever(EvalRetriever):
class EvalReranker:
"""
Class for reranker during evaluation.
"""
def __init__(self, reranker: AbsReranker, rerank_top_k: int = 100):
self.reranker = reranker
self.rerank_top_k = rerank_top_k

View File

@ -16,6 +16,16 @@ def evaluate_mrr(
results: Dict[str, Dict[str, float]],
k_values: List[int],
) -> Tuple[Dict[str, float]]:
"""Compute mean reciprocal rank (MRR).
Args:
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
results (Dict[str, Dict[str, float]]): Search results to evaluate.
k_values (List[int]): Cutoffs.
Returns:
Tuple[Dict[str, float]]: MRR results at provided k values.
"""
mrr = defaultdict(list)
k_max, top_hits = max(k_values), {}
@ -53,6 +63,17 @@ def evaluate_metrics(
Dict[str, float],
Dict[str, float],
]:
"""Evaluate the main metrics.
Args:
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
results (Dict[str, Dict[str, float]]): Search results to evaluate.
k_values (List[int]): Cutoffs.
Returns:
Tuple[ Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], ]: Results of different metrics at
different provided k values.
"""
all_ndcgs, all_aps, all_recalls, all_precisions = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
map_string = "map_cut." + ",".join([str(k) for k in k_values])
@ -93,6 +114,17 @@ def index(
load_path: Optional[str] = None,
device: Optional[str] = None
):
"""Create and add embeddings into a Faiss index.
Args:
index_factory (str, optional): Type of Faiss index to create. Defaults to "Flat".
corpus_embeddings (Optional[np.ndarray], optional): The embedding vectors of the corpus. Defaults to None.
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
device (Optional[str], optional): Device to hold Faiss index. Defaults to None.
Returns:
faiss.Index: The Faiss index that contains all the corpus embeddings.
"""
if corpus_embeddings is None:
corpus_embeddings = np.load(load_path)
@ -127,6 +159,15 @@ def search(
"""
1. Encode queries into dense embeddings;
2. Search through faiss index
Args:
faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings.
k (int, optional): Top k numbers of closest neighbours. Defaults to 100.
query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to None.
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
Returns:
Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices.
"""
if query_embeddings is None:
query_embeddings = np.load(load_path)