mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 02:16:32 +00:00
Add FARMClassifier node for Document Classification (#1265)
* Add FARM classification node * Add classification output to meta field of document * Update usage example * Add test case for FARMClassifier * Replace FARMRanker with FARMClassifier in documentation strings * Remove base method not implemented by any child class, etc.
This commit is contained in:
parent
f79d9bdca6
commit
4e6f7f349d
1
haystack/classifier/__init__.py
Normal file
1
haystack/classifier/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.classifier.farm import FARMClassifier
|
68
haystack/classifier/base.py
Normal file
68
haystack/classifier/base.py
Normal file
@ -0,0 +1,68 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from functools import wraps
|
||||
from time import perf_counter
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import Document, BaseComponent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseClassifier(BaseComponent):
|
||||
return_no_answers: bool
|
||||
outgoing_edges = 1
|
||||
query_count = 0
|
||||
query_time = 0
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||
pass
|
||||
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, **kwargs): # type: ignore
|
||||
self.query_count += 1
|
||||
if documents:
|
||||
predict = self.timing(self.predict, "query_time")
|
||||
results = predict(query=query, documents=documents, top_k=top_k)
|
||||
else:
|
||||
results = []
|
||||
|
||||
document_ids = [doc.id for doc in results]
|
||||
logger.debug(f"Retrieved documents with IDs: {document_ids}")
|
||||
output = {
|
||||
"query": query,
|
||||
"documents": results,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
def timing(self, fn, attr_name):
|
||||
"""Wrapper method used to time functions. """
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if attr_name not in self.__dict__:
|
||||
self.__dict__[attr_name] = 0
|
||||
tic = perf_counter()
|
||||
ret = fn(*args, **kwargs)
|
||||
toc = perf_counter()
|
||||
self.__dict__[attr_name] += toc - tic
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
def print_time(self):
|
||||
print("Classifier (Speed)")
|
||||
print("---------------")
|
||||
if not self.query_count:
|
||||
print("No querying performed via Classifier.run()")
|
||||
else:
|
||||
print(f"Queries Performed: {self.query_count}")
|
||||
print(f"Query time: {self.query_time}s")
|
||||
print(f"{self.query_time / self.query_count} seconds per query")
|
285
haystack/classifier/farm.py
Normal file
285
haystack/classifier/farm.py
Normal file
@ -0,0 +1,285 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from farm.data_handler.data_silo import DataSilo
|
||||
from farm.data_handler.processor import TextClassificationProcessor
|
||||
from farm.infer import Inferencer
|
||||
from farm.modeling.optimization import initialize_optimizer
|
||||
from farm.train import Trainer
|
||||
from farm.utils import set_all_seeds, initialize_device_settings
|
||||
|
||||
from haystack import Document
|
||||
from haystack.classifier.base import BaseClassifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FARMClassifier(BaseClassifier):
|
||||
"""
|
||||
This node classifies documents and adds the output from the classification step to the document's meta data.
|
||||
The meta field of the document is a dictionary with the following format:
|
||||
'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }
|
||||
|
||||
| With a FARMClassifier, you can:
|
||||
- directly get predictions via predict()
|
||||
- fine-tune the model on text classification training data via train()
|
||||
|
||||
Usage example:
|
||||
...
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
classifier = FARMClassifier(model_name_or_path="deepset/bert-base-german-cased-sentiment-Germeval17")
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=classifier, name="Classifier", inputs=["ESRetriever"])
|
||||
|
||||
res = p_extractive.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
top_k_retriever=10,
|
||||
top_k_reader=5
|
||||
)
|
||||
|
||||
print(res["documents"][0].to_dict()["meta"]["classification"]["label"])
|
||||
# Note that print_documents() does not output the content of the classification field in the meta data
|
||||
# document_dicts = [doc.to_dict() for doc in res["documents"]]
|
||||
# res["documents"] = document_dicts
|
||||
# print_documents(res, max_text_len=100)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[str, Path],
|
||||
model_version: Optional[str] = None,
|
||||
batch_size: int = 50,
|
||||
use_gpu: bool = True,
|
||||
top_k: int = 10,
|
||||
num_processes: Optional[int] = None,
|
||||
max_seq_len: int = 256,
|
||||
progress_bar: bool = True
|
||||
):
|
||||
|
||||
"""
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'deepset/bert-base-german-cased-sentiment-Germeval17'.
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param batch_size: Number of samples the model receives in one batch for inference.
|
||||
Memory consumption is much lower in inference mode. Recommendation: Increase the batch size
|
||||
to a value so only a single batch is used.
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param top_k: The maximum number of documents to return
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool`. Set to value of 0 to disable
|
||||
multiprocessing. Set to None to let Inferencer determine optimum number. If you
|
||||
want to debug the Language Model, you might need to disable multiprocessing!
|
||||
:param max_seq_len: Max sequence length of one input text for the model
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path, model_version=model_version,
|
||||
batch_size=batch_size, use_gpu=use_gpu, top_k=top_k,
|
||||
num_processes=num_processes, max_seq_len=max_seq_len, progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
||||
task_type="text_classification", max_seq_len=max_seq_len,
|
||||
num_processes=num_processes, revision=model_version,
|
||||
disable_tqdm=not progress_bar,
|
||||
strict=False)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.use_gpu = use_gpu
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
def train(
|
||||
self,
|
||||
data_dir: str,
|
||||
train_filename: str,
|
||||
label_list: List[str],
|
||||
delimiter: str,
|
||||
metric: str,
|
||||
dev_filename: Optional[str] = None,
|
||||
test_filename: Optional[str] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
batch_size: int = 10,
|
||||
n_epochs: int = 2,
|
||||
learning_rate: float = 1e-5,
|
||||
max_seq_len: Optional[int] = None,
|
||||
warmup_proportion: float = 0.2,
|
||||
dev_split: float = 0,
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
):
|
||||
"""
|
||||
Fine-tune a model on a TextClassification dataset.
|
||||
The dataset needs to be in tabular format (CSV, TSV, etc.), with columns called "label" and "text" in no specific order.
|
||||
Options:
|
||||
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for TextClassification
|
||||
- Take a TextClassification model and fine-tune it for your domain
|
||||
|
||||
:param data_dir: Path to directory containing your training data
|
||||
:param label_list: list of labels in the training dataset, e.g., ["0", "1"]
|
||||
:param delimiter: delimiter that separates columns in the training dataset, e.g., "\t"
|
||||
:param metric: evaluation metric to be used while training, e.g., "f1_macro"
|
||||
:param train_filename: Filename of training data
|
||||
:param dev_filename: Filename of dev / eval data
|
||||
:param test_filename: Filename of test data
|
||||
:param dev_split: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here
|
||||
that gets split off from training data for eval.
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param batch_size: Number of samples the model receives in one batch for training
|
||||
:param n_epochs: Number of iterations on the whole training data set
|
||||
:param learning_rate: Learning rate of the optimizer
|
||||
:param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down.
|
||||
:param warmup_proportion: Proportion of training steps until maximum learning rate is reached.
|
||||
Until that point LR is increasing linearly. After that it's decreasing again linearly.
|
||||
Options for different schedules are available in FARM.
|
||||
:param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset
|
||||
:param save_dir: Path to store the final model
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if dev_filename:
|
||||
dev_split = 0
|
||||
|
||||
if num_processes is None:
|
||||
num_processes = multiprocessing.cpu_count() - 1 or 1
|
||||
|
||||
set_all_seeds(seed=42)
|
||||
|
||||
# For these variables, by default, we use the value set when initializing the FARMClassifier.
|
||||
# These can also be manually set when train() is called if you want a different value at train vs inference
|
||||
if use_gpu is None:
|
||||
use_gpu = self.use_gpu
|
||||
if max_seq_len is None:
|
||||
max_seq_len = self.max_seq_len
|
||||
|
||||
device, n_gpu = initialize_device_settings(use_cuda=use_gpu, use_amp=use_amp)
|
||||
|
||||
if not save_dir:
|
||||
save_dir = f"saved_models/{self.inferencer.model.language_model.name}"
|
||||
|
||||
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
|
||||
processor = TextClassificationProcessor(
|
||||
tokenizer=self.inferencer.processor.tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
label_list=label_list,
|
||||
metric=metric,
|
||||
train_filename=train_filename,
|
||||
dev_filename=dev_filename,
|
||||
dev_split=dev_split,
|
||||
test_filename=test_filename,
|
||||
data_dir=Path(data_dir),
|
||||
delimiter=delimiter,
|
||||
)
|
||||
|
||||
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
|
||||
# and calculates a few descriptive statistics of our datasets
|
||||
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes)
|
||||
|
||||
# 3. Create an optimizer and pass the already initialized model
|
||||
model, optimizer, lr_schedule = initialize_optimizer(
|
||||
model=self.inferencer.model,
|
||||
learning_rate=learning_rate,
|
||||
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
n_epochs=n_epochs,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
)
|
||||
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
data_silo=data_silo,
|
||||
epochs=n_epochs,
|
||||
n_gpu=n_gpu,
|
||||
lr_schedule=lr_schedule,
|
||||
evaluate_every=evaluate_every,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
disable_tqdm=not self.progress_bar
|
||||
)
|
||||
|
||||
# 5. Let it grow!
|
||||
self.inferencer.model = trainer.train()
|
||||
self.save(Path(save_dir))
|
||||
|
||||
def update_parameters(
|
||||
self,
|
||||
max_seq_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Hot update parameters of a loaded FARMClassifier. It may not to be safe when processing concurrent requests.
|
||||
"""
|
||||
if max_seq_len is not None:
|
||||
self.inferencer.processor.max_seq_len = max_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def save(self, directory: Path):
|
||||
"""
|
||||
Saves the FARMClassifier model so that it can be reused at a later point in time.
|
||||
|
||||
:param directory: Directory where the FARMClassifier model should be saved
|
||||
"""
|
||||
logger.info(f"Saving classifier model to {directory}")
|
||||
self.inferencer.model.save(directory)
|
||||
self.inferencer.processor.save(directory)
|
||||
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: int = None, batch_size: int = None):
|
||||
"""
|
||||
Use loaded FARMClassifier model to, for a list of queries, classify each query's supplied list of Document.
|
||||
|
||||
Returns list of dictionary of query and list of document sorted by (desc.) similarity with query
|
||||
|
||||
:param query_doc_list: List of dictionaries containing queries with their retrieved documents
|
||||
:param top_k: The maximum number of answers to return for each query
|
||||
:param batch_size: Number of samples the model receives in one batch for inference
|
||||
:return: List of dictionaries containing query and list of Document with class probabilities in meta field
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
|
||||
"""
|
||||
Use loaded classification model to classify the supplied list of Document.
|
||||
|
||||
Returns list of Document enriched with class label and probability, which are stored in Document.meta["classification"]
|
||||
|
||||
:param query: Query string (is not used at the moment)
|
||||
:param documents: List of Document to be classified
|
||||
:param top_k: The maximum number of documents to return
|
||||
:return: List of Document with class probabilities in meta field
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
# documents should follow the structure {"text": "Schartau sagte dem Tagesspiegel, dass Fischer ein ... sei"},
|
||||
docs = [{"text": doc.text} for doc in documents]
|
||||
results = self.inferencer.inference_from_dicts(dicts=docs)[0]["predictions"]
|
||||
|
||||
classified_docs: List[Document] = []
|
||||
|
||||
for result, doc in zip(results, documents):
|
||||
cur_doc = doc
|
||||
cur_doc.meta["classification"] = result
|
||||
classified_docs.append(cur_doc)
|
||||
|
||||
return classified_docs[:top_k]
|
@ -112,7 +112,7 @@ class FARMRanker(BaseRanker):
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for TextPairClassification
|
||||
- Take a TextPairClassification model and fine-tune it for your domain
|
||||
|
||||
:param data_dir: Path to directory containing your training data in SQuAD style
|
||||
:param data_dir: Path to directory containing your training data
|
||||
:param train_filename: Filename of training data
|
||||
:param dev_filename: Filename of dev / eval data
|
||||
:param test_filename: Filename of test data
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
import requests
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from haystack.classifier import FARMClassifier
|
||||
from haystack.generator.transformers import Seq2SeqGenerator
|
||||
from haystack.knowledge_graph.graphdb import GraphDBKnowledgeGraph
|
||||
from milvus import Milvus
|
||||
@ -299,6 +300,14 @@ def ranker(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(params=["farm"], scope="module")
|
||||
def classifier(request):
|
||||
if request.param == "farm":
|
||||
return FARMClassifier(
|
||||
model_name_or_path="deepset/bert-base-german-cased-sentiment-Germeval17"
|
||||
)
|
||||
|
||||
|
||||
# TODO Fix bug in test_no_answer_output when using
|
||||
# @pytest.fixture(params=["farm", "transformers"])
|
||||
@pytest.fixture(params=["farm"], scope="module")
|
||||
|
24
test/test_classifier.py
Normal file
24
test/test_classifier.py
Normal file
@ -0,0 +1,24 @@
|
||||
from haystack import Document
|
||||
from haystack.classifier.base import BaseClassifier
|
||||
|
||||
|
||||
def test_classifier(classifier):
|
||||
assert isinstance(classifier, BaseClassifier)
|
||||
|
||||
query = "not used at the moment"
|
||||
docs = [
|
||||
Document(
|
||||
text="""Fragen und Antworten - Bitte auf Themen beschränken welche einen Bezug zur Bahn aufweisen. Persönliche Unterhaltungen bitte per PN führen. Links bitte mit kurzer Erklärung zum verlinkten Inhalt versehen""",
|
||||
meta={"name": "0"},
|
||||
id="1",
|
||||
),
|
||||
Document(
|
||||
text="""Ich liebe es wenn die Bahn selbstverschuldete unnötig lange Aufenthaltszeiten durch Verspätung wieder rausfährt.""",
|
||||
meta={"name": "1"},
|
||||
id="2",
|
||||
),
|
||||
]
|
||||
results = classifier.predict(query=query, documents=docs)
|
||||
expected_labels = ["neutral", "negative"]
|
||||
for i, doc in enumerate(results):
|
||||
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
Loading…
x
Reference in New Issue
Block a user