mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2026-01-03 10:46:56 +00:00
fix: _add_embeddings_to_elements bug resulting in duplicated elements (#1719)
Currently when the OpenAIEmbeddingEncoder adds embeddings to Elements in `_add_embeddings_to_elements` it overwrites each Element's `to_dict` method, mistakenly resulting in each Element having identical values with the exception of the actual embedding value. This was due to the way it leverages a nested `new_to_dict` method to overwrite. Instead, this updates the original definition of Element itself to accommodate the `embeddings` field when available. This also adds a test to validate that values are not duplicated.
This commit is contained in:
parent
ebf0722dcc
commit
40523061ca
@ -10,6 +10,7 @@
|
||||
### Fixes
|
||||
|
||||
* **Fixes PDF list parsing creating duplicate list items** Previously a bug in PDF list item parsing caused removal of other elements and duplication of the list items
|
||||
* **Fixes duplicated elements** Fixes issue where elements are duplicated when embeddings are generated. This will allow users to generate embeddings for their list of Elements without duplicating/breaking the orginal content.
|
||||
* **Fixes failure when flagging for embeddings through unstructured-ingest** Currently adding the embedding parameter to any connector results in a failure on the copy stage. This is resolves the issue by adding the IngestDoc to the context map in the embedding node's `run` method. This allows users to specify that connectors fetch embeddings without failure.
|
||||
* **Fix ingest pipeline reformat nodes not discoverable** Fixes issue where reformat nodes raise ModuleNotFoundError on import. This was due to the directory was missing `__init__.py` in order to make it discoverable.
|
||||
* **Fix default language in ingest CLI** Previously the default was being set to english which injected potentially incorrect information to downstream language detection libraries. By setting the default to None allows those libraries to better detect what language the text is in the doc being processed.
|
||||
|
||||
19
test_unstructured/embed/test_openai.py
Normal file
19
test_unstructured/embed/test_openai.py
Normal file
@ -0,0 +1,19 @@
|
||||
from unstructured.documents.elements import Text
|
||||
from unstructured.embed.openai import OpenAIEmbeddingEncoder
|
||||
|
||||
|
||||
def test_embed_documents_does_not_break_element_to_dict(mocker):
|
||||
# Mocked client with the desired behavior for embed_documents
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client.embed_documents.return_value = [1, 2]
|
||||
|
||||
# Mock get_openai_client to return our mock_client
|
||||
mocker.patch.object(OpenAIEmbeddingEncoder, "get_openai_client", return_value=mock_client)
|
||||
|
||||
encoder = OpenAIEmbeddingEncoder(api_key="api_key")
|
||||
elements = encoder.embed_documents(
|
||||
elements=[Text("This is sentence 1"), Text("This is sentence 2")],
|
||||
)
|
||||
assert len(elements) == 2
|
||||
assert elements[0].to_dict()["text"] == "This is sentence 1"
|
||||
assert elements[1].to_dict()["text"] == "This is sentence 2"
|
||||
@ -469,9 +469,11 @@ class Text(Element):
|
||||
coordinate_system: Optional[CoordinateSystem] = None,
|
||||
metadata: Optional[ElementMetadata] = None,
|
||||
detection_origin: Optional[str] = None,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
):
|
||||
metadata = metadata if metadata else ElementMetadata()
|
||||
self.text: str = text
|
||||
self.embeddings: Optional[List[float]] = embeddings
|
||||
|
||||
if isinstance(element_id, NoID):
|
||||
# NOTE(robinson) - Cut the SHA256 hex in half to get the first 128 bits
|
||||
@ -497,6 +499,7 @@ class Text(Element):
|
||||
(self.text == other.text),
|
||||
(self.metadata.coordinates == other.metadata.coordinates),
|
||||
(self.category == other.category),
|
||||
(self.embeddings == other.embeddings),
|
||||
],
|
||||
)
|
||||
|
||||
@ -505,6 +508,8 @@ class Text(Element):
|
||||
out["element_id"] = self.id
|
||||
out["type"] = self.category
|
||||
out["text"] = self.text
|
||||
if self.embeddings:
|
||||
out["embeddings"] = self.embeddings
|
||||
return out
|
||||
|
||||
def apply(self, *cleaners: Callable):
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
@ -37,18 +36,9 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
||||
def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]:
|
||||
assert len(elements) == len(embeddings)
|
||||
elements_w_embedding = []
|
||||
|
||||
for i, element in enumerate(elements):
|
||||
original_method = element.to_dict
|
||||
|
||||
def new_to_dict(self):
|
||||
d = original_method()
|
||||
d["embeddings"] = self.embeddings
|
||||
return d
|
||||
|
||||
element.embeddings = embeddings[i]
|
||||
elements_w_embedding.append(element)
|
||||
element.to_dict = types.MethodType(new_to_dict, element)
|
||||
return elements
|
||||
|
||||
@EmbeddingEncoderConnectionError.wrap
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user