mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 14:08:27 +00:00
fix pip backtracking issue (#2281)
* fix pip backtracking issue * restrict azure-core version * Remove the trailing comma * Add skip_magic_trailing_comma in pyproject.toml for pydoc compatibility * Pin pydoc-markdown _again_ Co-authored-by: Sara Zan <sarazanzo94@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
5951fc463e
commit
dde9d59271
12
.github/utils/generate_json_schema.py
vendored
12
.github/utils/generate_json_schema.py
vendored
@ -107,11 +107,7 @@ def get_json_schema():
|
||||
if param.default != param.empty:
|
||||
default = param.default
|
||||
param_fields_kwargs[param.name] = (annotation, default)
|
||||
model = create_model(
|
||||
f"{node.__name__}ComponentParams",
|
||||
__config__=Config,
|
||||
**param_fields_kwargs,
|
||||
)
|
||||
model = create_model(f"{node.__name__}ComponentParams", __config__=Config, **param_fields_kwargs)
|
||||
model.update_forward_refs(**model.__dict__)
|
||||
params_schema = model.schema()
|
||||
params_schema["title"] = "Parameters"
|
||||
@ -172,11 +168,7 @@ def get_json_schema():
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Name of the pipeline.",
|
||||
"type": "string",
|
||||
},
|
||||
"name": {"title": "Name", "description": "Name of the pipeline.", "type": "string"},
|
||||
"nodes": {
|
||||
"title": "Nodes",
|
||||
"description": "Nodes to be used by this particular pipeline",
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -131,8 +131,7 @@ from haystack.pipelines import ExtractiveQAPipeline
|
||||
# Prebuilt pipeline
|
||||
p_extractive_premade = ExtractiveQAPipeline(reader=reader, retriever=es_retriever)
|
||||
res = p_extractive_premade.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}},
|
||||
query="Who is the father of Arya Stark?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
|
||||
)
|
||||
print_answers(res, details="minimum")
|
||||
```
|
||||
@ -144,10 +143,7 @@ If you want to just do the retrieval step, you can use a `DocumentSearchPipeline
|
||||
from haystack.pipelines import DocumentSearchPipeline
|
||||
|
||||
p_retrieval = DocumentSearchPipeline(es_retriever)
|
||||
res = p_retrieval.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
params={"Retriever": {"top_k": 10}},
|
||||
)
|
||||
res = p_retrieval.run(query="Who is the father of Arya Stark?", params={"Retriever": {"top_k": 10}})
|
||||
print_documents(res, max_text_len=200)
|
||||
```
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
|
||||
# from haystack.retriever import EmbeddingRetriever, DensePassageRetriever
|
||||
# retriever = EmbeddingRetriever(document_store=document_store, model_format="sentence_transformers",
|
||||
# embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
||||
# embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
||||
# retriever = DensePassageRetriever(document_store=document_store,
|
||||
# query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
# passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -606,7 +606,7 @@ set the `address` parameter when creating the RayPipeline instance.
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True, address: Optional[str] = None, **kwargs, ,)
|
||||
def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwrite_with_env_variables: bool = True, address: Optional[str] = None, **kwargs)
|
||||
```
|
||||
|
||||
Load Pipeline from a YAML file defining the individual components and how they're tied together to form
|
||||
|
||||
@ -71,11 +71,7 @@ html_logo = "img/logo.png"
|
||||
html_additional_pages = {"index": "pages/index.html"}
|
||||
|
||||
# The file extensions of source files.
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "restructuredtext",
|
||||
".md": "markdown",
|
||||
}
|
||||
source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
|
||||
|
||||
# -- Add autodocs for __init__() methods -------------------------------------
|
||||
|
||||
|
||||
@ -383,15 +383,7 @@ class ElasticsearchDocumentStore(KeywordDocumentStore):
|
||||
{"strings": {"path_match": "*", "match_mapping_type": "string", "mapping": {"type": "keyword"}}}
|
||||
],
|
||||
},
|
||||
"settings": {
|
||||
"analysis": {
|
||||
"analyzer": {
|
||||
"default": {
|
||||
"type": self.analyzer,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"settings": {"analysis": {"analyzer": {"default": {"type": self.analyzer}}}},
|
||||
}
|
||||
|
||||
if self.synonyms:
|
||||
@ -538,15 +530,7 @@ class ElasticsearchDocumentStore(KeywordDocumentStore):
|
||||
if query:
|
||||
body["query"] = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query,
|
||||
"type": "most_fields",
|
||||
"fields": self.search_fields,
|
||||
}
|
||||
}
|
||||
]
|
||||
"should": [{"multi_match": {"query": query, "type": "most_fields", "fields": self.search_fields}}]
|
||||
}
|
||||
}
|
||||
if filters:
|
||||
@ -1291,10 +1275,7 @@ class ElasticsearchDocumentStore(KeywordDocumentStore):
|
||||
return query
|
||||
|
||||
def _convert_es_hit_to_document(
|
||||
self,
|
||||
hit: dict,
|
||||
return_embedding: bool,
|
||||
adapt_score_for_embedding: bool = False,
|
||||
self, hit: dict, return_embedding: bool, adapt_score_for_embedding: bool = False
|
||||
) -> Document:
|
||||
# We put all additional data of the doc into meta_data and return it in the API
|
||||
meta_data = {
|
||||
@ -1799,10 +1780,7 @@ class OpenSearchDocumentStore(ElasticsearchDocumentStore):
|
||||
if not self.embedding_field:
|
||||
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
|
||||
# +1 in similarity to avoid negative numbers (for cosine sim)
|
||||
body: Dict[str, Any] = {
|
||||
"size": top_k,
|
||||
"query": self._get_vector_similarity_query(query_emb, top_k),
|
||||
}
|
||||
body: Dict[str, Any] = {"size": top_k, "query": self._get_vector_similarity_query(query_emb, top_k)}
|
||||
if filters:
|
||||
body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
||||
|
||||
@ -1915,15 +1893,7 @@ class OpenSearchDocumentStore(ElasticsearchDocumentStore):
|
||||
{"strings": {"path_match": "*", "match_mapping_type": "string", "mapping": {"type": "keyword"}}}
|
||||
],
|
||||
},
|
||||
"settings": {
|
||||
"analysis": {
|
||||
"analyzer": {
|
||||
"default": {
|
||||
"type": self.analyzer,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"settings": {"analysis": {"analyzer": {"default": {"type": self.analyzer}}}},
|
||||
}
|
||||
|
||||
if self.synonyms:
|
||||
|
||||
@ -299,9 +299,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
progress_bar.close()
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
return {
|
||||
self.index: self.embedding_field,
|
||||
}
|
||||
return {self.index: self.embedding_field}
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
|
||||
@ -150,9 +150,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
self.indexes[index][document.id] = document
|
||||
|
||||
def _create_document_field_map(self):
|
||||
return {
|
||||
self.embedding_field: "embedding",
|
||||
}
|
||||
return {self.embedding_field: "embedding"}
|
||||
|
||||
def write_labels(
|
||||
self,
|
||||
|
||||
@ -161,10 +161,7 @@ class Milvus1DocumentStore(SQLDocumentStore):
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
super().__init__(
|
||||
url=sql_url,
|
||||
index=index,
|
||||
duplicate_documents=duplicate_documents,
|
||||
isolation_level=isolation_level,
|
||||
url=sql_url, index=index, duplicate_documents=duplicate_documents, isolation_level=isolation_level
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
@ -194,9 +191,7 @@ class Milvus1DocumentStore(SQLDocumentStore):
|
||||
raise RuntimeError(f"Index creation on Milvus server failed: {status}")
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
return {
|
||||
self.index: self.embedding_field,
|
||||
}
|
||||
return {self.index: self.embedding_field}
|
||||
|
||||
def write_documents(
|
||||
self,
|
||||
|
||||
@ -193,17 +193,11 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
super().__init__(
|
||||
url=sql_url,
|
||||
index=index,
|
||||
duplicate_documents=duplicate_documents,
|
||||
isolation_level=isolation_level,
|
||||
url=sql_url, index=index, duplicate_documents=duplicate_documents, isolation_level=isolation_level
|
||||
)
|
||||
|
||||
def _create_collection_and_index_if_not_exist(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
consistency_level: int = 0,
|
||||
index_param: Optional[Dict[str, Any]] = None,
|
||||
self, index: Optional[str] = None, consistency_level: int = 0, index_param: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
index = index or self.index
|
||||
index_param = index_param or self.index_param
|
||||
@ -240,9 +234,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
return collection
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
return {
|
||||
self.index: self.embedding_field,
|
||||
}
|
||||
return {self.index: self.embedding_field}
|
||||
|
||||
def write_documents(
|
||||
self,
|
||||
@ -626,8 +618,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
vector_id_map[int(vector_id)] = doc
|
||||
|
||||
search_result: QueryResult = self.collection.query(
|
||||
expr=f'{self.id_field} in [ {",".join(ids)} ]',
|
||||
output_fields=[self.embedding_field],
|
||||
expr=f'{self.id_field} in [ {",".join(ids)} ]', output_fields=[self.embedding_field]
|
||||
)
|
||||
|
||||
for result in search_result:
|
||||
|
||||
@ -257,11 +257,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
|
||||
if return_embedding is True:
|
||||
raise Exception("return_embeddings is not supported by SQLDocumentStore.")
|
||||
result = self._query(
|
||||
index=index,
|
||||
filters=filters,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
result = self._query(index=index, filters=filters, batch_size=batch_size)
|
||||
yield from result
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
@ -482,13 +478,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
index = index or self.index
|
||||
for chunk_map in self.chunked_dict(vector_id_map, size=batch_size):
|
||||
self.session.query(DocumentORM).filter(DocumentORM.id.in_(chunk_map), DocumentORM.index == index).update(
|
||||
{
|
||||
DocumentORM.vector_id: case(
|
||||
chunk_map,
|
||||
value=DocumentORM.id,
|
||||
)
|
||||
},
|
||||
synchronize_session=False,
|
||||
{DocumentORM.vector_id: case(chunk_map, value=DocumentORM.id)}, synchronize_session=False
|
||||
)
|
||||
try:
|
||||
self.session.commit()
|
||||
@ -538,8 +528,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
if filters:
|
||||
for key, values in filters.items():
|
||||
query = query.join(MetaDocumentORM, aliased=True).filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
MetaDocumentORM.name == key, MetaDocumentORM.value.in_(values)
|
||||
)
|
||||
|
||||
if only_documents_without_embedding:
|
||||
@ -658,8 +647,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
if filters:
|
||||
for key, values in filters.items():
|
||||
document_ids_to_delete = document_ids_to_delete.join(MetaDocumentORM, aliased=True).filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
MetaDocumentORM.name == key, MetaDocumentORM.value.in_(values)
|
||||
)
|
||||
if ids:
|
||||
document_ids_to_delete = document_ids_to_delete.filter(DocumentORM.id.in_(ids))
|
||||
|
||||
@ -171,10 +171,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
else:
|
||||
return index[0].upper() + index[1:]
|
||||
|
||||
def _create_schema_and_index_if_not_exist(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
):
|
||||
def _create_schema_and_index_if_not_exist(self, index: Optional[str] = None):
|
||||
"""
|
||||
Create a new index (schema/class in Weaviate) for storing documents in case if an
|
||||
index (schema) with the name doesn't exist already.
|
||||
@ -1035,11 +1032,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
"All the documents in Weaviate store have an embedding by default. Only update is allowed!"
|
||||
)
|
||||
|
||||
result = self._get_all_documents_in_index(
|
||||
index=index,
|
||||
filters=filters,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
result = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)
|
||||
|
||||
for result_batch in get_batches_from_generator(result, batch_size):
|
||||
document_batch = [
|
||||
|
||||
@ -140,9 +140,7 @@ class DataSilo:
|
||||
|
||||
num_dicts = len(dicts)
|
||||
multiprocessing_chunk_size, num_cpus_used = calc_chunksize(
|
||||
num_dicts=num_dicts,
|
||||
max_processes=self.max_processes,
|
||||
max_chunksize=self.max_multiprocessing_chunksize,
|
||||
num_dicts=num_dicts, max_processes=self.max_processes, max_chunksize=self.max_multiprocessing_chunksize
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
@ -383,11 +381,7 @@ class DataSilo:
|
||||
else:
|
||||
data_loader_test = None
|
||||
|
||||
self.loaders = {
|
||||
"train": data_loader_train,
|
||||
"dev": data_loader_dev,
|
||||
"test": data_loader_test,
|
||||
}
|
||||
self.loaders = {"train": data_loader_train, "dev": data_loader_dev, "test": data_loader_test}
|
||||
|
||||
def _create_dev_from_train(self):
|
||||
"""
|
||||
@ -594,10 +588,7 @@ class DataSiloForCrossVal:
|
||||
sampler_train = RandomSampler(trainset)
|
||||
|
||||
self.data_loader_train = NamedDataLoader(
|
||||
dataset=trainset,
|
||||
sampler=sampler_train,
|
||||
batch_size=self.batch_size,
|
||||
tensor_names=self.tensor_names,
|
||||
dataset=trainset, sampler=sampler_train, batch_size=self.batch_size, tensor_names=self.tensor_names
|
||||
)
|
||||
self.data_loader_dev = NamedDataLoader(
|
||||
dataset=devset,
|
||||
@ -611,11 +602,7 @@ class DataSiloForCrossVal:
|
||||
batch_size=self.batch_size,
|
||||
tensor_names=self.tensor_names,
|
||||
)
|
||||
self.loaders = {
|
||||
"train": self.data_loader_train,
|
||||
"dev": self.data_loader_dev,
|
||||
"test": self.data_loader_test,
|
||||
}
|
||||
self.loaders = {"train": self.data_loader_train, "dev": self.data_loader_dev, "test": self.data_loader_test}
|
||||
|
||||
def get_data_loader(self, which):
|
||||
return self.loaders[which]
|
||||
@ -694,11 +681,7 @@ class DataSiloForCrossVal:
|
||||
documents.append(list(document))
|
||||
|
||||
xval_split = cls._split_for_qa(
|
||||
documents=documents,
|
||||
id_index=id_index,
|
||||
n_splits=n_splits,
|
||||
shuffle=shuffle,
|
||||
random_state=random_state,
|
||||
documents=documents, id_index=id_index, n_splits=n_splits, shuffle=shuffle, random_state=random_state
|
||||
)
|
||||
silos = []
|
||||
|
||||
|
||||
@ -81,11 +81,7 @@ def sample_to_features_text(sample, tasks, max_seq_len, tokenizer):
|
||||
assert len(padding_mask) == max_seq_len
|
||||
assert len(segment_ids) == max_seq_len
|
||||
|
||||
feat_dict = {
|
||||
"input_ids": input_ids,
|
||||
"padding_mask": padding_mask,
|
||||
"segment_ids": segment_ids,
|
||||
}
|
||||
feat_dict = {"input_ids": input_ids, "padding_mask": padding_mask, "segment_ids": segment_ids}
|
||||
|
||||
# Add Labels for different tasks
|
||||
for task_name, task in tasks.items():
|
||||
|
||||
@ -357,10 +357,7 @@ class Processor(ABC):
|
||||
logger.debug(random_sample)
|
||||
|
||||
def _log_params(self):
|
||||
params = {
|
||||
"processor": self.__class__.__name__,
|
||||
"tokenizer": self.tokenizer.__class__.__name__,
|
||||
}
|
||||
params = {"processor": self.__class__.__name__, "tokenizer": self.tokenizer.__class__.__name__}
|
||||
names = ["max_seq_len", "dev_split"]
|
||||
for name in names:
|
||||
value = getattr(self, name)
|
||||
@ -1925,12 +1922,7 @@ class InferenceProcessor(TextClassificationProcessor):
|
||||
- Doesn't read from file, but only consumes dictionaries (e.g. coming from API requests)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_seq_len,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, tokenizer, max_seq_len, **kwargs):
|
||||
|
||||
super(InferenceProcessor, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -2020,10 +2012,7 @@ class InferenceProcessor(TextClassificationProcessor):
|
||||
# Private method to keep s3e pooling and embedding extraction working
|
||||
def _sample_to_features(self, sample: Sample) -> Dict:
|
||||
features = sample_to_features_text(
|
||||
sample=sample,
|
||||
tasks=self.tasks,
|
||||
max_seq_len=self.max_seq_len,
|
||||
tokenizer=self.tokenizer,
|
||||
sample=sample, tasks=self.tasks, max_seq_len=self.max_seq_len, tokenizer=self.tokenizer
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@ -152,8 +152,7 @@ class Evaluator:
|
||||
if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"):
|
||||
if isinstance(metric_val, numbers.Number):
|
||||
MlLogger.log_metrics(
|
||||
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val},
|
||||
step=steps,
|
||||
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
|
||||
)
|
||||
# print via standard python logger
|
||||
if print:
|
||||
|
||||
@ -62,11 +62,7 @@ def f1_macro(preds, labels):
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearman": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
return {"pearson": pearson_corr, "spearman": spearman_corr, "corr": (pearson_corr + spearman_corr) / 2}
|
||||
|
||||
|
||||
def compute_metrics(metric: str, preds, labels):
|
||||
|
||||
@ -301,9 +301,7 @@ class Inferencer:
|
||||
"""
|
||||
dicts = self.processor.file_to_dicts(file)
|
||||
preds_all = self.inference_from_dicts(
|
||||
dicts,
|
||||
return_json=return_json,
|
||||
multiprocessing_chunksize=multiprocessing_chunksize,
|
||||
dicts, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize
|
||||
)
|
||||
return list(preds_all)
|
||||
|
||||
@ -343,10 +341,7 @@ class Inferencer:
|
||||
multiprocessing_chunksize = _chunk_size
|
||||
|
||||
predictions = self._inference_with_multiprocessing(
|
||||
dicts,
|
||||
return_json,
|
||||
aggregate_preds,
|
||||
multiprocessing_chunksize,
|
||||
dicts, return_json, aggregate_preds, multiprocessing_chunksize
|
||||
)
|
||||
|
||||
self.processor.log_problematic(self.problematic_sample_ids)
|
||||
|
||||
@ -554,9 +554,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]:
|
||||
raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.")
|
||||
|
||||
task_type_to_pipeline_map = {
|
||||
"question_answering": "question-answering",
|
||||
}
|
||||
task_type_to_pipeline_map = {"question_answering": "question-answering"}
|
||||
|
||||
convert(
|
||||
pipeline_name=task_type_to_pipeline_map[task_type],
|
||||
|
||||
@ -102,13 +102,7 @@ class LanguageModel(nn.Module):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.subclasses[cls.__name__] = cls
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
def forward(self, input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -347,16 +341,7 @@ class LanguageModel(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def _infer_language_from_name(cls, name):
|
||||
known_languages = (
|
||||
"german",
|
||||
"english",
|
||||
"chinese",
|
||||
"indian",
|
||||
"french",
|
||||
"polish",
|
||||
"spanish",
|
||||
"multilingual",
|
||||
)
|
||||
known_languages = ("german", "english", "chinese", "indian", "french", "polish", "spanish", "multilingual")
|
||||
matches = [lang for lang in known_languages if lang in name]
|
||||
if "camembert" in name:
|
||||
language = "french"
|
||||
|
||||
@ -162,12 +162,7 @@ def initialize_optimizer(
|
||||
schedule_opts["num_training_steps"] = num_train_optimization_steps
|
||||
|
||||
# Log params
|
||||
MlLogger.log_params(
|
||||
{
|
||||
"use_amp": use_amp,
|
||||
"num_train_optimization_steps": schedule_opts["num_training_steps"],
|
||||
}
|
||||
)
|
||||
MlLogger.log_params({"use_amp": use_amp, "num_train_optimization_steps": schedule_opts["num_training_steps"]})
|
||||
|
||||
# Get optimizer from pytorch, transformers or apex
|
||||
optimizer = _get_optim(model, optimizer_opts)
|
||||
|
||||
@ -378,10 +378,7 @@ class Trainer:
|
||||
loss = self.adjust_loss(loss)
|
||||
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]:
|
||||
if self.local_rank in [-1, 0]:
|
||||
MlLogger.log_metrics(
|
||||
{"Train_loss_total": float(loss.detach().cpu().numpy())},
|
||||
step=self.global_step,
|
||||
)
|
||||
MlLogger.log_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
|
||||
if self.log_learning_rate:
|
||||
MlLogger.log_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
|
||||
if self.use_amp:
|
||||
|
||||
@ -132,10 +132,7 @@ class AzureConverter(BaseConverter):
|
||||
return docs
|
||||
|
||||
def convert_azure_json(
|
||||
self,
|
||||
file_path: Path,
|
||||
meta: Optional[Dict[str, str]] = None,
|
||||
valid_languages: Optional[List[str]] = None,
|
||||
self, file_path: Path, meta: Optional[Dict[str, str]] = None, valid_languages: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract text and tables from the JSON output of Azure's Form Recognizer service.
|
||||
|
||||
@ -14,11 +14,7 @@ class BaseConverter(BaseComponent):
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
remove_numeric_tables: bool = False,
|
||||
valid_languages: Optional[List[str]] = None,
|
||||
):
|
||||
def __init__(self, remove_numeric_tables: bool = False, valid_languages: Optional[List[str]] = None):
|
||||
"""
|
||||
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
|
||||
The tabular structures in documents might be noise for the reader model if it
|
||||
|
||||
@ -20,11 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageToTextConverter(BaseConverter):
|
||||
def __init__(
|
||||
self,
|
||||
remove_numeric_tables: bool = False,
|
||||
valid_languages: Optional[List[str]] = ["eng"],
|
||||
):
|
||||
def __init__(self, remove_numeric_tables: bool = False, valid_languages: Optional[List[str]] = ["eng"]):
|
||||
"""
|
||||
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
|
||||
The tabular structures in documents might be noise for the reader model if it
|
||||
|
||||
@ -20,11 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PDFToTextConverter(BaseConverter):
|
||||
def __init__(
|
||||
self,
|
||||
remove_numeric_tables: bool = False,
|
||||
valid_languages: Optional[List[str]] = None,
|
||||
):
|
||||
def __init__(self, remove_numeric_tables: bool = False, valid_languages: Optional[List[str]] = None):
|
||||
"""
|
||||
:param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables.
|
||||
The tabular structures in documents might be noise for the reader model if it
|
||||
@ -156,11 +152,7 @@ class PDFToTextConverter(BaseConverter):
|
||||
|
||||
|
||||
class PDFToTextOCRConverter(BaseConverter):
|
||||
def __init__(
|
||||
self,
|
||||
remove_numeric_tables: bool = False,
|
||||
valid_languages: Optional[List[str]] = ["eng"],
|
||||
):
|
||||
def __init__(self, remove_numeric_tables: bool = False, valid_languages: Optional[List[str]] = ["eng"]):
|
||||
"""
|
||||
Extract text from image file using the pytesseract library (https://github.com/madmaze/pytesseract)
|
||||
|
||||
|
||||
@ -61,11 +61,7 @@ class TikaConverter(BaseConverter):
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
tika_url=tika_url,
|
||||
remove_numeric_tables=remove_numeric_tables,
|
||||
valid_languages=valid_languages,
|
||||
)
|
||||
self.set_config(tika_url=tika_url, remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
|
||||
ping = requests.get(tika_url)
|
||||
if ping.status_code != 200:
|
||||
|
||||
@ -35,12 +35,7 @@ class Docs2Answers(BaseComponent):
|
||||
else:
|
||||
# Regular docs
|
||||
cur_answer = Answer(
|
||||
answer="",
|
||||
type="other",
|
||||
score=doc.score,
|
||||
context=doc.content,
|
||||
document_id=doc.id,
|
||||
meta=doc.meta,
|
||||
answer="", type="other", score=doc.score, context=doc.content, document_id=doc.id, meta=doc.meta
|
||||
)
|
||||
answers.append(cur_answer)
|
||||
|
||||
|
||||
@ -23,11 +23,7 @@ class BasePreProcessor(BaseComponent):
|
||||
raise NotImplementedError
|
||||
|
||||
def clean(
|
||||
self,
|
||||
document: dict,
|
||||
clean_whitespace: bool,
|
||||
clean_header_footer: bool,
|
||||
clean_empty_lines: bool,
|
||||
self, document: dict, clean_whitespace: bool, clean_header_footer: bool, clean_empty_lines: bool
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -181,13 +181,7 @@ class PreProcessor(BasePreProcessor):
|
||||
nested_docs = [self._process_single(d, **kwargs) for d in tqdm(documents, unit="docs")]
|
||||
return [d for x in nested_docs for d in x]
|
||||
|
||||
def clean(
|
||||
self,
|
||||
document: dict,
|
||||
clean_whitespace: bool,
|
||||
clean_header_footer: bool,
|
||||
clean_empty_lines: bool,
|
||||
) -> dict:
|
||||
def clean(self, document: dict, clean_whitespace: bool, clean_header_footer: bool, clean_empty_lines: bool) -> dict:
|
||||
"""
|
||||
Perform document cleaning on a single document and return a single document. This method will deal with whitespaces, headers, footers
|
||||
and empty lines. Its exact functionality is defined by the parameters passed into PreProcessor.__init__().
|
||||
|
||||
@ -54,11 +54,7 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_version=model_version,
|
||||
top_k=top_k,
|
||||
)
|
||||
self.set_config(model_name_or_path=model_name_or_path, model_version=model_version, top_k=top_k)
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
|
||||
@ -922,10 +922,7 @@ class FARMReader(BaseReader):
|
||||
)
|
||||
continue
|
||||
aggregated_per_question[aggregation_key]["answers"].append(
|
||||
{
|
||||
"text": label.answer.answer,
|
||||
"answer_start": label.answer.offsets_in_document[0].start,
|
||||
}
|
||||
{"text": label.answer.answer, "answer_start": label.answer.offsets_in_document[0].start}
|
||||
)
|
||||
aggregated_per_question[aggregation_key]["is_impossible"] = False
|
||||
# create new one
|
||||
|
||||
@ -180,10 +180,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
self.set_config(document_store=document_store, top_k=top_k, auto_fit=auto_fit)
|
||||
|
||||
self.vectorizer = TfidfVectorizer(
|
||||
lowercase=True,
|
||||
stop_words=None,
|
||||
token_pattern=r"(?u)\b\w\w+\b",
|
||||
ngram_range=(1, 1),
|
||||
lowercase=True, stop_words=None, token_pattern=r"(?u)\b\w\w+\b", ngram_range=(1, 1)
|
||||
)
|
||||
|
||||
self.document_store = document_store
|
||||
|
||||
@ -19,15 +19,8 @@ from haystack.nodes.evaluator.evaluator import (
|
||||
calculate_f1_str_multi,
|
||||
semantic_answer_similarity,
|
||||
)
|
||||
from haystack.pipelines.config import (
|
||||
get_component_definitions,
|
||||
get_pipeline_definition,
|
||||
read_pipeline_config_from_yaml,
|
||||
)
|
||||
from haystack.pipelines.utils import (
|
||||
generate_code,
|
||||
print_eval_report,
|
||||
)
|
||||
from haystack.pipelines.config import get_component_definitions, get_pipeline_definition, read_pipeline_config_from_yaml
|
||||
from haystack.pipelines.utils import generate_code, print_eval_report
|
||||
from haystack.utils import DeepsetCloud
|
||||
|
||||
try:
|
||||
@ -80,10 +73,7 @@ class BasePipeline:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_code(
|
||||
self,
|
||||
pipeline_variable_name: str = "pipeline",
|
||||
generate_imports: bool = True,
|
||||
add_comment: bool = False,
|
||||
self, pipeline_variable_name: str = "pipeline", generate_imports: bool = True, add_comment: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Returns the code to create this pipeline as string.
|
||||
@ -105,10 +95,7 @@ class BasePipeline:
|
||||
return code
|
||||
|
||||
def to_notebook_cell(
|
||||
self,
|
||||
pipeline_variable_name: str = "pipeline",
|
||||
generate_imports: bool = True,
|
||||
add_comment: bool = True,
|
||||
self, pipeline_variable_name: str = "pipeline", generate_imports: bool = True, add_comment: bool = True
|
||||
):
|
||||
"""
|
||||
Creates a new notebook cell with the code to create this pipeline.
|
||||
@ -324,10 +311,7 @@ class BasePipeline:
|
||||
|
||||
@classmethod
|
||||
def list_pipelines_on_deepset_cloud(
|
||||
cls,
|
||||
workspace: str = "default",
|
||||
api_key: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
cls, workspace: str = "default", api_key: Optional[str] = None, api_endpoint: Optional[str] = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Lists all pipeline configs available on Deepset Cloud.
|
||||
@ -1088,12 +1072,7 @@ class Pipeline(BasePipeline):
|
||||
}
|
||||
return config
|
||||
|
||||
def _generate_component_name(
|
||||
self,
|
||||
type_name: str,
|
||||
params: Dict[str, Any],
|
||||
existing_components: Dict[str, Any],
|
||||
):
|
||||
def _generate_component_name(self, type_name: str, params: Dict[str, Any], existing_components: Dict[str, Any]):
|
||||
component_name: str = type_name
|
||||
# add number if there are multiple distinct ones of the same type
|
||||
while component_name in existing_components and params != existing_components[component_name]["params"]:
|
||||
|
||||
@ -54,7 +54,7 @@ def convert_files_to_dicts(
|
||||
if encoding is None and suffix == ".pdf":
|
||||
encoding = "Latin1"
|
||||
logger.info("Converting {}".format(path))
|
||||
document = suffix2converter[suffix].convert(file_path=path, meta=None, encoding=encoding,)[
|
||||
document = suffix2converter[suffix].convert(file_path=path, meta=None, encoding=encoding)[
|
||||
0
|
||||
] # PDFToTextConverter, TextConverter, and DocxToTextConverter return a list containing a single dict
|
||||
text = document["content"]
|
||||
|
||||
@ -334,13 +334,8 @@ if __name__ == "__main__":
|
||||
num_hard_negative_ctxs = args.num_hard_negative_ctxs
|
||||
split_dataset = args.split_dataset
|
||||
|
||||
retriever_dpr_config = {
|
||||
"use_gpu": True,
|
||||
}
|
||||
store_dpr_config = {
|
||||
"embedding_field": "embedding",
|
||||
"embedding_dim": 768,
|
||||
}
|
||||
retriever_dpr_config = {"use_gpu": True}
|
||||
store_dpr_config = {"embedding_field": "embedding", "embedding_dim": 768}
|
||||
|
||||
retriever_bm25_config: dict = {}
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
skip_magic_trailing_comma = true # For compatibility with pydoc>=4.6, check if still needed.
|
||||
|
||||
|
||||
[tool.pylint.'MESSAGES CONTROL']
|
||||
|
||||
@ -23,11 +23,7 @@ def get_application() -> FastAPI:
|
||||
# This middleware enables allow all cross-domain requests to the API from a browser. For production
|
||||
# deployments, it could be made more restrictive.
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
|
||||
)
|
||||
application.add_exception_handler(HTTPException, http_error_handler)
|
||||
application.include_router(api_router)
|
||||
|
||||
@ -144,11 +144,7 @@ def test_file_upload_with_no_meta(client: TestClient):
|
||||
assert len(response.json()) == 0
|
||||
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(
|
||||
url="/file-upload",
|
||||
files=file_to_upload,
|
||||
data={"meta": ""},
|
||||
)
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": ""})
|
||||
assert 200 == response.status_code
|
||||
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {}}')
|
||||
@ -160,11 +156,7 @@ def test_file_upload_with_wrong_meta(client: TestClient):
|
||||
assert len(response.json()) == 0
|
||||
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(
|
||||
url="/file-upload",
|
||||
files=file_to_upload,
|
||||
data={"meta": "1"},
|
||||
)
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": "1"})
|
||||
assert 500 == response.status_code
|
||||
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {}}')
|
||||
|
||||
@ -67,6 +67,13 @@ install_requires =
|
||||
mmh3 # fast hashing function (murmurhash3)
|
||||
quantulum3 # quantities extraction from text
|
||||
azure-ai-formrecognizer==3.2.0b2 # forms reader
|
||||
# azure-core is a dependency of azure-ai-formrecognizer
|
||||
# In order to stop malicious pip backtracking during pip install farm-haystack[all] documented in https://github.com/deepset-ai/haystack/issues/2280
|
||||
# we have to resolve a dependency version conflict ourself.
|
||||
# azure-core>=1.23 conflicts with pydoc-markdown's dependency on databind>=1.5.0 which itself requires typing-extensions<4.0.0
|
||||
# azure-core>=1.23 needs typing-extensions>=4.0.1
|
||||
# pip unfortunately backtracks into the databind direction ultimately getting lost.
|
||||
azure-core<1.23
|
||||
|
||||
# Preprocessing
|
||||
more_itertools # for windowing
|
||||
@ -167,7 +174,7 @@ dev =
|
||||
# Code formatting
|
||||
black[jupyter]
|
||||
# Documentation
|
||||
pydoc-markdown>=4,<5
|
||||
pydoc-markdown==4.5.1 # FIXME Unpin!
|
||||
mkdocs
|
||||
jupytercontrib
|
||||
watchdog #==1.0.2
|
||||
|
||||
@ -315,16 +315,12 @@ def summarizer():
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def en_to_de_translator():
|
||||
return TransformersTranslator(
|
||||
model_name_or_path="Helsinki-NLP/opus-mt-en-de",
|
||||
)
|
||||
return TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-de")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def de_to_en_translator():
|
||||
return TransformersTranslator(
|
||||
model_name_or_path="Helsinki-NLP/opus-mt-de-en",
|
||||
)
|
||||
return TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-de-en")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -400,16 +396,12 @@ def table_reader(request):
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ranker_two_logits():
|
||||
return SentenceTransformersRanker(
|
||||
model_name_or_path="deepset/gbert-base-germandpr-reranking",
|
||||
)
|
||||
return SentenceTransformersRanker(model_name_or_path="deepset/gbert-base-germandpr-reranking")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ranker():
|
||||
return SentenceTransformersRanker(
|
||||
model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
||||
)
|
||||
return SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -734,10 +726,7 @@ def get_document_store(
|
||||
|
||||
elif document_store_type == "weaviate":
|
||||
document_store = WeaviateDocumentStore(
|
||||
weaviate_url="http://localhost:8080",
|
||||
index=index,
|
||||
similarity=similarity,
|
||||
embedding_dim=embedding_dim,
|
||||
weaviate_url="http://localhost:8080", index=index, similarity=similarity, embedding_dim=embedding_dim
|
||||
)
|
||||
document_store.weaviate_client.schema.delete_all()
|
||||
document_store._create_schema_and_index_if_not_exist()
|
||||
|
||||
@ -14,11 +14,7 @@ def test_document_classifier(document_classifier):
|
||||
meta={"name": "0"},
|
||||
id="1",
|
||||
),
|
||||
Document(
|
||||
content="""That's bad. I don't like it.""",
|
||||
meta={"name": "1"},
|
||||
id="2",
|
||||
),
|
||||
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
||||
]
|
||||
results = document_classifier.predict(documents=docs)
|
||||
expected_labels = ["joy", "sadness"]
|
||||
@ -36,11 +32,7 @@ def test_zero_shot_document_classifier(zero_shot_document_classifier):
|
||||
meta={"name": "0"},
|
||||
id="1",
|
||||
),
|
||||
Document(
|
||||
content="""That's bad. I don't like it.""",
|
||||
meta={"name": "1"},
|
||||
id="2",
|
||||
),
|
||||
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
||||
]
|
||||
results = zero_shot_document_classifier.predict(documents=docs)
|
||||
expected_labels = ["positive", "negative"]
|
||||
@ -58,11 +50,7 @@ def test_document_classifier_batch_size(batched_document_classifier):
|
||||
meta={"name": "0"},
|
||||
id="1",
|
||||
),
|
||||
Document(
|
||||
content="""That's bad. I don't like it.""",
|
||||
meta={"name": "1"},
|
||||
id="2",
|
||||
),
|
||||
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
||||
]
|
||||
results = batched_document_classifier.predict(documents=docs)
|
||||
expected_labels = ["joy", "sadness"]
|
||||
@ -99,11 +87,7 @@ def test_document_classifier_as_query_node(document_classifier):
|
||||
meta={"name": "0"},
|
||||
id="1",
|
||||
),
|
||||
Document(
|
||||
content="""That's bad. I don't like it.""",
|
||||
meta={"name": "1"},
|
||||
id="2",
|
||||
),
|
||||
Document(content="""That's bad. I don't like it.""", meta={"name": "1"}, id="2"),
|
||||
]
|
||||
output, output_name = document_classifier.run(documents=docs, root_node="Query")
|
||||
expected_labels = ["joy", "sadness"]
|
||||
|
||||
@ -393,13 +393,9 @@ def test_get_all_documents_generator(document_store):
|
||||
|
||||
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
||||
def test_update_existing_documents(document_store, update_existing_documents):
|
||||
original_docs = [
|
||||
{"content": "text1_orig", "id": "1", "meta_field_for_count": "a"},
|
||||
]
|
||||
original_docs = [{"content": "text1_orig", "id": "1", "meta_field_for_count": "a"}]
|
||||
|
||||
updated_docs = [
|
||||
{"content": "text1_new", "id": "1", "meta_field_for_count": "a"},
|
||||
]
|
||||
updated_docs = [{"content": "text1_new", "id": "1", "meta_field_for_count": "a"}]
|
||||
|
||||
document_store.write_documents(original_docs)
|
||||
assert document_store.get_document_count() == 1
|
||||
@ -436,10 +432,7 @@ def test_write_document_meta(document_store):
|
||||
|
||||
|
||||
def test_write_document_index(document_store):
|
||||
documents = [
|
||||
{"content": "text1", "id": "1"},
|
||||
{"content": "text2", "id": "2"},
|
||||
]
|
||||
documents = [{"content": "text1", "id": "1"}, {"content": "text2", "id": "2"}]
|
||||
document_store.write_documents([documents[0]], index="haystack_test_one")
|
||||
assert len(document_store.get_all_documents(index="haystack_test_one")) == 1
|
||||
|
||||
@ -492,9 +485,7 @@ def test_update_embeddings(document_store, retriever):
|
||||
assert type(doc.embedding) is np.ndarray
|
||||
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_one",
|
||||
filters={"meta_field": ["value_0"]},
|
||||
return_embedding=True,
|
||||
index="haystack_test_one", filters={"meta_field": ["value_0"]}, return_embedding=True
|
||||
)
|
||||
assert len(documents) == 2
|
||||
for doc in documents:
|
||||
@ -502,9 +493,7 @@ def test_update_embeddings(document_store, retriever):
|
||||
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
||||
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_one",
|
||||
filters={"meta_field": ["value_0", "value_5"]},
|
||||
return_embedding=True,
|
||||
index="haystack_test_one", filters={"meta_field": ["value_0", "value_5"]}, return_embedding=True
|
||||
)
|
||||
documents_with_value_0 = [doc for doc in documents if doc.meta["meta_field"] == "value_0"]
|
||||
documents_with_value_5 = [doc for doc in documents if doc.meta["meta_field"] == "value_5"]
|
||||
@ -624,9 +613,7 @@ def test_update_embeddings_table_text_retriever(document_store, retriever):
|
||||
|
||||
# Check if Documents with same content (text) get same embedding
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_one",
|
||||
filters={"meta_field": ["value_text_0"]},
|
||||
return_embedding=True,
|
||||
index="haystack_test_one", filters={"meta_field": ["value_text_0"]}, return_embedding=True
|
||||
)
|
||||
assert len(documents) == 2
|
||||
for doc in documents:
|
||||
@ -635,9 +622,7 @@ def test_update_embeddings_table_text_retriever(document_store, retriever):
|
||||
|
||||
# Check if Documents with same content (table) get same embedding
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_one",
|
||||
filters={"meta_field": ["value_table_0"]},
|
||||
return_embedding=True,
|
||||
index="haystack_test_one", filters={"meta_field": ["value_table_0"]}, return_embedding=True
|
||||
)
|
||||
assert len(documents) == 2
|
||||
for doc in documents:
|
||||
@ -646,9 +631,7 @@ def test_update_embeddings_table_text_retriever(document_store, retriever):
|
||||
|
||||
# Check if Documents wih different content (text) get different embedding
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_one",
|
||||
filters={"meta_field": ["value_text_1", "value_text_2"]},
|
||||
return_embedding=True,
|
||||
index="haystack_test_one", filters={"meta_field": ["value_text_1", "value_text_2"]}, return_embedding=True
|
||||
)
|
||||
np.testing.assert_raises(
|
||||
AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
|
||||
@ -656,9 +639,7 @@ def test_update_embeddings_table_text_retriever(document_store, retriever):
|
||||
|
||||
# Check if Documents with different content (table) get different embeddings
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_one",
|
||||
filters={"meta_field": ["value_table_1", "value_table_2"]},
|
||||
return_embedding=True,
|
||||
index="haystack_test_one", filters={"meta_field": ["value_table_1", "value_table_2"]}, return_embedding=True
|
||||
)
|
||||
np.testing.assert_raises(
|
||||
AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
|
||||
@ -666,9 +647,7 @@ def test_update_embeddings_table_text_retriever(document_store, retriever):
|
||||
|
||||
# Check if Documents with different content (table + text) get different embeddings
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_one",
|
||||
filters={"meta_field": ["value_text_1", "value_table_1"]},
|
||||
return_embedding=True,
|
||||
index="haystack_test_one", filters={"meta_field": ["value_text_1", "value_table_1"]}, return_embedding=True
|
||||
)
|
||||
np.testing.assert_raises(
|
||||
AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
|
||||
|
||||
@ -305,10 +305,7 @@ def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
|
||||
labels = EVAL_LABELS[:1]
|
||||
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result = pipeline.eval(
|
||||
labels=labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
)
|
||||
eval_result = pipeline.eval(labels=labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
@ -469,10 +466,7 @@ def test_extractive_qa_labels_with_filters(reader, retriever_with_docs, tmp_path
|
||||
]
|
||||
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result = pipeline.eval(
|
||||
labels=labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
)
|
||||
eval_result = pipeline.eval(labels=labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
@ -541,10 +535,7 @@ def test_reader_eval_in_pipeline(reader):
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_extractive_qa_eval_doc_relevance_col(reader, retriever_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match")
|
||||
|
||||
@ -773,10 +764,7 @@ def test_extractive_qa_eval_wrong_examples(reader, retriever_with_docs):
|
||||
]
|
||||
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
wrongs_retriever = eval_result.wrong_examples(node="Retriever", n=1)
|
||||
wrongs_reader = eval_result.wrong_examples(node="Reader", n=1)
|
||||
|
||||
@ -20,11 +20,7 @@ def test_extractor(document_store_with_docs):
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={
|
||||
"ESRetriever": {"top_k": 1},
|
||||
"Reader": {"top_k": 1},
|
||||
},
|
||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||
)
|
||||
entities = [entity["word"] for entity in prediction["answers"][0].meta["entities"]]
|
||||
assert "Carla" in entities
|
||||
@ -44,11 +40,7 @@ def test_extractor_output_simplifier(document_store_with_docs):
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={
|
||||
"ESRetriever": {"top_k": 1},
|
||||
"Reader": {"top_k": 1},
|
||||
},
|
||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||
)
|
||||
simplified = simplify_ner_for_qa(prediction)
|
||||
assert simplified[0] == {"answer": "Carla", "entities": ["Carla"]}
|
||||
|
||||
@ -16,11 +16,7 @@ from conftest import DOCS_WITH_EMBEDDINGS
|
||||
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner")
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory")],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
|
||||
def test_generator_pipeline_with_translator(
|
||||
document_store, retriever, rag_generator, en_to_de_translator, de_to_en_translator
|
||||
):
|
||||
|
||||
@ -511,14 +511,8 @@ def test_dpr_problematic():
|
||||
|
||||
def test_dpr_query_only():
|
||||
erroneous_dicts = [
|
||||
{
|
||||
"query": "where is castle on the hill based on",
|
||||
"answers": ["Framlingham Castle"],
|
||||
},
|
||||
{
|
||||
"query": "where is castle on the hill 2 based on",
|
||||
"answers": ["Framlingham Castle 2"],
|
||||
},
|
||||
{"query": "where is castle on the hill based on", "answers": ["Framlingham Castle"]},
|
||||
{"query": "where is castle on the hill 2 based on", "answers": ["Framlingham Castle 2"]},
|
||||
]
|
||||
|
||||
query_tok = "facebook/dpr-question_encoder-single-nq-base"
|
||||
@ -688,7 +682,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(query_and_passage_model):
|
||||
"text": "Etalab est une administration publique française qui fait notamment office de Chief Data Officer de l'État et coordonne la conception et la mise en œuvre de sa stratégie dans le domaine de la donnée (ouverture et partage des données publiques ou open data, exploitation des données et intelligence artificielle...). Ainsi, Etalab développe et maintient le portail des données ouvertes du gouvernement français data.gouv.fr. Etalab promeut également une plus grande ouverture l'administration sur la société (gouvernement ouvert) : transparence de l'action publique, innovation ouverte, participation citoyenne... elle promeut l’innovation, l’expérimentation, les méthodes de travail ouvertes, agiles et itératives, ainsi que les synergies avec la société civile pour décloisonner l’administration et favoriser l’adoption des meilleures pratiques professionnelles dans le domaine du numérique. À ce titre elle étudie notamment l’opportunité de recourir à des technologies en voie de maturation issues du monde de la recherche. Cette entité chargée de l'innovation au sein de l'administration doit contribuer à l'amélioration du service public grâce au numérique. Elle est rattachée à la Direction interministérielle du numérique, dont les missions et l’organisation ont été fixées par le décret du 30 octobre 2019. Dirigé par Laure Lucchesi depuis 2016, elle rassemble une équipe pluridisciplinaire d'une trentaine de personnes.",
|
||||
"label": "positive",
|
||||
"external_id": "1",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@ -23,8 +23,7 @@ def test_qa_format_and_results(adaptive_model_qa, multiprocessing_chunksize):
|
||||
ground_truths = ["France", "GameTrailers"]
|
||||
|
||||
results = adaptive_model_qa.inference_from_dicts(
|
||||
dicts=qa_inputs_dicts,
|
||||
multiprocessing_chunksize=multiprocessing_chunksize,
|
||||
dicts=qa_inputs_dicts, multiprocessing_chunksize=multiprocessing_chunksize
|
||||
)
|
||||
# sample results
|
||||
# [
|
||||
|
||||
@ -328,10 +328,7 @@ def test_generate_code_simple_pipeline():
|
||||
def test_generate_code_imports():
|
||||
pipeline_config = {
|
||||
"components": [
|
||||
{
|
||||
"name": "DocumentStore",
|
||||
"type": "ElasticsearchDocumentStore",
|
||||
},
|
||||
{"name": "DocumentStore", "type": "ElasticsearchDocumentStore"},
|
||||
{"name": "retri", "type": "ElasticsearchRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
{"name": "retri2", "type": "EmbeddingRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
],
|
||||
@ -363,10 +360,7 @@ def test_generate_code_imports():
|
||||
def test_generate_code_imports_no_pipeline_cls():
|
||||
pipeline_config = {
|
||||
"components": [
|
||||
{
|
||||
"name": "DocumentStore",
|
||||
"type": "ElasticsearchDocumentStore",
|
||||
},
|
||||
{"name": "DocumentStore", "type": "ElasticsearchDocumentStore"},
|
||||
{"name": "retri", "type": "ElasticsearchRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
],
|
||||
"pipelines": [{"name": "Query", "type": "Pipeline", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
@ -393,10 +387,7 @@ def test_generate_code_imports_no_pipeline_cls():
|
||||
def test_generate_code_comment():
|
||||
pipeline_config = {
|
||||
"components": [
|
||||
{
|
||||
"name": "DocumentStore",
|
||||
"type": "ElasticsearchDocumentStore",
|
||||
},
|
||||
{"name": "DocumentStore", "type": "ElasticsearchDocumentStore"},
|
||||
{"name": "retri", "type": "ElasticsearchRetriever", "params": {"document_store": "DocumentStore"}},
|
||||
],
|
||||
"pipelines": [{"name": "Query", "type": "Pipeline", "nodes": [{"name": "retri", "inputs": ["Query"]}]}],
|
||||
@ -434,10 +425,7 @@ def test_generate_code_is_component_order_invariant():
|
||||
]
|
||||
}
|
||||
|
||||
doc_store = {
|
||||
"name": "ElasticsearchDocumentStore",
|
||||
"type": "ElasticsearchDocumentStore",
|
||||
}
|
||||
doc_store = {"name": "ElasticsearchDocumentStore", "type": "ElasticsearchDocumentStore"}
|
||||
es_retriever = {
|
||||
"name": "EsRetriever",
|
||||
"type": "ElasticsearchRetriever",
|
||||
@ -451,10 +439,7 @@ def test_generate_code_is_component_order_invariant():
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
}
|
||||
join_node = {
|
||||
"name": "JoinResults",
|
||||
"type": "JoinDocuments",
|
||||
}
|
||||
join_node = {"name": "JoinResults", "type": "JoinDocuments"}
|
||||
|
||||
component_orders = [
|
||||
[doc_store, es_retriever, emb_retriever, join_node],
|
||||
@ -515,31 +500,13 @@ def test_validate_pipeline_config_invalid_component_param_key():
|
||||
|
||||
def test_validate_pipeline_config_invalid_pipeline_name():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config(
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"name": "test",
|
||||
"type": "test",
|
||||
}
|
||||
],
|
||||
"pipelines": [{"name": "\btest"}],
|
||||
}
|
||||
)
|
||||
validate_config({"components": [{"name": "test", "type": "test"}], "pipelines": [{"name": "\btest"}]})
|
||||
|
||||
|
||||
def test_validate_pipeline_config_invalid_pipeline_type():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config(
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"name": "test",
|
||||
"type": "test",
|
||||
}
|
||||
],
|
||||
"pipelines": [{"name": "test", "type": "\btest"}],
|
||||
}
|
||||
{"components": [{"name": "test", "type": "test"}], "pipelines": [{"name": "test", "type": "\btest"}]}
|
||||
)
|
||||
|
||||
|
||||
@ -547,12 +514,7 @@ def test_validate_pipeline_config_invalid_pipeline_node_name():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config(
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"name": "test",
|
||||
"type": "test",
|
||||
}
|
||||
],
|
||||
"components": [{"name": "test", "type": "test"}],
|
||||
"pipelines": [{"name": "test", "type": "test", "nodes": [{"name": "\btest"}]}],
|
||||
}
|
||||
)
|
||||
@ -562,12 +524,7 @@ def test_validate_pipeline_config_invalid_pipeline_node_inputs():
|
||||
with pytest.raises(ValueError, match="is not a valid config variable name"):
|
||||
validate_config(
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"name": "test",
|
||||
"type": "test",
|
||||
}
|
||||
],
|
||||
"components": [{"name": "test", "type": "test"}],
|
||||
"pipelines": [{"name": "test", "type": "test", "nodes": [{"name": "test", "inputs": ["\btest"]}]}],
|
||||
}
|
||||
)
|
||||
@ -1138,8 +1095,7 @@ def test_documentsearch_es_authentication(retriever_with_docs, document_store_wi
|
||||
auth_headers = {"Authorization": "Basic YWRtaW46cm9vdA=="}
|
||||
pipeline = DocumentSearchPipeline(retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={"Retriever": {"top_k": 10, "headers": auth_headers}},
|
||||
query="Who lives in Berlin?", params={"Retriever": {"top_k": 10, "headers": auth_headers}}
|
||||
)
|
||||
assert prediction is not None
|
||||
assert len(prediction["documents"]) == 5
|
||||
@ -1162,13 +1118,11 @@ def test_documentsearch_document_store_authentication(retriever_with_docs, docum
|
||||
if not mock_client:
|
||||
with pytest.raises(Exception):
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={"Retriever": {"top_k": 10, "headers": auth_headers}},
|
||||
query="Who lives in Berlin?", params={"Retriever": {"top_k": 10, "headers": auth_headers}}
|
||||
)
|
||||
else:
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={"Retriever": {"top_k": 10, "headers": auth_headers}},
|
||||
query="Who lives in Berlin?", params={"Retriever": {"top_k": 10, "headers": auth_headers}}
|
||||
)
|
||||
assert prediction is not None
|
||||
assert len(prediction["documents"]) == 5
|
||||
|
||||
@ -3,14 +3,8 @@ from pathlib import Path
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from haystack.pipelines import (
|
||||
Pipeline,
|
||||
RootNode,
|
||||
)
|
||||
from haystack.nodes import (
|
||||
FARMReader,
|
||||
ElasticsearchRetriever,
|
||||
)
|
||||
from haystack.pipelines import Pipeline, RootNode
|
||||
from haystack.nodes import FARMReader, ElasticsearchRetriever
|
||||
|
||||
from conftest import SAMPLES_PATH
|
||||
|
||||
@ -88,8 +82,7 @@ def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["ESRetriever"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={"ESRetriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3}},
|
||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3}}
|
||||
)
|
||||
assert "_debug" in prediction.keys()
|
||||
assert "ESRetriever" in prediction["_debug"].keys()
|
||||
|
||||
@ -9,10 +9,7 @@ from haystack.schema import Answer
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}},
|
||||
)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}})
|
||||
assert prediction is not None
|
||||
assert type(prediction["answers"][0]) == Answer
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
@ -70,9 +67,7 @@ def test_extractive_qa_answers_single_result(reader, retriever_with_docs):
|
||||
def test_extractive_qa_answers_with_translator(reader, retriever_with_docs, en_to_de_translator, de_to_en_translator):
|
||||
base_pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=de_to_en_translator,
|
||||
output_translator=en_to_de_translator,
|
||||
pipeline=base_pipeline,
|
||||
input_translator=de_to_en_translator, output_translator=en_to_de_translator, pipeline=base_pipeline
|
||||
)
|
||||
|
||||
prediction = pipeline.run(query="Wer lebt in Berlin?", params={"Reader": {"top_k": 3}})
|
||||
|
||||
@ -27,10 +27,7 @@ def test_preprocess_sentence_split():
|
||||
assert len(documents) == 15
|
||||
|
||||
preprocessor = PreProcessor(
|
||||
split_length=10,
|
||||
split_overlap=0,
|
||||
split_by="sentence",
|
||||
split_respect_sentence_boundary=False,
|
||||
split_length=10, split_overlap=0, split_by="sentence", split_respect_sentence_boundary=False
|
||||
)
|
||||
documents = preprocessor.process(document)
|
||||
assert len(documents) == 2
|
||||
|
||||
@ -11,9 +11,7 @@ from conftest import SAMPLES_PATH
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_load_pipeline(document_store_with_docs):
|
||||
pipeline = RayPipeline.load_from_yaml(
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline.yaml",
|
||||
pipeline_name="ray_query_pipeline",
|
||||
num_cpus=8,
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline.yaml", pipeline_name="ray_query_pipeline", num_cpus=8
|
||||
)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}})
|
||||
|
||||
|
||||
@ -6,13 +6,7 @@ import math
|
||||
import pytest
|
||||
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.pipelines import (
|
||||
Pipeline,
|
||||
FAQPipeline,
|
||||
DocumentSearchPipeline,
|
||||
RootNode,
|
||||
MostSimilarDocumentsPipeline,
|
||||
)
|
||||
from haystack.pipelines import Pipeline, FAQPipeline, DocumentSearchPipeline, RootNode, MostSimilarDocumentsPipeline
|
||||
from haystack.nodes import (
|
||||
DensePassageRetriever,
|
||||
ElasticsearchRetriever,
|
||||
@ -27,36 +21,16 @@ from conftest import SAMPLES_PATH
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[
|
||||
("embedding", "memory"),
|
||||
("embedding", "faiss"),
|
||||
("embedding", "milvus1"),
|
||||
("embedding", "elasticsearch"),
|
||||
],
|
||||
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus1"), ("embedding", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_faq_pipeline(retriever, document_store):
|
||||
documents = [
|
||||
{
|
||||
"content": "How to test module-1?",
|
||||
"meta": {"source": "wiki1", "answer": "Using tests for module-1"},
|
||||
},
|
||||
{
|
||||
"content": "How to test module-2?",
|
||||
"meta": {"source": "wiki2", "answer": "Using tests for module-2"},
|
||||
},
|
||||
{
|
||||
"content": "How to test module-3?",
|
||||
"meta": {"source": "wiki3", "answer": "Using tests for module-3"},
|
||||
},
|
||||
{
|
||||
"content": "How to test module-4?",
|
||||
"meta": {"source": "wiki4", "answer": "Using tests for module-4"},
|
||||
},
|
||||
{
|
||||
"content": "How to test module-5?",
|
||||
"meta": {"source": "wiki5", "answer": "Using tests for module-5"},
|
||||
},
|
||||
{"content": "How to test module-1?", "meta": {"source": "wiki1", "answer": "Using tests for module-1"}},
|
||||
{"content": "How to test module-2?", "meta": {"source": "wiki2", "answer": "Using tests for module-2"}},
|
||||
{"content": "How to test module-3?", "meta": {"source": "wiki3", "answer": "Using tests for module-3"}},
|
||||
{"content": "How to test module-4?", "meta": {"source": "wiki4", "answer": "Using tests for module-4"}},
|
||||
{"content": "How to test module-5?", "meta": {"source": "wiki5", "answer": "Using tests for module-5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
@ -103,11 +77,7 @@ def test_document_search_pipeline(retriever, document_store):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[
|
||||
("embedding", "faiss"),
|
||||
("embedding", "milvus1"),
|
||||
("embedding", "elasticsearch"),
|
||||
],
|
||||
[("embedding", "faiss"), ("embedding", "milvus1"), ("embedding", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_most_similar_documents_pipeline(retriever, document_store):
|
||||
@ -308,20 +278,12 @@ def test_query_keyword_statement_classifier():
|
||||
return kwargs, "output_2"
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="SkQueryKeywordQuestionClassifier", component=SklearnQueryClassifier(), inputs=["Query"])
|
||||
pipeline.add_node(
|
||||
name="SkQueryKeywordQuestionClassifier",
|
||||
component=SklearnQueryClassifier(),
|
||||
inputs=["Query"],
|
||||
name="KeywordNode", component=KeywordOutput(), inputs=["SkQueryKeywordQuestionClassifier.output_2"]
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="KeywordNode",
|
||||
component=KeywordOutput(),
|
||||
inputs=["SkQueryKeywordQuestionClassifier.output_2"],
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="QuestionNode",
|
||||
component=QuestionOutput(),
|
||||
inputs=["SkQueryKeywordQuestionClassifier.output_1"],
|
||||
name="QuestionNode", component=QuestionOutput(), inputs=["SkQueryKeywordQuestionClassifier.output_1"]
|
||||
)
|
||||
output = pipeline.run(query="morse code")
|
||||
assert output["output"] == "keyword"
|
||||
@ -331,19 +293,13 @@ def test_query_keyword_statement_classifier():
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(
|
||||
name="TfQueryKeywordQuestionClassifier",
|
||||
component=TransformersQueryClassifier(),
|
||||
inputs=["Query"],
|
||||
name="TfQueryKeywordQuestionClassifier", component=TransformersQueryClassifier(), inputs=["Query"]
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="KeywordNode",
|
||||
component=KeywordOutput(),
|
||||
inputs=["TfQueryKeywordQuestionClassifier.output_2"],
|
||||
name="KeywordNode", component=KeywordOutput(), inputs=["TfQueryKeywordQuestionClassifier.output_2"]
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="QuestionNode",
|
||||
component=QuestionOutput(),
|
||||
inputs=["TfQueryKeywordQuestionClassifier.output_1"],
|
||||
name="QuestionNode", component=QuestionOutput(), inputs=["TfQueryKeywordQuestionClassifier.output_1"]
|
||||
)
|
||||
output = pipeline.run(query="morse code")
|
||||
assert output["output"] == "keyword"
|
||||
|
||||
@ -6,7 +6,7 @@ from haystack.nodes import DensePassageRetriever, EmbeddingRetriever
|
||||
|
||||
DOCS = [
|
||||
Document(
|
||||
content="""PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""",
|
||||
content="""PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
),
|
||||
Document(
|
||||
content="""The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."""
|
||||
@ -55,9 +55,7 @@ def test_summarization_one_summary(summarizer):
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.summarizer
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("elasticsearch", "elasticsearch")],
|
||||
indirect=True,
|
||||
"retriever,document_store", [("embedding", "memory"), ("elasticsearch", "elasticsearch")], indirect=True
|
||||
)
|
||||
def test_summarization_pipeline(document_store, retriever, summarizer):
|
||||
document_store.write_documents(DOCS)
|
||||
@ -76,9 +74,7 @@ def test_summarization_pipeline(document_store, retriever, summarizer):
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.summarizer
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("elasticsearch", "elasticsearch")],
|
||||
indirect=True,
|
||||
"retriever,document_store", [("embedding", "memory"), ("elasticsearch", "elasticsearch")], indirect=True
|
||||
)
|
||||
def test_summarization_pipeline_one_summary(document_store, retriever, summarizer):
|
||||
document_store.write_documents(SPLIT_DOCS)
|
||||
|
||||
@ -9,9 +9,7 @@ from test_summarizer import SPLIT_DOCS
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.summarizer
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("elasticsearch", "elasticsearch")],
|
||||
indirect=True,
|
||||
"retriever,document_store", [("embedding", "memory"), ("elasticsearch", "elasticsearch")], indirect=True
|
||||
)
|
||||
def test_summarization_pipeline_with_translator(
|
||||
document_store, retriever, summarizer, en_to_de_translator, de_to_en_translator
|
||||
|
||||
@ -40,11 +40,7 @@ TEXTS = [
|
||||
def test_basic_loading(caplog):
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
# slow tokenizers
|
||||
tokenizer = Tokenizer.load(
|
||||
pretrained_model_name_or_path="bert-base-cased",
|
||||
do_lower_case=True,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer = Tokenizer.load(pretrained_model_name_or_path="bert-base-cased", do_lower_case=True, use_fast=False)
|
||||
assert type(tokenizer) == BertTokenizer
|
||||
assert tokenizer.basic_tokenizer.do_lower_case == True
|
||||
|
||||
@ -227,13 +223,7 @@ def test_save_load(caplog):
|
||||
assert data_before == data_after
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"bert-base-german-cased",
|
||||
"google/electra-small-discriminator",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model_name", ["bert-base-german-cased", "google/electra-small-discriminator"])
|
||||
def test_fast_tokenizer_with_examples(caplog, model_name):
|
||||
fast_tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=True)
|
||||
tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=False)
|
||||
@ -306,10 +296,7 @@ def test_all_tokenizer_on_special_cases(caplog):
|
||||
# token offsets are originally relative to the beginning of the word
|
||||
# These lines convert them so they are relative to the beginning of the sentence
|
||||
token_offsets = []
|
||||
for (
|
||||
(start, end),
|
||||
w_index,
|
||||
) in zip(encoded.offsets, encoded.words):
|
||||
for ((start, end), w_index) in zip(encoded.offsets, encoded.words):
|
||||
word_start_ch = word_spans[w_index][0]
|
||||
token_offsets.append((start + word_start_ch, end + word_start_ch))
|
||||
if getattr(tokenizer, "add_prefix_space", None):
|
||||
@ -468,10 +455,7 @@ def test_fast_bert_custom_vocab(caplog):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, tokenizer_type",
|
||||
[
|
||||
("bert-base-german-cased", BertTokenizerFast),
|
||||
("google/electra-small-discriminator", ElectraTokenizerFast),
|
||||
],
|
||||
[("bert-base-german-cased", BertTokenizerFast), ("google/electra-small-discriminator", ElectraTokenizerFast)],
|
||||
)
|
||||
def test_fast_tokenizer_type(caplog, model_name, tokenizer_type):
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
|
||||
@ -260,8 +260,7 @@
|
||||
"# Prebuilt pipeline\n",
|
||||
"p_extractive_premade = ExtractiveQAPipeline(reader=reader, retriever=es_retriever)\n",
|
||||
"res = p_extractive_premade.run(\n",
|
||||
" query=\"Who is the father of Arya Stark?\",\n",
|
||||
" params={\"Retriever\": {\"top_k\": 10}, \"Reader\": {\"top_k\": 5}},\n",
|
||||
" query=\"Who is the father of Arya Stark?\", params={\"Retriever\": {\"top_k\": 10}, \"Reader\": {\"top_k\": 5}}\n",
|
||||
")\n",
|
||||
"print_answers(res, details=\"minimum\")"
|
||||
]
|
||||
@ -289,10 +288,7 @@
|
||||
"from haystack.pipelines import DocumentSearchPipeline\n",
|
||||
"\n",
|
||||
"p_retrieval = DocumentSearchPipeline(es_retriever)\n",
|
||||
"res = p_retrieval.run(\n",
|
||||
" query=\"Who is the father of Arya Stark?\",\n",
|
||||
" params={\"Retriever\": {\"top_k\": 10}},\n",
|
||||
")\n",
|
||||
"res = p_retrieval.run(query=\"Who is the father of Arya Stark?\", params={\"Retriever\": {\"top_k\": 10}})\n",
|
||||
"print_documents(res, max_text_len=200)"
|
||||
]
|
||||
},
|
||||
|
||||
@ -55,10 +55,7 @@ def tutorial11_pipelines():
|
||||
|
||||
query = "Who is the father of Arya Stark?"
|
||||
p_extractive_premade = ExtractiveQAPipeline(reader=reader, retriever=es_retriever)
|
||||
res = p_extractive_premade.run(
|
||||
query=query,
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}},
|
||||
)
|
||||
res = p_extractive_premade.run(query=query, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}})
|
||||
print("\nQuery: ", query)
|
||||
print("Answers:")
|
||||
print_answers(res, details="minimum")
|
||||
@ -69,10 +66,7 @@ def tutorial11_pipelines():
|
||||
|
||||
query = "Who is the father of Arya Stark?"
|
||||
p_retrieval = DocumentSearchPipeline(es_retriever)
|
||||
res = p_retrieval.run(
|
||||
query=query,
|
||||
params={"Retriever": {"top_k": 10}},
|
||||
)
|
||||
res = p_retrieval.run(query=query, params={"Retriever": {"top_k": 10}})
|
||||
print()
|
||||
print_documents(res, max_text_len=200)
|
||||
|
||||
@ -90,10 +84,7 @@ def tutorial11_pipelines():
|
||||
# Generative QA
|
||||
query = "Who is the father of Arya Stark?"
|
||||
p_generator = GenerativeQAPipeline(generator=rag_generator, retriever=dpr_retriever)
|
||||
res = p_generator.run(
|
||||
query=query,
|
||||
params={"Retriever": {"top_k": 10}},
|
||||
)
|
||||
res = p_generator.run(query=query, params={"Retriever": {"top_k": 10}})
|
||||
print()
|
||||
print_answers(res, details="minimum")
|
||||
|
||||
@ -125,10 +116,7 @@ def tutorial11_pipelines():
|
||||
|
||||
# Now we can run it
|
||||
query = "Who is the father of Arya Stark?"
|
||||
res = p_extractive.run(
|
||||
query=query,
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}},
|
||||
)
|
||||
res = p_extractive.run(query=query, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}})
|
||||
print("\nQuery: ", query)
|
||||
print("Answers:")
|
||||
print_answers(res, details="minimum")
|
||||
@ -151,8 +139,7 @@ def tutorial11_pipelines():
|
||||
# Run pipeline
|
||||
query = "Who is the father of Arya Stark?"
|
||||
res = p_ensemble.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
params={"ESRetriever": {"top_k": 5}, "DPRRetriever": {"top_k": 5}},
|
||||
query="Who is the father of Arya Stark?", params={"ESRetriever": {"top_k": 5}, "DPRRetriever": {"top_k": 5}}
|
||||
)
|
||||
print("\nQuery: ", query)
|
||||
print("Answers:")
|
||||
@ -186,9 +173,7 @@ def tutorial11_pipelines():
|
||||
|
||||
# Run only the dense retriever on the full sentence query
|
||||
query = "Who is the father of Arya Stark?"
|
||||
res_1 = p_classifier.run(
|
||||
query=query,
|
||||
)
|
||||
res_1 = p_classifier.run(query=query)
|
||||
print()
|
||||
print("\nQuery: ", query)
|
||||
print(" * DPR Answers:")
|
||||
@ -196,9 +181,7 @@ def tutorial11_pipelines():
|
||||
|
||||
# Run only the sparse retriever on a keyword based query
|
||||
query = "Arya Stark father"
|
||||
res_2 = p_classifier.run(
|
||||
query=query,
|
||||
)
|
||||
res_2 = p_classifier.run(query=query)
|
||||
print()
|
||||
print("\nQuery: ", query)
|
||||
print(" * ES Answers:")
|
||||
|
||||
@ -51,49 +51,37 @@ def tutorial14_query_classifier():
|
||||
sklearn_keyword_classifier.draw("pipeline_classifier.png")
|
||||
|
||||
# Run only the dense retriever on the full sentence query
|
||||
res_1 = sklearn_keyword_classifier.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
)
|
||||
res_1 = sklearn_keyword_classifier.run(query="Who is the father of Arya Stark?")
|
||||
print("\n===============================")
|
||||
print("DPR Results" + "\n" + "=" * 15)
|
||||
print_answers(res_1, details="minimum")
|
||||
|
||||
# Run only the sparse retriever on a keyword based query
|
||||
res_2 = sklearn_keyword_classifier.run(
|
||||
query="arya stark father",
|
||||
)
|
||||
res_2 = sklearn_keyword_classifier.run(query="arya stark father")
|
||||
print("\n===============================")
|
||||
print("ES Results" + "\n" + "=" * 15)
|
||||
print_answers(res_2, details="minimum")
|
||||
|
||||
# Run only the dense retriever on the full sentence query
|
||||
res_3 = sklearn_keyword_classifier.run(
|
||||
query="which country was jon snow filmed ?",
|
||||
)
|
||||
res_3 = sklearn_keyword_classifier.run(query="which country was jon snow filmed ?")
|
||||
print("\n===============================")
|
||||
print("DPR Results" + "\n" + "=" * 15)
|
||||
print_answers(res_3, details="minimum")
|
||||
|
||||
# Run only the sparse retriever on a keyword based query
|
||||
res_4 = sklearn_keyword_classifier.run(
|
||||
query="jon snow country",
|
||||
)
|
||||
res_4 = sklearn_keyword_classifier.run(query="jon snow country")
|
||||
print("\n===============================")
|
||||
print("ES Results" + "\n" + "=" * 15)
|
||||
print_answers(res_4, details="minimum")
|
||||
|
||||
# Run only the dense retriever on the full sentence query
|
||||
res_5 = sklearn_keyword_classifier.run(
|
||||
query="who are the younger brothers of arya stark ?",
|
||||
)
|
||||
res_5 = sklearn_keyword_classifier.run(query="who are the younger brothers of arya stark ?")
|
||||
print("\n===============================")
|
||||
print("DPR Results" + "\n" + "=" * 15)
|
||||
print_answers(res_5, details="minimum")
|
||||
|
||||
# Run only the sparse retriever on a keyword based query
|
||||
res_6 = sklearn_keyword_classifier.run(
|
||||
query="arya stark younger brothers",
|
||||
)
|
||||
res_6 = sklearn_keyword_classifier.run(query="arya stark younger brothers")
|
||||
print("\n===============================")
|
||||
print("ES Results" + "\n" + "=" * 15)
|
||||
print_answers(res_6, details="minimum")
|
||||
@ -116,49 +104,37 @@ def tutorial14_query_classifier():
|
||||
transformer_keyword_classifier.draw("pipeline_classifier.png")
|
||||
|
||||
# Run only the dense retriever on the full sentence query
|
||||
res_1 = transformer_keyword_classifier.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
)
|
||||
res_1 = transformer_keyword_classifier.run(query="Who is the father of Arya Stark?")
|
||||
print("\n===============================")
|
||||
print("DPR Results" + "\n" + "=" * 15)
|
||||
print_answers(res_1, details="minimum")
|
||||
|
||||
# Run only the sparse retriever on a keyword based query
|
||||
res_2 = transformer_keyword_classifier.run(
|
||||
query="arya stark father",
|
||||
)
|
||||
res_2 = transformer_keyword_classifier.run(query="arya stark father")
|
||||
print("\n===============================")
|
||||
print("ES Results" + "\n" + "=" * 15)
|
||||
print_answers(res_2, details="minimum")
|
||||
|
||||
# Run only the dense retriever on the full sentence query
|
||||
res_3 = transformer_keyword_classifier.run(
|
||||
query="which country was jon snow filmed ?",
|
||||
)
|
||||
res_3 = transformer_keyword_classifier.run(query="which country was jon snow filmed ?")
|
||||
print("\n===============================")
|
||||
print("DPR Results" + "\n" + "=" * 15)
|
||||
print_answers(res_3, details="minimum")
|
||||
|
||||
# Run only the sparse retriever on a keyword based query
|
||||
res_4 = transformer_keyword_classifier.run(
|
||||
query="jon snow country",
|
||||
)
|
||||
res_4 = transformer_keyword_classifier.run(query="jon snow country")
|
||||
print("\n===============================")
|
||||
print("ES Results" + "\n" + "=" * 15)
|
||||
print_answers(res_4, details="minimum")
|
||||
|
||||
# Run only the dense retriever on the full sentence query
|
||||
res_5 = transformer_keyword_classifier.run(
|
||||
query="who are the younger brothers of arya stark ?",
|
||||
)
|
||||
res_5 = transformer_keyword_classifier.run(query="who are the younger brothers of arya stark ?")
|
||||
print("\n===============================")
|
||||
print("DPR Results" + "\n" + "=" * 15)
|
||||
print_answers(res_5, details="minimum")
|
||||
|
||||
# Run only the sparse retriever on a keyword based query
|
||||
res_6 = transformer_keyword_classifier.run(
|
||||
query="arya stark younger brothers",
|
||||
)
|
||||
res_6 = transformer_keyword_classifier.run(query="arya stark younger brothers")
|
||||
print("\n===============================")
|
||||
print("ES Results" + "\n" + "=" * 15)
|
||||
print_answers(res_6, details="minimum")
|
||||
@ -179,17 +155,13 @@ def tutorial14_query_classifier():
|
||||
transformer_question_classifier.draw("question_classifier.png")
|
||||
|
||||
# Run only the QA reader on the question query
|
||||
res_1 = transformer_question_classifier.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
)
|
||||
res_1 = transformer_question_classifier.run(query="Who is the father of Arya Stark?")
|
||||
print("\n===============================")
|
||||
print("DPR Results" + "\n" + "=" * 15)
|
||||
print_answers(res_1, details="minimum")
|
||||
|
||||
# Show only DPR results
|
||||
res_2 = transformer_question_classifier.run(
|
||||
query="Arya Stark was the daughter of a Lord.",
|
||||
)
|
||||
res_2 = transformer_question_classifier.run(query="Arya Stark was the daughter of a Lord.")
|
||||
print("\n===============================")
|
||||
print("ES Results" + "\n" + "=" * 15)
|
||||
print_answers(res_2, details="minimum")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user