mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-30 16:47:19 +00:00
Improve dpr conversion (#826)
* Bugfix dpr conversion * Add latest docstring and tutorial changes * Fix preprocessor changes
This commit is contained in:
parent
e9f0076dbd
commit
7b559fa4e8
@ -341,7 +341,7 @@ Embeddings of documents / passages shape (batch_size, embedding_dim)
|
||||
#### train
|
||||
|
||||
```python
|
||||
| train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_positives: int = 1, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", query_encoder_save_dir: str = "query_encoder", passage_encoder_save_dir: str = "passage_encoder")
|
||||
| train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, dev_split: float = 0, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_positives: int = 1, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr", query_encoder_save_dir: str = "query_encoder", passage_encoder_save_dir: str = "passage_encoder")
|
||||
```
|
||||
|
||||
train a DensePassageRetrieval model
|
||||
@ -352,6 +352,7 @@ train a DensePassageRetrieval model
|
||||
- `train_filename`: training filename
|
||||
- `dev_filename`: development set filename, file to be used by model in eval step of training
|
||||
- `test_filename`: test set filename, file to be used by model in test step after training
|
||||
- `dev_split`: The proportion of the train set that will sliced. Only works if dev_filename is set to None
|
||||
- `batch_size`: total number of samples in 1 batch of data
|
||||
- `embed_title`: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
|
||||
- `num_hard_negatives`: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
|
||||
|
||||
@ -270,6 +270,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
train_filename: str,
|
||||
dev_filename: str = None,
|
||||
test_filename: str = None,
|
||||
dev_split: float = 0,
|
||||
batch_size: int = 2,
|
||||
embed_title: bool = True,
|
||||
num_hard_negatives: int = 1,
|
||||
@ -294,6 +295,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:param train_filename: training filename
|
||||
:param dev_filename: development set filename, file to be used by model in eval step of training
|
||||
:param test_filename: test set filename, file to be used by model in test step after training
|
||||
:param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
|
||||
:param batch_size: total number of samples in 1 batch of data
|
||||
:param embed_title: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
|
||||
:param num_hard_negatives: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
|
||||
@ -324,6 +326,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
train_filename=train_filename,
|
||||
dev_filename=dev_filename,
|
||||
test_filename=test_filename,
|
||||
dev_split=dev_split,
|
||||
embed_title=self.embed_title,
|
||||
num_hard_negatives=num_hard_negatives,
|
||||
num_positives=num_positives)
|
||||
|
||||
@ -73,6 +73,7 @@ from haystack.retriever.base import BaseRetriever
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HaystackDocumentStore:
|
||||
def __init__(self,
|
||||
@ -96,7 +97,7 @@ class HaystackDocumentStore:
|
||||
def __prepare_ElasticsearchDocumentStore():
|
||||
es = Elasticsearch(['http://localhost:9200/'], verify_certs=True)
|
||||
if not es.ping():
|
||||
logging.info("Starting Elasticsearch ...")
|
||||
logger.info("Starting Elasticsearch ...")
|
||||
status = subprocess.run(
|
||||
['docker run -d -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.9.2'], shell=True
|
||||
)
|
||||
@ -167,7 +168,7 @@ def create_dpr_training_dataset(squad_data: dict, retriever: BaseRetriever,
|
||||
n_non_added_questions = 0
|
||||
n_questions = 0
|
||||
for idx_article, article in enumerate(tqdm(squad_data, unit="article")):
|
||||
article_title = article["title"]
|
||||
article_title = article.get("title", "")
|
||||
for paragraph in article["paragraphs"]:
|
||||
context = paragraph["context"]
|
||||
for question in paragraph["qas"]:
|
||||
@ -178,10 +179,7 @@ def create_dpr_training_dataset(squad_data: dict, retriever: BaseRetriever,
|
||||
question=question["question"],
|
||||
answers=answers,
|
||||
n_ctxs=num_hard_negative_ctxs)
|
||||
positive_ctxs = [{
|
||||
"title": f"{article_title}_{i}",
|
||||
"text": c
|
||||
} for i, c in enumerate([context for _ in question["answers"]])]
|
||||
positive_ctxs = [{"title": article_title,"text": context, "passage_id": ""}]
|
||||
|
||||
if not hard_negative_ctxs or not positive_ctxs:
|
||||
logging.error(
|
||||
@ -193,12 +191,13 @@ def create_dpr_training_dataset(squad_data: dict, retriever: BaseRetriever,
|
||||
"answers": answers,
|
||||
"positive_ctxs": positive_ctxs,
|
||||
"negative_ctxs": [],
|
||||
"hard_negative_ctxs": hard_negative_ctxs
|
||||
"hard_negative_ctxs": hard_negative_ctxs,
|
||||
}
|
||||
n_questions += 1
|
||||
yield dict_DPR
|
||||
|
||||
print(f"Number of not added questions : {n_non_added_questions} / {n_questions}")
|
||||
logger.info(f"Number of skipped questions: {n_non_added_questions}")
|
||||
logger.info(f"Number of added questions: {n_questions}")
|
||||
|
||||
|
||||
def save_dataset(iter_dpr: Iterator, dpr_output_filename: Path,
|
||||
@ -228,12 +227,12 @@ def get_hard_negative_contexts(retriever: BaseRetriever, question: str, answers:
|
||||
list_hard_neg_ctxs = []
|
||||
retrieved_docs = retriever.retrieve(query=question, top_k=n_ctxs, index="document")
|
||||
for retrieved_doc in retrieved_docs:
|
||||
retrieved_doc_id = retrieved_doc.meta["name"]
|
||||
retrieved_doc_id = retrieved_doc.meta.get("name", "")
|
||||
retrieved_doc_text = retrieved_doc.text
|
||||
if any([True if answer.lower() in retrieved_doc_text.lower() else False
|
||||
for answer in answers]):
|
||||
continue
|
||||
list_hard_neg_ctxs.append({"title": retrieved_doc_id, "text": retrieved_doc_text})
|
||||
list_hard_neg_ctxs.append({"title": retrieved_doc_id, "text": retrieved_doc_text, "passage_id": ""})
|
||||
|
||||
return list_hard_neg_ctxs
|
||||
|
||||
@ -255,11 +254,14 @@ def load_squad_file(squad_file_path: Path):
|
||||
return squad_file_path, squad_data["data"]
|
||||
|
||||
|
||||
def main(squad_input_filename: Path, dpr_output_filename: Path,
|
||||
def main(squad_input_filename: Path,
|
||||
dpr_output_filename: Path,
|
||||
preprocessor,
|
||||
document_store_type_config: Tuple[str, Dict] = ("ElasticsearchDocumentStore", {}),
|
||||
retriever_type_config: Tuple[str, Dict] = ("ElasticsearchRetriever", {}),
|
||||
num_hard_negative_ctxs: int = 30,
|
||||
split_dataset: bool = False):
|
||||
split_dataset: bool = False,
|
||||
):
|
||||
tqdm.write(f"Using SQuAD-like file {squad_input_filename}")
|
||||
|
||||
# 1. Load squad file data
|
||||
@ -315,8 +317,11 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
preprocessor = PreProcessor(split_length=100, split_overlap=0, clean_empty_lines=False,
|
||||
|
||||
preprocessor = PreProcessor(split_length=100,
|
||||
split_overlap=0,
|
||||
clean_empty_lines=False,
|
||||
split_respect_sentence_boundary=False,
|
||||
clean_whitespace=False)
|
||||
squad_input_filename = Path(args.squad_input_filename)
|
||||
dpr_output_filename = Path(args.dpr_output_filename)
|
||||
@ -335,8 +340,9 @@ if __name__ == '__main__':
|
||||
|
||||
main(squad_input_filename=squad_input_filename,
|
||||
dpr_output_filename=dpr_output_filename,
|
||||
preprocessor=preprocessor,
|
||||
document_store_type_config=("ElasticsearchDocumentStore", store_dpr_config),
|
||||
retriever_type_config=("DensePassageRetriever", retriever_dpr_config), # dpr
|
||||
# retriever_type_config=("ElasticsearchRetriever", retriever_bm25_config), # bm25
|
||||
#retriever_type_config=("DensePassageRetriever", retriever_dpr_config), # dpr
|
||||
retriever_type_config=("ElasticsearchRetriever", retriever_bm25_config), # bm25
|
||||
num_hard_negative_ctxs=num_hard_negative_ctxs,
|
||||
split_dataset=split_dataset)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user