!fix: return a list of elements in stage_for_transformers (#420)

* update stage_for_transformers to return a list of elements

* bump changelog and version

* flag breaking change

* fix last word bug in chunk_by_attention_window
This commit is contained in:
Matt Robinson 2023-03-30 12:27:11 -04:00 committed by GitHub
parent 2f5c61c178
commit e5dd9d5676
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 14 deletions

View File

@ -1,10 +1,11 @@
## 0.5.8-dev4 ## 0.5.8-dev5
### Enhancements ### Enhancements
* Update `elements_to_json` to return string when filename is not specified * Update `elements_to_json` to return string when filename is not specified
* `elements_from_json` may take a string instead of a filename with the `text` kwarg * `elements_from_json` may take a string instead of a filename with the `text` kwarg
* `detect_filetype` now does a final fallback to file extension. * `detect_filetype` now does a final fallback to file extension.
* Empty tags are now skipped during the depth check for HTML processing.
### Features ### Features
@ -18,6 +19,13 @@
* Partitioning functions that accept a `text` kwarg no longer raise an error if an empty * Partitioning functions that accept a `text` kwarg no longer raise an error if an empty
string is passed (and empty list of elements is returned instead). string is passed (and empty list of elements is returned instead).
* `partition_json` no longer fails if the input is an empty list. * `partition_json` no longer fails if the input is an empty list.
* Fixed bug in `chunk_by_attention_window` that caused the last word in segments to be cut-off
in some cases.
### BREAKING CHANGES
* `stage_for_transformers` now returns a list of elements, making it consistent with other
staging bricks
## 0.5.7 ## 0.5.7

View File

@ -1,6 +1,6 @@
import pytest import pytest
from unstructured.documents.elements import Text from unstructured.documents.elements import Text, Title
from unstructured.staging import huggingface from unstructured.staging import huggingface
@ -12,14 +12,23 @@ class MockTokenizer:
def test_stage_for_transformers(): def test_stage_for_transformers():
elements = [Text(text="hello " * 20), Text(text="there " * 20)] title_element = (Title(text="Here is a wonderful story"),)
elements = [title_element, Text(text="hello " * 20 + "there " * 20)]
tokenizer = MockTokenizer() tokenizer = MockTokenizer()
chunks = huggingface.stage_for_transformers(elements, tokenizer, buffer=10) chunk_elements = huggingface.stage_for_transformers(elements, tokenizer, buffer=10)
hello_chunk = ("hello " * 10).strip() hello_chunk = Text(("hello " * 10).strip())
there_chunk = ("there " * 10).strip() there_chunk = Text(("there " * 10).strip())
assert chunks == [hello_chunk, hello_chunk, "\n\n" + there_chunk, there_chunk]
assert chunk_elements == [
title_element,
hello_chunk,
hello_chunk,
there_chunk,
there_chunk,
]
def test_chunk_by_attention_window(): def test_chunk_by_attention_window():

View File

@ -1 +1 @@
__version__ = "0.5.8-dev4" # pragma: no cover __version__ = "0.5.8-dev5" # pragma: no cover

View File

@ -1,19 +1,32 @@
from copy import deepcopy
from typing import Callable, List, Optional from typing import Callable, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from unstructured.documents.elements import Text from unstructured.documents.elements import Element, NarrativeText, Text
def stage_for_transformers( def stage_for_transformers(
elements: List[Text], elements: List[Text],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
**chunk_kwargs, **chunk_kwargs,
) -> List[str]: ) -> List[Element]:
"""Stages text elements for transformers pipelines by chunking them into sections that can """Stages text elements for transformers pipelines by chunking them into sections that can
fit into the attention window for the model associated with the tokenizer.""" fit into the attention window for the model associated with the tokenizer."""
combined_text = "\n\n".join([str(element) for element in elements]) chunked_elements: List[Element] = []
return chunk_by_attention_window(combined_text, tokenizer, **chunk_kwargs) for element in elements:
# NOTE(robinson) - Only chunk potentially lengthy text. Shorter text (like titles)
# should already fit into the attention window just fine.
if isinstance(element, (NarrativeText, Text)):
chunked_text = chunk_by_attention_window(element.text, tokenizer, **chunk_kwargs)
for chunk in chunked_text:
_chunk_element = deepcopy(element)
_chunk_element.text = chunk
chunked_elements.append(_chunk_element)
else:
chunked_elements.append(element)
return chunked_elements
def chunk_by_attention_window( def chunk_by_attention_window(
@ -68,8 +81,8 @@ def chunk_by_attention_window(
f"error is: \n\n{segment}", f"error is: \n\n{segment}",
) )
if chunk_size + num_tokens > max_chunk_size or i == (num_splits - 1): if chunk_size + num_tokens > max_chunk_size:
chunks.append(chunk_text) chunks.append(chunk_text + chunk_separator.strip())
chunk_text = "" chunk_text = ""
chunk_size = 0 chunk_size = 0
@ -79,4 +92,7 @@ def chunk_by_attention_window(
chunk_text += segment chunk_text += segment
chunk_size += num_tokens chunk_size += num_tokens
if i == (num_splits - 1) and len(chunk_text) > 0:
chunks.append(chunk_text)
return chunks return chunks