Fix Numba TypingError in normalize_embedding for cosine similarity (#1933)

* Fix Numba TypingError

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
bogdankostic 2022-01-03 17:14:51 +01:00 committed by GitHub
parent 202ef276ee
commit 3e0ef1cc8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 6 deletions

View File

@ -130,11 +130,13 @@ object, provided that they have the same product_id (to be found in Label.meta["
#### normalize\_embedding
```python
| @staticmethod
| @njit
| normalize_embedding(emb: np.ndarray) -> None
```
Performs L2 normalization of embeddings vector inplace. Input can be a single vector (1D array) or a matrix (2D array).
Performs L2 normalization of embeddings vector inplace. Input can be a single vector (1D array) or a matrix
(2D array).
<a name="base.BaseDocumentStore.add_eval_data"></a>
#### add\_eval\_data

View File

@ -232,22 +232,27 @@ class BaseDocumentStore(BaseComponent):
headers: Optional[Dict[str, str]] = None) -> int:
pass
@staticmethod
@njit#(fastmath=True)
def normalize_embedding(self, emb: np.ndarray) -> None:
def normalize_embedding(emb: np.ndarray) -> None:
"""
Performs L2 normalization of embeddings vector inplace. Input can be a single vector (1D array) or a matrix (2D array).
Performs L2 normalization of embeddings vector inplace. Input can be a single vector (1D array) or a matrix
(2D array).
"""
# Might be extended to other normalizations in future
# Single vec
if len(emb.shape) == 1:
norm = np.sqrt(emb.dot(emb)) #faster than np.linalg.norm()
norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm()
if norm != 0.0:
emb /= norm
# 2D matrix
else:
norm = np.linalg.norm(emb, axis=1)
emb /= norm[:, None]
for vec in emb:
vec = np.ascontiguousarray(vec)
norm = np.sqrt(vec.dot(vec))
if norm != 0.0:
vec /= norm
def finalize_raw_score(self, raw_score: float, similarity: Optional[str]) -> float:
if similarity == "cosine":