mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 22:48:29 +00:00
Fix normalize_embedding using numba (#2347)
* fix normalize_embedding using numba * Update Documentation & Code Style * fix too-many-public-methods pylint msg Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
7e6ff8a205
commit
851fe1cf07
@ -212,8 +212,6 @@ TODO drop params
|
||||
#### normalize\_embedding
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
@njit
|
||||
def normalize_embedding(emb: np.ndarray) -> None
|
||||
```
|
||||
|
||||
|
||||
@ -341,9 +341,7 @@ class BaseDocumentStore(BaseComponent):
|
||||
) -> int:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@njit # (fastmath=True)
|
||||
def normalize_embedding(emb: np.ndarray) -> None:
|
||||
def normalize_embedding(self, emb: np.ndarray) -> None:
|
||||
"""
|
||||
Performs L2 normalization of embeddings vector inplace. Input can be a single vector (1D array) or a matrix
|
||||
(2D array).
|
||||
@ -352,16 +350,26 @@ class BaseDocumentStore(BaseComponent):
|
||||
|
||||
# Single vec
|
||||
if len(emb.shape) == 1:
|
||||
norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm()
|
||||
if norm != 0.0:
|
||||
emb /= norm
|
||||
self._normalize_embedding_1D(emb)
|
||||
# 2D matrix
|
||||
else:
|
||||
for vec in emb:
|
||||
vec = np.ascontiguousarray(vec)
|
||||
norm = np.sqrt(vec.dot(vec))
|
||||
if norm != 0.0:
|
||||
vec /= norm
|
||||
self._normalize_embedding_2D(emb)
|
||||
|
||||
@staticmethod
|
||||
@njit # (fastmath=True)
|
||||
def _normalize_embedding_1D(emb: np.ndarray) -> None:
|
||||
norm = np.sqrt(emb.dot(emb)) # faster than np.linalg.norm()
|
||||
if norm != 0.0:
|
||||
emb /= norm
|
||||
|
||||
@staticmethod
|
||||
@njit # (fastmath=True)
|
||||
def _normalize_embedding_2D(emb: np.ndarray) -> None:
|
||||
for vec in emb:
|
||||
vec = np.ascontiguousarray(vec)
|
||||
norm = np.sqrt(vec.dot(vec))
|
||||
if norm != 0.0:
|
||||
vec /= norm
|
||||
|
||||
def finalize_raw_score(self, raw_score: float, similarity: Optional[str]) -> float:
|
||||
if similarity == "cosine":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user