96 lines
2.8 KiB
Python
Raw Normal View History

2024-12-03 11:49:43 +00:00
"""
Ref: https://github.com/facebookresearch/contriever
"""
import regex
import unicodedata
from functools import partial
from typing import List, Union
class SimpleTokenizer:
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
def tokenize(self, text, uncased=False):
matches = [m for m in self._regexp.finditer(text)]
if uncased:
tokens = [m.group().lower() for m in matches]
else:
tokens = [m.group() for m in matches]
return tokens
def _normalize(text):
return unicodedata.normalize('NFD', text)
def has_answer(answers, text, tokenizer) -> bool:
"""Check if a document contains an answer string."""
text = _normalize(text)
text = tokenizer.tokenize(text, uncased=True)
for answer in answers:
answer = _normalize(answer)
answer = tokenizer.tokenize(answer, uncased=True)
for i in range(0, len(text) - len(answer) + 1):
if answer == text[i: i + len(answer)]:
return True
return False
def check_answer(example, tokenizer) -> List[bool]:
"""Search through all the top docs to see if they have any of the answers."""
answers = example['answers']
ctxs = example['ctxs']
hits = []
for i, text in enumerate(ctxs):
if text is None: # cannot find the document for some reason
hits.append(False)
continue
hits.append(has_answer(answers, text, tokenizer))
return hits
def evaluate_qa_recall(ctxs, answers, k_values: Union[int, List[int]]=100):
# compute Recall@k for QA task
data = []
assert len(ctxs) == len(answers)
for i in range(len(ctxs)):
_ctxs, _answers = ctxs[i], answers[i]
data.append({
'answers': _answers,
'ctxs': _ctxs,
})
tokenizer = SimpleTokenizer()
get_score_partial = partial(check_answer, tokenizer=tokenizer)
scores = map(get_score_partial, data)
n_docs = len(data[0]['ctxs'])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
if isinstance(k_values, int):
k = min(k_values, len(top_k_hits))
return top_k_hits[k - 1] / len(data)
else:
scores = []
for k in k_values:
k = min(k, len(top_k_hits))
scores.append(top_k_hits[k - 1] / len(data))
return scores