mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-06 04:01:35 +00:00
evaluation
This commit is contained in:
parent
1df8b3637e
commit
3f071736d4
@ -113,7 +113,7 @@ class AbsEvalDataLoader(ABC):
|
||||
return self._load_remote_corpus(dataset_name=dataset_name)
|
||||
|
||||
def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load the corpus from the dataset.
|
||||
"""Load the qrels from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
|
||||
|
||||
@ -11,10 +11,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MIRACLEvalDataLoader(AbsEvalDataLoader):
|
||||
"""
|
||||
Data loader class for MIRACL.
|
||||
"""
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Get the available dataset names.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available dataset names.
|
||||
"""
|
||||
return ["ar", "bn", "en", "es", "fa", "fi", "fr", "hi", "id", "ja", "ko", "ru", "sw", "te", "th", "zh", "de", "yo"]
|
||||
|
||||
def available_splits(self, dataset_name: str) -> List[str]:
|
||||
"""
|
||||
Get the avaialble splits.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Dataset name.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available splits for the dataset.
|
||||
"""
|
||||
if dataset_name in ["de", "yo"]:
|
||||
return ["dev"]
|
||||
else:
|
||||
@ -25,6 +43,15 @@ class MIRACLEvalDataLoader(AbsEvalDataLoader):
|
||||
dataset_name: str,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the corpus dataset from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of corpus.
|
||||
"""
|
||||
corpus = datasets.load_dataset(
|
||||
"miracl/miracl-corpus", dataset_name,
|
||||
cache_dir=self.cache_dir,
|
||||
@ -60,6 +87,16 @@ class MIRACLEvalDataLoader(AbsEvalDataLoader):
|
||||
split: str = 'dev',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the qrels from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrel.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl"
|
||||
qrels_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/qrels/qrels.miracl-v1.0-{dataset_name}-{split}.tsv"
|
||||
|
||||
@ -101,6 +138,16 @@ class MIRACLEvalDataLoader(AbsEvalDataLoader):
|
||||
split: str = 'dev',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the queries from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of queries.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl"
|
||||
queries_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/topics/topics.miracl-v1.0-{dataset_name}-{split}.tsv"
|
||||
|
||||
|
||||
@ -4,7 +4,15 @@ from .data_loader import MIRACLEvalDataLoader
|
||||
|
||||
|
||||
class MIRACLEvalRunner(AbsEvalRunner):
|
||||
"""
|
||||
Evaluation runner of MIRACL.
|
||||
"""
|
||||
def load_data_loader(self) -> MIRACLEvalDataLoader:
|
||||
"""Load the data loader instance by args.
|
||||
|
||||
Returns:
|
||||
MIRACLEvalDataLoader: The MIRACL data loader instance.
|
||||
"""
|
||||
data_loader = MIRACLEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
|
||||
@ -4,6 +4,7 @@ from FlagEmbedding.abc.evaluation import (
|
||||
)
|
||||
|
||||
from .data_loader import MKQAEvalDataLoader
|
||||
from .evaluator import MKQAEvaluator
|
||||
from .runner import MKQAEvalRunner
|
||||
|
||||
__all__ = [
|
||||
@ -11,4 +12,5 @@ __all__ = [
|
||||
"MKQAEvalModelArgs",
|
||||
"MKQAEvalRunner",
|
||||
"MKQAEvalDataLoader",
|
||||
"MKQAEvaluator"
|
||||
]
|
||||
|
||||
@ -13,13 +13,39 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MKQAEvalDataLoader(AbsEvalDataLoader):
|
||||
"""
|
||||
Data loader class for MKQA.
|
||||
"""
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Get the available dataset names.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available dataset names.
|
||||
"""
|
||||
return ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
|
||||
|
||||
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Get the avaialble splits.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Dataset name.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available splits for the dataset.
|
||||
"""
|
||||
return ["test"]
|
||||
|
||||
def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
|
||||
"""Load the corpus.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of corpus.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
# same corpus for all languages
|
||||
save_dir = self.dataset_dir
|
||||
@ -28,6 +54,19 @@ class MKQAEvalDataLoader(AbsEvalDataLoader):
|
||||
return self._load_remote_corpus(dataset_name=dataset_name)
|
||||
|
||||
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Try to load qrels from local datasets.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory that save the data files.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError: No local qrels found, will try to download from remote.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrels.
|
||||
"""
|
||||
checked_split = self.check_splits(split)
|
||||
if len(checked_split) == 0:
|
||||
raise ValueError(f"Split {split} not found in the dataset.")
|
||||
@ -96,6 +135,16 @@ class MKQAEvalDataLoader(AbsEvalDataLoader):
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load remote qrels from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrel.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
|
||||
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"
|
||||
|
||||
@ -137,6 +186,16 @@ class MKQAEvalDataLoader(AbsEvalDataLoader):
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the queries from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of queries.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
|
||||
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"
|
||||
|
||||
|
||||
@ -8,12 +8,25 @@ from .utils.compute_metrics import evaluate_qa_recall
|
||||
|
||||
|
||||
class MKQAEvaluator(AbsEvaluator):
|
||||
"""
|
||||
The evaluator class of MKQA.
|
||||
"""
|
||||
def get_corpus_embd_save_dir(
|
||||
self,
|
||||
retriever_name: str,
|
||||
corpus_embd_save_dir: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None
|
||||
):
|
||||
"""Get the directory to save the corpus embedding.
|
||||
|
||||
Args:
|
||||
retriever_name (str): Name of the retriever.
|
||||
corpus_embd_save_dir (Optional[str], optional): Directory to save the corpus embedding. Defaults to ``None``.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
str: The final directory to save the corpus embedding.
|
||||
"""
|
||||
if corpus_embd_save_dir is not None:
|
||||
# Save the corpus embeddings in the same directory for all dataset_name
|
||||
corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, retriever_name)
|
||||
@ -24,6 +37,15 @@ class MKQAEvaluator(AbsEvaluator):
|
||||
search_results_save_dir: str,
|
||||
k_values: List[int] = [1, 3, 5, 10, 100, 1000]
|
||||
):
|
||||
"""Compute the metrics and get the eval results.
|
||||
|
||||
Args:
|
||||
search_results_save_dir (str): Directory that saves the search results.
|
||||
k_values (List[int], optional): Cutoffs. Defaults to ``[1, 3, 5, 10, 100, 1000]``.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results.
|
||||
"""
|
||||
eval_results_dict = {}
|
||||
|
||||
corpus = self.data_loader.load_corpus()
|
||||
@ -70,6 +92,14 @@ class MKQAEvaluator(AbsEvaluator):
|
||||
):
|
||||
"""
|
||||
Compute Recall@k for QA task. The definition of recall in QA task is different from the one in IR task. Please refer to the paper of RocketQA: https://aclanthology.org/2021.naacl-main.466.pdf.
|
||||
|
||||
Args:
|
||||
corpus_dict (Dict[str, str]): Dictionary of the corpus with doc id and contents.
|
||||
qrels (Dict[str, List[str]]): Relevances of queries and passage.
|
||||
search_results (Dict[str, Dict[str, float]]): Search results of the model to evaluate.
|
||||
|
||||
Returns:
|
||||
dict: The model's scores of the metrics.
|
||||
"""
|
||||
contexts = []
|
||||
answers = []
|
||||
|
||||
@ -5,7 +5,15 @@ from .evaluator import MKQAEvaluator
|
||||
|
||||
|
||||
class MKQAEvalRunner(AbsEvalRunner):
|
||||
"""
|
||||
Evaluation runner of MKQA.
|
||||
"""
|
||||
def load_data_loader(self) -> MKQAEvalDataLoader:
|
||||
"""Load the data loader instance by args.
|
||||
|
||||
Returns:
|
||||
MKQAEvalDataLoader: The MKQA data loader instance.
|
||||
"""
|
||||
data_loader = MKQAEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
@ -16,6 +24,11 @@ class MKQAEvalRunner(AbsEvalRunner):
|
||||
return data_loader
|
||||
|
||||
def load_evaluator(self) -> MKQAEvaluator:
|
||||
"""Load the evaluator instance by args.
|
||||
|
||||
Returns:
|
||||
MKQAEvaluator: The MKQA evaluator instance.
|
||||
"""
|
||||
evaluator = MKQAEvaluator(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
data_loader=self.data_loader,
|
||||
|
||||
@ -1,2 +1,6 @@
|
||||
Evaluation
|
||||
==========
|
||||
==========
|
||||
|
||||
.. toctree::
|
||||
evaluation/miracl
|
||||
evaluation/mkqa
|
||||
48
docs/source/API/evaluation/miracl.rst
Normal file
48
docs/source/API/evaluation/miracl.rst
Normal file
@ -0,0 +1,48 @@
|
||||
MIRACL
|
||||
======
|
||||
|
||||
`MIRACL <https://project-miracl.github.io/>`_ (Multilingual Information Retrieval Across a Continuum of Languages)
|
||||
is an WSDM 2023 Cup challenge that focuses on search across 18 different languages.
|
||||
They release a multilingual retrieval dataset containing the train and dev set for 16 "known languages" and only dev set for 2 "surprise languages".
|
||||
The topics are generated by native speakers of each language, who also label the relevance between the topics and a given document list.
|
||||
You can found the `dataset <https://huggingface.co/datasets/miracl/miracl-corpus>`_ on HuggingFace.
|
||||
|
||||
You can evaluate model's performance on MIRACL simply by running our provided shell script:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
chmod +x /examples/evaluation/miracl/eval_miracl.sh
|
||||
./examples/evaluation/miracl/eval_miracl.sh
|
||||
|
||||
Or by running:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
python -m FlagEmbedding.evaluation.miracl \
|
||||
--eval_name miracl \
|
||||
--dataset_dir ./miracl/data \
|
||||
--dataset_names bn hi sw te th yo \
|
||||
--splits dev \
|
||||
--corpus_embd_save_dir ./miracl/corpus_embd \
|
||||
--output_dir ./miracl/search_results \
|
||||
--search_top_k 1000 \
|
||||
--rerank_top_k 100 \
|
||||
--cache_path /root/.cache/huggingface/hub \
|
||||
--overwrite False \
|
||||
--k_values 10 100 \
|
||||
--eval_output_method markdown \
|
||||
--eval_output_path ./miracl/miracl_eval_results.md \
|
||||
--eval_metrics ndcg_at_10 recall_at_100 \
|
||||
--embedder_name_or_path BAAI/bge-m3 \
|
||||
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
|
||||
--devices cuda:0 cuda:1 \
|
||||
--cache_dir /root/.cache/huggingface/hub \
|
||||
--reranker_max_length 1024
|
||||
|
||||
change the embedder, reranker, devices and cache directory to your preference.
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
miracl/data_loader
|
||||
miracl/runner
|
||||
13
docs/source/API/evaluation/miracl/data_loader.rst
Normal file
13
docs/source/API/evaluation/miracl/data_loader.rst
Normal file
@ -0,0 +1,13 @@
|
||||
data_loader
|
||||
===========
|
||||
|
||||
.. autoclass:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader
|
||||
|
||||
Methods
|
||||
-------
|
||||
|
||||
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader.available_dataset_names
|
||||
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader.available_splits
|
||||
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_corpus
|
||||
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_qrels
|
||||
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_queries
|
||||
5
docs/source/API/evaluation/miracl/runner.rst
Normal file
5
docs/source/API/evaluation/miracl/runner.rst
Normal file
@ -0,0 +1,5 @@
|
||||
runner
|
||||
======
|
||||
|
||||
.. autoclass:: FlagEmbedding.evaluation.miracl.MIRACLEvalRunner
|
||||
:members:
|
||||
87
docs/source/API/evaluation/mkqa.rst
Normal file
87
docs/source/API/evaluation/mkqa.rst
Normal file
@ -0,0 +1,87 @@
|
||||
MKQA
|
||||
====
|
||||
|
||||
`MKQA <https://github.com/apple/ml-mkqa>`_ is an open-domain question answering evaluation set comprising 10k question-answer pairs aligned across 26 typologically diverse languages.
|
||||
Each example in the dataset has the following structure:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
{
|
||||
'example_id': 563260143484355911,
|
||||
'queries': {
|
||||
'en': "who sings i hear you knocking but you can't come in",
|
||||
'ru': "кто поет i hear you knocking but you can't come in",
|
||||
'ja': '「 I hear you knocking」は誰が歌っていますか',
|
||||
'zh_cn': "《i hear you knocking but you can't come in》是谁演唱的",
|
||||
...
|
||||
},
|
||||
'query': "who sings i hear you knocking but you can't come in",
|
||||
'answers': {
|
||||
'en': [{
|
||||
'type': 'entity',
|
||||
'entity': 'Q545186',
|
||||
'text': 'Dave Edmunds',
|
||||
'aliases': [],
|
||||
}],
|
||||
'ru': [{
|
||||
'type': 'entity',
|
||||
'entity': 'Q545186',
|
||||
'text': 'Эдмундс, Дэйв',
|
||||
'aliases': ['Эдмундс', 'Дэйв Эдмундс', 'Эдмундс Дэйв', 'Dave Edmunds'],
|
||||
}],
|
||||
'ja': [{
|
||||
'type': 'entity',
|
||||
'entity': 'Q545186',
|
||||
'text': 'デイヴ・エドモンズ',
|
||||
'aliases': ['デーブ・エドモンズ', 'デイブ・エドモンズ'],
|
||||
}],
|
||||
'zh_cn': [{
|
||||
'type': 'entity',
|
||||
'text': '戴维·埃德蒙兹 ',
|
||||
'entity': 'Q545186',
|
||||
}],
|
||||
...
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
You can evaluate model's performance on MKQA simply by running our provided shell script:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
chmod +x /examples/evaluation/mkqa/eval_mkqa.sh
|
||||
./examples/evaluation/mkqa/eval_mkqa.sh
|
||||
|
||||
Or by running:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
python -m FlagEmbedding.evaluation.mkqa \
|
||||
--eval_name mkqa \
|
||||
--dataset_dir ./mkqa/data \
|
||||
--dataset_names en zh_cn \
|
||||
--splits test \
|
||||
--corpus_embd_save_dir ./mkqa/corpus_embd \
|
||||
--output_dir ./mkqa/search_results \
|
||||
--search_top_k 1000 \
|
||||
--rerank_top_k 100 \
|
||||
--cache_path /root/.cache/huggingface/hub \
|
||||
--overwrite False \
|
||||
--k_values 20 \
|
||||
--eval_output_method markdown \
|
||||
--eval_output_path ./mkqa/mkqa_eval_results.md \
|
||||
--eval_metrics qa_recall_at_20 \
|
||||
--embedder_name_or_path BAAI/bge-m3 \
|
||||
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
|
||||
--devices cuda:0 cuda:1 \
|
||||
--cache_dir /root/.cache/huggingface/hub \
|
||||
--reranker_max_length 1024
|
||||
|
||||
change the embedder, reranker, devices and cache directory to your preference.
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
mkqa/data_loader
|
||||
mkqa/evaluator
|
||||
mkqa/runner
|
||||
15
docs/source/API/evaluation/mkqa/data_loader.rst
Normal file
15
docs/source/API/evaluation/mkqa/data_loader.rst
Normal file
@ -0,0 +1,15 @@
|
||||
data_loader
|
||||
===========
|
||||
|
||||
.. autoclass:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader
|
||||
|
||||
Methods
|
||||
-------
|
||||
|
||||
.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader.available_dataset_names
|
||||
.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader.available_splits
|
||||
.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader.load_corpus
|
||||
.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_local_qrels
|
||||
.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_remote_corpus
|
||||
.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_remote_qrels
|
||||
.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_remote_queries
|
||||
5
docs/source/API/evaluation/mkqa/evaluator.rst
Normal file
5
docs/source/API/evaluation/mkqa/evaluator.rst
Normal file
@ -0,0 +1,5 @@
|
||||
evaluator
|
||||
=========
|
||||
|
||||
.. autoclass:: FlagEmbedding.evaluation.mkqa.MKQAEvaluator
|
||||
:members:
|
||||
4
docs/source/API/evaluation/mkqa/runner.rst
Normal file
4
docs/source/API/evaluation/mkqa/runner.rst
Normal file
@ -0,0 +1,4 @@
|
||||
runner
|
||||
======
|
||||
.. autoclass:: FlagEmbedding.evaluation.mkqa.MKQAEvalRunner
|
||||
:members:
|
||||
Loading…
x
Reference in New Issue
Block a user