Fix mypy typing (#792)

This commit is contained in:
Tanay Soni 2021-02-01 12:15:36 +01:00 committed by GitHub
parent 1dc74c7067
commit d62355ca88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 55 additions and 48 deletions

View File

@ -295,7 +295,7 @@ that are most relevant to the query as defined by the BM25 algorithm.
#### query\_by\_embedding
```python
| query_by_embedding(query_emb: np.array, filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None) -> List[Document]
| query_by_embedding(query_emb: np.ndarray, filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None) -> List[Document]
```
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
@ -453,7 +453,7 @@ Fetch documents by specifying a list of text id strings
#### query\_by\_embedding
```python
| query_by_embedding(query_emb: List[float], filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None) -> List[Document]
| query_by_embedding(query_emb: np.ndarray, filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None) -> List[Document]
```
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
@ -863,7 +863,7 @@ Example: {"name": ["some", "more"], "category": ["only_one"]}
#### train\_index
```python
| train_index(documents: Optional[Union[List[dict], List[Document]]], embeddings: Optional[np.array] = None)
| train_index(documents: Optional[Union[List[dict], List[Document]]], embeddings: Optional[np.ndarray] = None)
```
Some FAISS indices (e.g. IVF) require initial "training" on a sample of vectors before you can add your final vectors.
@ -892,7 +892,7 @@ Delete all documents from the document store.
#### query\_by\_embedding
```python
| query_by_embedding(query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None) -> List[Document]
| query_by_embedding(query_emb: np.ndarray, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None) -> List[Document]
```
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.

View File

@ -286,7 +286,7 @@ that are most relevant to the query.
#### embed\_queries
```python
| embed_queries(texts: List[str]) -> List[np.array]
| embed_queries(texts: List[str]) -> List[np.ndarray]
```
Create embeddings for a list of queries using the query encoder
@ -303,7 +303,7 @@ Embeddings, one per input queries
#### embed\_passages
```python
| embed_passages(docs: List[Document]) -> List[np.array]
| embed_passages(docs: List[Document]) -> List[np.ndarray]
```
Create embeddings for a list of passages using the passage encoder
@ -434,7 +434,7 @@ that are most relevant to the query.
#### embed
```python
| embed(texts: Union[List[str], str]) -> List[np.array]
| embed(texts: Union[List[str], str]) -> List[np.ndarray]
```
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
@ -451,7 +451,7 @@ List of embeddings (one per input text). Each embedding is a list of floats.
#### embed\_queries
```python
| embed_queries(texts: List[str]) -> List[np.array]
| embed_queries(texts: List[str]) -> List[np.ndarray]
```
Create embeddings for a list of queries. For this Retriever type: The same as calling .embed()
@ -468,7 +468,7 @@ Embeddings, one per input queries
#### embed\_passages
```python
| embed_passages(docs: List[Document]) -> List[np.array]
| embed_passages(docs: List[Document]) -> List[np.ndarray]
```
Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()

View File

@ -1,11 +1,13 @@
import logging
from abc import abstractmethod, ABC
from pathlib import Path
from typing import Any, Optional, Dict, List, Union
from haystack import Document, Label, MultiLabel
from haystack.preprocessor.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
from haystack.preprocessor.preprocessor import PreProcessor
from typing import Optional, Dict, List, Union
import numpy as np
from haystack import Document, Label, MultiLabel
from haystack.preprocessor.preprocessor import PreProcessor
from haystack.preprocessor.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
logger = logging.getLogger(__name__)
@ -64,7 +66,7 @@ class BaseDocumentStore(ABC):
all_labels = self.get_all_labels(index=index, filters=filters)
# Collect all answers to a question in a dict
question_ans_dict = {} # type: ignore
question_ans_dict: dict = {}
for l in all_labels:
# only aggregate labels with correct answers, as only those can be currently used in evaluation
if not l.is_correct_answer:
@ -125,7 +127,7 @@ class BaseDocumentStore(ABC):
@abstractmethod
def query_by_embedding(self,
query_emb: List[float],
query_emb: np.ndarray,
filters: Optional[Optional[Dict[str, List[str]]]] = None,
top_k: int = 10,
index: Optional[str] = None,

View File

@ -568,7 +568,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return documents
def query_by_embedding(self,
query_emb: np.array,
query_emb: np.ndarray,
filters: Optional[Dict[str, List[str]]] = None,
top_k: int = 10,
index: Optional[str] = None,
@ -631,7 +631,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
]
return documents
def _get_vector_similarity_query(self, query_emb: np.array, top_k: int):
def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int):
"""
Generate Elasticsearch query for vector similarity.
"""
@ -849,7 +849,7 @@ class OpenDistroElasticsearchDocumentStore(ElasticsearchDocumentStore):
if not self.client.indices.exists(index=index_name):
raise e
def _get_vector_similarity_query(self, query_emb: np.array, top_k: int):
def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int):
"""
Generate Elasticsearch query for vector similarity.
"""

View File

@ -142,8 +142,8 @@ class FAISSDocumentStore(SQLDocumentStore):
for i in range(0, len(document_objects), batch_size):
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]]
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.add(embeddings)
embeddings_to_index = np.array(embeddings, dtype="float32")
self.faiss_index.add(embeddings_to_index)
docs_to_write_in_sql = []
for doc in document_objects[i: i + batch_size]:
@ -259,7 +259,9 @@ class FAISSDocumentStore(SQLDocumentStore):
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
return documents
def train_index(self, documents: Optional[Union[List[dict], List[Document]]], embeddings: Optional[np.array] = None):
def train_index(
self, documents: Optional[Union[List[dict], List[Document]]], embeddings: Optional[np.ndarray] = None
):
"""
Some FAISS indices (e.g. IVF) require initial "training" on a sample of vectors before you can add your final vectors.
The train vectors should come from the same distribution as your final ones.
@ -274,9 +276,11 @@ class FAISSDocumentStore(SQLDocumentStore):
raise ValueError("Either pass `documents` or `embeddings`. You passed both.")
if documents:
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
embeddings = [doc.embedding for doc in document_objects]
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.train(embeddings)
doc_embeddings = [doc.embedding for doc in document_objects]
embeddings_for_train = np.array(doc_embeddings, dtype="float32")
self.faiss_index.train(embeddings_for_train)
if embeddings:
self.faiss_index.train(embeddings)
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
@ -287,7 +291,7 @@ class FAISSDocumentStore(SQLDocumentStore):
super().delete_all_documents(index=index, filters=filters)
def query_by_embedding(self,
query_emb: np.array,
query_emb: np.ndarray,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None,

View File

@ -1,15 +1,15 @@
import logging
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List, Optional, Union, Generator
from uuid import uuid4
from collections import defaultdict
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.retriever.base import BaseRetriever
import numpy as np
from scipy.spatial.distance import cosine
import logging
from haystack import Document, Label
from haystack.document_store.base import BaseDocumentStore
from haystack.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
@ -94,7 +94,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
return documents
def query_by_embedding(self,
query_emb: List[float],
query_emb: np.ndarray,
filters: Optional[Dict[str, List[str]]] = None,
top_k: int = 10,
index: Optional[str] = None,

View File

@ -258,7 +258,7 @@ class MilvusDocumentStore(SQLDocumentStore):
self.milvus_server.compact(collection_name=index)
def query_by_embedding(self,
query_emb: np.array,
query_emb: np.ndarray,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None,
@ -458,7 +458,7 @@ class MilvusDocumentStore(SQLDocumentStore):
if status.code != Status.SUCCESS:
raise RuntimeError("E existing vector ids deletion failed: {status}")
def get_all_vectors(self, index=None) -> List[np.array]:
def get_all_vectors(self, index: Optional[str] = None) -> List[np.ndarray]:
"""
Helper function to dump all vectors stored in Milvus server.

View File

@ -3,6 +3,7 @@ import logging
from typing import Any, Dict, Union, List, Optional, Generator
from uuid import uuid4
import numpy as np
from sqlalchemy import and_, func, create_engine, Column, Integer, String, DateTime, ForeignKey, Boolean, Text, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
@ -136,7 +137,7 @@ class SQLDocumentStore(BaseDocumentStore):
for row in query.all():
documents.append(self._convert_sql_row_to_document(row))
sorted_documents = sorted(documents, key=lambda doc: vector_ids.index(doc.meta["vector_id"])) # type: ignore
sorted_documents = sorted(documents, key=lambda doc: vector_ids.index(doc.meta["vector_id"]))
return sorted_documents
def get_all_documents(
@ -196,7 +197,7 @@ class SQLDocumentStore(BaseDocumentStore):
documents_map[row.id] = Document(
id=row.id,
text=row.text,
meta=None if row.vector_id is None else {"vector_id": row.vector_id} # type: ignore
meta=None if row.vector_id is None else {"vector_id": row.vector_id}
)
if i % batch_size == 0:
documents_map = self._get_documents_meta(documents_map)
@ -215,7 +216,7 @@ class SQLDocumentStore(BaseDocumentStore):
).filter(MetaORM.document_id.in_(doc_ids))
for row in meta_query.all():
documents_map[row.document_id].meta[row.name] = row.value # type: ignore
documents_map[row.document_id].meta[row.name] = row.value
return documents_map
def get_all_labels(self, index=None, filters: Optional[dict] = None):
@ -389,7 +390,7 @@ class SQLDocumentStore(BaseDocumentStore):
return label
def query_by_embedding(self,
query_emb: List[float],
query_emb: np.ndarray,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None,

View File

@ -166,7 +166,7 @@ class RAGenerator(BaseGenerator):
return contextualized_inputs["input_ids"].to(self.device), \
contextualized_inputs["attention_mask"].to(self.device)
def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[Optional[numpy.ndarray]]) -> torch.Tensor:
def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[numpy.ndarray]) -> torch.Tensor:
# If document missing embedding, then need embedding for all the documents
is_embedding_required = embeddings is None or any(embedding is None for embedding in embeddings)

View File

@ -28,10 +28,10 @@ class BaseReader(ABC):
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap is a list of this most significant difference per document
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
no_ans_gap_array = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gap_array)
# all passages "no answer" as top score
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
if np.sum(no_ans_gap_array < 0) == len(no_ans_gap_array):
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap

View File

@ -210,7 +210,7 @@ class DensePassageRetriever(BaseRetriever):
all_embeddings["query"] = np.concatenate(all_embeddings["query"])
return all_embeddings
def embed_queries(self, texts: List[str]) -> List[np.array]:
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
"""
Create embeddings for a list of queries using the query encoder
@ -221,7 +221,7 @@ class DensePassageRetriever(BaseRetriever):
result = self._get_predictions(queries)["query"]
return result
def embed_passages(self, docs: List[Document]) -> List[np.array]:
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
"""
Create embeddings for a list of passages using the passage encoder
@ -483,7 +483,7 @@ class EmbeddingRetriever(BaseRetriever):
top_k=top_k, index=index)
return documents
def embed(self, texts: Union[List[str], str]) -> List[np.array]:
def embed(self, texts: Union[List[str], str]) -> List[np.ndarray]:
"""
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
@ -508,7 +508,7 @@ class EmbeddingRetriever(BaseRetriever):
emb = [r for r in emb]
return emb
def embed_queries(self, texts: List[str]) -> List[np.array]:
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
"""
Create embeddings for a list of queries. For this Retriever type: The same as calling .embed()
@ -517,7 +517,7 @@ class EmbeddingRetriever(BaseRetriever):
"""
return self.embed(texts)
def embed_passages(self, docs: List[Document]) -> List[np.array]:
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
"""
Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()

View File

@ -11,7 +11,7 @@ class Document:
probability: Optional[float] = None,
question: Optional[str] = None,
meta: Dict[str, Any] = None,
embedding: Optional[np.array] = None):
embedding: Optional[np.ndarray] = None):
"""
Object used to represent documents / passages in a standardized way within Haystack.
For example, this is what the retriever will return from the DocumentStore,