mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-25 09:50:14 +00:00

* move logging config from haystack lib to application * Update Documentation & Code Style * config logging before importing haystack * Update Documentation & Code Style * add logging config to all tutorials * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
259 lines
11 KiB
Python
259 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Generative-Pseudo-Labeling-for-Domain-Adaptation-of-Dense-Retrievals.ipynb
|
|
|
|
Automatically generated by Colaboratory.
|
|
|
|
Original file is located at
|
|
https://colab.research.google.com/drive/1Tz9GSzre7JfvXDDKe7sCnO0FMuDViMnN
|
|
|
|
# Generative Pseudo Labeling for Domain Adaptation of Dense Retrievals
|
|
#### Note: Adapted to Haystack from Nils Riemers' original [notebook](https://colab.research.google.com/gist/jamescalam/d2c888775c87f9882bb7c379a96adbc8/gpl-domain-adaptation.ipynb#scrollTo=183ff7ab).
|
|
|
|
The NLP models we use every day were trained on a corpus of data that reflects the world from the past. In the meantime, we've experienced world-changing events, like the COVID pandemics, and we'd like our models to know about them. Training a model from scratch is tedious work but what if we could just update the models with new data? Generative Pseudo Labeling comes to the rescue.
|
|
|
|
The example below shows you how to use GPL to fine-tune a model so that it can answer the query: "How is COVID-19 transmitted?".
|
|
|
|
We're using TAS-B: A DistilBERT model that achieves state-of-the-art performance on MS MARCO (500k queries from Bing Search Engine). Both DistilBERT and MS MARCO were created with data from 2018 and before, hence, it lacks the knowledge of any COVID-related information.
|
|
|
|
For this example, we're using just four documents. When you ask the model ""How is COVID-19 transmitted?", here are the answers that you get (dot-score and document):
|
|
- 94.84 Ebola is transmitted via direct contact with blood
|
|
- 92.87 HIV is transmitted via sex or sharing needles
|
|
- 92.31 Corona is transmitted via the air
|
|
- 91.54 Polio is transmitted via contaminated water or food
|
|
|
|
|
|
You can see that the correct document is only third, outranked by Ebola and HIV information. Let's see how we can make this better.
|
|
|
|
## Efficient Domain Adaptation with GPL
|
|
This notebook demonstrates [Generative Pseudo Labeling (GPL)](https://arxiv.org/abs/2112.07577), an efficient approach to adapt existing dense retrieval models to new domains and data.
|
|
|
|
We get a collection of 10k scientific papers on COVID-19 and then fine-tune the model within 15-60 minutes (depending on your GPU) so that includes the COVID knowledge.
|
|
|
|
If we search again with the updated model, we get the search results we would expect:
|
|
- Query: How is COVID-19 transmitted
|
|
- 97.70 Corona is transmitted via the air
|
|
- 96.71 Ebola is transmitted via direct contact with blood
|
|
- 95.14 Polio is transmitted via contaminated water or food
|
|
- 94.13 HIV is transmitted via sex or sharing needles
|
|
"""
|
|
|
|
import logging
|
|
|
|
# We configure how logging messages should be displayed and which log level should be used before importing Haystack.
|
|
# Example log message:
|
|
# INFO - haystack.utils.preprocessing - Converting data/tutorial1/218_Olenna_Tyrell.txt
|
|
# Default log level in basicConfig is WARNING so the explicit parameter is not necessary but can be changed easily:
|
|
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
|
|
logging.getLogger("haystack").setLevel(logging.INFO)
|
|
|
|
from sentence_transformers import SentenceTransformer, util
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
from datasets import load_dataset
|
|
|
|
|
|
def tutorial18_gpl():
|
|
# We load the TAS-B model, a state-of-the-art model trained on MS MARCO
|
|
max_seq_length = 200
|
|
model_name = "msmarco-distilbert-base-tas-b"
|
|
|
|
org_model = SentenceTransformer(model_name)
|
|
org_model.max_seq_length = max_seq_length
|
|
|
|
# We define a simple query and some documents how diseases are transmitted
|
|
# As TAS-B was trained on rather out-dated data (2018 and older), it has now idea about COVID-19
|
|
# So in the below example, it fails to recognize the relationship between COVID-19 and Corona
|
|
|
|
def show_examples(model):
|
|
query = "How is COVID-19 transmitted"
|
|
docs = [
|
|
"Corona is transmitted via the air",
|
|
"Ebola is transmitted via direct contact with blood",
|
|
"HIV is transmitted via sex or sharing needles",
|
|
"Polio is transmitted via contaminated water or food",
|
|
]
|
|
|
|
query_emb = model.encode(query)
|
|
docs_emb = model.encode(docs)
|
|
scores = util.dot_score(query_emb, docs_emb)[0]
|
|
doc_scores = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
|
|
|
|
print("Query:", query)
|
|
for doc, score in doc_scores:
|
|
# print(doc, score)
|
|
print(f"{score:0.02f}\t{doc}")
|
|
|
|
print("Original Model")
|
|
show_examples(org_model)
|
|
|
|
"""# Get Some Data on COVID-19
|
|
We select 10k scientific publications (title + abstract) that are connected to COVID-19. As a dataset, we use [TREC-COVID-19](https://huggingface.co/datasets/nreimers/trec-covid).
|
|
"""
|
|
|
|
dataset = load_dataset("nreimers/trec-covid", split="train")
|
|
num_documents = 10000
|
|
corpus = []
|
|
for row in dataset:
|
|
if len(row["title"]) > 20 and len(row["text"]) > 100:
|
|
text = row["title"] + " " + row["text"]
|
|
|
|
text_lower = text.lower()
|
|
|
|
# The dataset also contains many papers on other diseases. To make the training in this demo
|
|
# more efficient, we focus on papers that talk about COVID.
|
|
if "covid" in text_lower or "corona" in text_lower or "sars-cov-2" in text_lower:
|
|
corpus.append(text)
|
|
|
|
if len(corpus) >= num_documents:
|
|
break
|
|
|
|
print("Len Corpus:", len(corpus))
|
|
|
|
"""# Initialize Haystack Retriever and DocumentStore
|
|
|
|
Let's add corpus documents to `FAISSDocumentStore` and update corpus embeddings via `EmbeddingRetriever`
|
|
"""
|
|
|
|
from haystack.nodes.retriever import EmbeddingRetriever
|
|
from haystack.document_stores import FAISSDocumentStore
|
|
|
|
document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", similarity="cosine")
|
|
document_store.write_documents([{"content": t} for t in corpus])
|
|
|
|
retriever = EmbeddingRetriever(
|
|
document_store=document_store,
|
|
embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b",
|
|
model_format="sentence_transformers",
|
|
max_seq_len=max_seq_length,
|
|
progress_bar=False,
|
|
)
|
|
document_store.update_embeddings(retriever)
|
|
|
|
"""## (Optional) Download Pre-Generated Questions or Generate Them Outside of Haystack
|
|
|
|
The first step of the GPL algorithm requires us to generate questions for a given text passage. Even though our pre-COVID trained model hasn't seen any COVID-related content, it can still produce sensible queries by copying words from the input text. As generating questions from 10k documents is a bit slow (depending on the GPU used), we'll download question/document pairs directly from the Hugging Face hub.
|
|
|
|
"""
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
query_doc_pairs = []
|
|
|
|
load_queries_from_hub = True
|
|
|
|
# Generation of the queries is quite slow in Colab due to the old GPU and the limited CPU
|
|
# I pre-computed the queries and uploaded these to the HF dataset hub. Here we just download them
|
|
if load_queries_from_hub:
|
|
generated_queries = load_dataset("nreimers/trec-covid-generated-queries", split="train")
|
|
for row in generated_queries:
|
|
query_doc_pairs.append({"question": row["query"], "document": row["doc"]})
|
|
else:
|
|
# Load doc2query model
|
|
t5_name = "doc2query/msmarco-t5-base-v1"
|
|
t5_tokenizer = AutoTokenizer.from_pretrained(t5_name)
|
|
t5_model = AutoModelForSeq2SeqLM.from_pretrained(t5_name).cuda()
|
|
|
|
batch_size = 32
|
|
queries_per_doc = 3
|
|
|
|
for start_idx in tqdm(range(0, len(corpus), batch_size)):
|
|
corpus_batch = corpus[start_idx : start_idx + batch_size]
|
|
enc_inp = t5_tokenizer(
|
|
corpus_batch, max_length=max_seq_length, truncation=True, padding=True, return_tensors="pt"
|
|
)
|
|
|
|
outputs = t5_model.generate(
|
|
input_ids=enc_inp["input_ids"].cuda(),
|
|
attention_mask=enc_inp["attention_mask"].cuda(),
|
|
max_length=64,
|
|
do_sample=True,
|
|
top_p=0.95,
|
|
num_return_sequences=queries_per_doc,
|
|
)
|
|
|
|
decoded_output = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
for idx, query in enumerate(decoded_output):
|
|
corpus_id = int(idx / queries_per_doc)
|
|
query_doc_pairs.append({"question": query, "document": corpus_batch[corpus_id]})
|
|
|
|
print("Generated queries:", len(query_doc_pairs))
|
|
|
|
"""# Use PseudoLabelGenerator to Genenerate Retriever Adaptation Training Data
|
|
|
|
PseudoLabelGenerator run will execute all three steps of the GPL [algorithm](https://github.com/UKPLab/gpl#how-does-gpl-work):
|
|
1. Question generation - optional step
|
|
2. Negative mining
|
|
3. Pseudo labeling (margin scoring)
|
|
|
|
The output of the `PseudoLabelGenerator` is the training data we'll use to adapt our `EmbeddingRetriever`.
|
|
|
|
"""
|
|
|
|
from haystack.nodes.question_generator import QuestionGenerator
|
|
from haystack.nodes.label_generator import PseudoLabelGenerator
|
|
|
|
use_question_generator = False
|
|
|
|
if use_question_generator:
|
|
questions_producer = QuestionGenerator(
|
|
model_name_or_path="doc2query/msmarco-t5-base-v1",
|
|
max_length=64,
|
|
split_length=128,
|
|
batch_size=32,
|
|
num_queries_per_doc=3,
|
|
)
|
|
|
|
else:
|
|
questions_producer = query_doc_pairs
|
|
|
|
# We can use either QuestionGenerator or already generated questions in PseudoLabelGenerator
|
|
psg = PseudoLabelGenerator(questions_producer, retriever, max_questions_per_document=10, batch_size=32, top_k=10)
|
|
output, pipe_id = psg.run(documents=document_store.get_all_documents())
|
|
|
|
"""# Update the Retriever
|
|
|
|
Now that we have the generated training data produced by `PseudoLabelGenerator`, we'll update the `EmbeddingRetriever`. Let's take a peek at the training data.
|
|
"""
|
|
|
|
output["gpl_labels"][0]
|
|
|
|
len(output["gpl_labels"])
|
|
|
|
retriever.train(output["gpl_labels"])
|
|
|
|
"""## Verify that EmbeddingRetriever Is Adapted and Save It For Future Use
|
|
|
|
Let's repeat our query to see if the Retriever learned about COVID and can now rank it as #1 among the answers.
|
|
|
|
"""
|
|
|
|
print("Original Model")
|
|
show_examples(org_model)
|
|
|
|
print("\n\nAdapted Model")
|
|
show_examples(retriever.embedding_encoder.embedding_model)
|
|
|
|
retriever.save("adapted_retriever")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tutorial18_gpl()
|
|
|
|
# ## About us
|
|
#
|
|
# This [Haystack](https://github.com/deepset-ai/haystack/) notebook was made with love by [deepset](https://deepset.ai/) in Berlin, Germany
|
|
#
|
|
# We bring NLP to the industry via open source!
|
|
# Our focus: Industry specific language models & large scale QA systems.
|
|
#
|
|
# Some of our other work:
|
|
# - [German BERT](https://deepset.ai/german-bert)
|
|
# - [GermanQuAD and GermanDPR](https://deepset.ai/germanquad)
|
|
# - [FARM](https://github.com/deepset-ai/FARM)
|
|
#
|
|
# Get in touch:
|
|
# [Twitter](https://twitter.com/deepset_ai) | [LinkedIn](https://www.linkedin.com/company/deepset-ai/) | [Slack](https://haystack.deepset.ai/community/join) | [GitHub Discussions](https://github.com/deepset-ai/haystack/discussions) | [Website](https://deepset.ai)
|
|
#
|
|
# By the way: [we're hiring!](https://www.deepset.ai/jobs)
|
|
#
|