!fix: return a list of elements in stage_for_transformers (#420)

* update stage_for_transformers to return a list of elements

* bump changelog and version

* flag breaking change

* fix last word bug in chunk_by_attention_window
This commit is contained in:
Matt Robinson 2023-03-30 12:27:11 -04:00 committed by GitHub
parent 2f5c61c178
commit e5dd9d5676
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 14 deletions

View File

@ -1,10 +1,11 @@
## 0.5.8-dev4
## 0.5.8-dev5
### Enhancements
* Update `elements_to_json` to return string when filename is not specified
* `elements_from_json` may take a string instead of a filename with the `text` kwarg
* `detect_filetype` now does a final fallback to file extension.
* Empty tags are now skipped during the depth check for HTML processing.
### Features
@ -18,6 +19,13 @@
* Partitioning functions that accept a `text` kwarg no longer raise an error if an empty
string is passed (and empty list of elements is returned instead).
* `partition_json` no longer fails if the input is an empty list.
* Fixed bug in `chunk_by_attention_window` that caused the last word in segments to be cut-off
in some cases.
### BREAKING CHANGES
* `stage_for_transformers` now returns a list of elements, making it consistent with other
staging bricks
## 0.5.7

View File

@ -1,6 +1,6 @@
import pytest
from unstructured.documents.elements import Text
from unstructured.documents.elements import Text, Title
from unstructured.staging import huggingface
@ -12,14 +12,23 @@ class MockTokenizer:
def test_stage_for_transformers():
elements = [Text(text="hello " * 20), Text(text="there " * 20)]
title_element = (Title(text="Here is a wonderful story"),)
elements = [title_element, Text(text="hello " * 20 + "there " * 20)]
tokenizer = MockTokenizer()
chunks = huggingface.stage_for_transformers(elements, tokenizer, buffer=10)
chunk_elements = huggingface.stage_for_transformers(elements, tokenizer, buffer=10)
hello_chunk = ("hello " * 10).strip()
there_chunk = ("there " * 10).strip()
assert chunks == [hello_chunk, hello_chunk, "\n\n" + there_chunk, there_chunk]
hello_chunk = Text(("hello " * 10).strip())
there_chunk = Text(("there " * 10).strip())
assert chunk_elements == [
title_element,
hello_chunk,
hello_chunk,
there_chunk,
there_chunk,
]
def test_chunk_by_attention_window():

View File

@ -1 +1 @@
__version__ = "0.5.8-dev4" # pragma: no cover
__version__ = "0.5.8-dev5" # pragma: no cover

View File

@ -1,19 +1,32 @@
from copy import deepcopy
from typing import Callable, List, Optional
from transformers import PreTrainedTokenizer
from unstructured.documents.elements import Text
from unstructured.documents.elements import Element, NarrativeText, Text
def stage_for_transformers(
elements: List[Text],
tokenizer: PreTrainedTokenizer,
**chunk_kwargs,
) -> List[str]:
) -> List[Element]:
"""Stages text elements for transformers pipelines by chunking them into sections that can
fit into the attention window for the model associated with the tokenizer."""
combined_text = "\n\n".join([str(element) for element in elements])
return chunk_by_attention_window(combined_text, tokenizer, **chunk_kwargs)
chunked_elements: List[Element] = []
for element in elements:
# NOTE(robinson) - Only chunk potentially lengthy text. Shorter text (like titles)
# should already fit into the attention window just fine.
if isinstance(element, (NarrativeText, Text)):
chunked_text = chunk_by_attention_window(element.text, tokenizer, **chunk_kwargs)
for chunk in chunked_text:
_chunk_element = deepcopy(element)
_chunk_element.text = chunk
chunked_elements.append(_chunk_element)
else:
chunked_elements.append(element)
return chunked_elements
def chunk_by_attention_window(
@ -68,8 +81,8 @@ def chunk_by_attention_window(
f"error is: \n\n{segment}",
)
if chunk_size + num_tokens > max_chunk_size or i == (num_splits - 1):
chunks.append(chunk_text)
if chunk_size + num_tokens > max_chunk_size:
chunks.append(chunk_text + chunk_separator.strip())
chunk_text = ""
chunk_size = 0
@ -79,4 +92,7 @@ def chunk_by_attention_window(
chunk_text += segment
chunk_size += num_tokens
if i == (num_splits - 1) and len(chunk_text) > 0:
chunks.append(chunk_text)
return chunks