fix: _add_embeddings_to_elements bug resulting in duplicated elements (#1719)

Currently when the OpenAIEmbeddingEncoder adds embeddings to Elements in
`_add_embeddings_to_elements` it overwrites each Element's `to_dict`
method, mistakenly resulting in each Element having identical values
with the exception of the actual embedding value. This was due to the
way it leverages a nested `new_to_dict` method to overwrite. Instead,
this updates the original definition of Element itself to accommodate
the `embeddings` field when available. This also adds a test to validate
that values are not duplicated.
This commit is contained in:
ryannikolaidis 2023-10-12 14:47:32 -07:00 committed by GitHub
parent ebf0722dcc
commit 40523061ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 10 deletions

View File

@ -10,6 +10,7 @@
### Fixes
* **Fixes PDF list parsing creating duplicate list items** Previously a bug in PDF list item parsing caused removal of other elements and duplication of the list items
* **Fixes duplicated elements** Fixes issue where elements are duplicated when embeddings are generated. This will allow users to generate embeddings for their list of Elements without duplicating/breaking the orginal content.
* **Fixes failure when flagging for embeddings through unstructured-ingest** Currently adding the embedding parameter to any connector results in a failure on the copy stage. This is resolves the issue by adding the IngestDoc to the context map in the embedding node's `run` method. This allows users to specify that connectors fetch embeddings without failure.
* **Fix ingest pipeline reformat nodes not discoverable** Fixes issue where reformat nodes raise ModuleNotFoundError on import. This was due to the directory was missing `__init__.py` in order to make it discoverable.
* **Fix default language in ingest CLI** Previously the default was being set to english which injected potentially incorrect information to downstream language detection libraries. By setting the default to None allows those libraries to better detect what language the text is in the doc being processed.

View File

@ -0,0 +1,19 @@
from unstructured.documents.elements import Text
from unstructured.embed.openai import OpenAIEmbeddingEncoder
def test_embed_documents_does_not_break_element_to_dict(mocker):
# Mocked client with the desired behavior for embed_documents
mock_client = mocker.MagicMock()
mock_client.embed_documents.return_value = [1, 2]
# Mock get_openai_client to return our mock_client
mocker.patch.object(OpenAIEmbeddingEncoder, "get_openai_client", return_value=mock_client)
encoder = OpenAIEmbeddingEncoder(api_key="api_key")
elements = encoder.embed_documents(
elements=[Text("This is sentence 1"), Text("This is sentence 2")],
)
assert len(elements) == 2
assert elements[0].to_dict()["text"] == "This is sentence 1"
assert elements[1].to_dict()["text"] == "This is sentence 2"

View File

@ -469,9 +469,11 @@ class Text(Element):
coordinate_system: Optional[CoordinateSystem] = None,
metadata: Optional[ElementMetadata] = None,
detection_origin: Optional[str] = None,
embeddings: Optional[List[float]] = None,
):
metadata = metadata if metadata else ElementMetadata()
self.text: str = text
self.embeddings: Optional[List[float]] = embeddings
if isinstance(element_id, NoID):
# NOTE(robinson) - Cut the SHA256 hex in half to get the first 128 bits
@ -497,6 +499,7 @@ class Text(Element):
(self.text == other.text),
(self.metadata.coordinates == other.metadata.coordinates),
(self.category == other.category),
(self.embeddings == other.embeddings),
],
)
@ -505,6 +508,8 @@ class Text(Element):
out["element_id"] = self.id
out["type"] = self.category
out["text"] = self.text
if self.embeddings:
out["embeddings"] = self.embeddings
return out
def apply(self, *cleaners: Callable):

View File

@ -1,4 +1,3 @@
import types
from typing import List
import numpy as np
@ -37,18 +36,9 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]:
assert len(elements) == len(embeddings)
elements_w_embedding = []
for i, element in enumerate(elements):
original_method = element.to_dict
def new_to_dict(self):
d = original_method()
d["embeddings"] = self.embeddings
return d
element.embeddings = embeddings[i]
elements_w_embedding.append(element)
element.to_dict = types.MethodType(new_to_dict, element)
return elements
@EmbeddingEncoderConnectionError.wrap