Fix normalize_embedding using numba (#2347)

* fix normalize_embedding using numba

* Update Documentation & Code Style

* fix too-many-public-methods pylint msg

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2022-03-22 23:04:55 +01:00 committed by GitHub
parent 7e6ff8a205
commit 851fe1cf07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 13 deletions

View File

@ -212,8 +212,6 @@ TODO drop params
#### normalize\_embedding
```python
@staticmethod
@njit
def normalize_embedding(emb: np.ndarray) -> None
```

View File

@ -341,9 +341,7 @@ class BaseDocumentStore(BaseComponent):
) -> int:
pass
@staticmethod
@njit # (fastmath=True)
def normalize_embedding(emb: np.ndarray) -> None:
def normalize_embedding(self, emb: np.ndarray) -> None:
"""
Performs L2 normalization of embeddings vector inplace. Input can be a single vector (1D array) or a matrix
(2D array).
@ -352,16 +350,26 @@ class BaseDocumentStore(BaseComponent):
# Single vec
if len(emb.shape) == 1:
norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm()
if norm != 0.0:
emb /= norm
self._normalize_embedding_1D(emb)
# 2D matrix
else:
for vec in emb:
vec = np.ascontiguousarray(vec)
norm = np.sqrt(vec.dot(vec))
if norm != 0.0:
vec /= norm
self._normalize_embedding_2D(emb)
@staticmethod
@njit # (fastmath=True)
def _normalize_embedding_1D(emb: np.ndarray) -> None:
norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm()
if norm != 0.0:
emb /= norm
@staticmethod
@njit # (fastmath=True)
def _normalize_embedding_2D(emb: np.ndarray) -> None:
for vec in emb:
vec = np.ascontiguousarray(vec)
norm = np.sqrt(vec.dot(vec))
if norm != 0.0:
vec /= norm
def finalize_raw_score(self, raw_score: float, similarity: Optional[str]) -> float:
if similarity == "cosine":