mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-11-03 11:34:07 +00:00
!fix: return a list of elements in stage_for_transformers (#420)
* update stage_for_transformers to return a list of elements * bump changelog and version * flag breaking change * fix last word bug in chunk_by_attention_window
This commit is contained in:
parent
2f5c61c178
commit
e5dd9d5676
10
CHANGELOG.md
10
CHANGELOG.md
@ -1,10 +1,11 @@
|
||||
## 0.5.8-dev4
|
||||
## 0.5.8-dev5
|
||||
|
||||
### Enhancements
|
||||
|
||||
* Update `elements_to_json` to return string when filename is not specified
|
||||
* `elements_from_json` may take a string instead of a filename with the `text` kwarg
|
||||
* `detect_filetype` now does a final fallback to file extension.
|
||||
* Empty tags are now skipped during the depth check for HTML processing.
|
||||
|
||||
### Features
|
||||
|
||||
@ -18,6 +19,13 @@
|
||||
* Partitioning functions that accept a `text` kwarg no longer raise an error if an empty
|
||||
string is passed (and empty list of elements is returned instead).
|
||||
* `partition_json` no longer fails if the input is an empty list.
|
||||
* Fixed bug in `chunk_by_attention_window` that caused the last word in segments to be cut-off
|
||||
in some cases.
|
||||
|
||||
### BREAKING CHANGES
|
||||
|
||||
* `stage_for_transformers` now returns a list of elements, making it consistent with other
|
||||
staging bricks
|
||||
|
||||
## 0.5.7
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from unstructured.documents.elements import Text
|
||||
from unstructured.documents.elements import Text, Title
|
||||
from unstructured.staging import huggingface
|
||||
|
||||
|
||||
@ -12,14 +12,23 @@ class MockTokenizer:
|
||||
|
||||
|
||||
def test_stage_for_transformers():
|
||||
elements = [Text(text="hello " * 20), Text(text="there " * 20)]
|
||||
title_element = (Title(text="Here is a wonderful story"),)
|
||||
elements = [title_element, Text(text="hello " * 20 + "there " * 20)]
|
||||
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
chunks = huggingface.stage_for_transformers(elements, tokenizer, buffer=10)
|
||||
chunk_elements = huggingface.stage_for_transformers(elements, tokenizer, buffer=10)
|
||||
|
||||
hello_chunk = ("hello " * 10).strip()
|
||||
there_chunk = ("there " * 10).strip()
|
||||
assert chunks == [hello_chunk, hello_chunk, "\n\n" + there_chunk, there_chunk]
|
||||
hello_chunk = Text(("hello " * 10).strip())
|
||||
there_chunk = Text(("there " * 10).strip())
|
||||
|
||||
assert chunk_elements == [
|
||||
title_element,
|
||||
hello_chunk,
|
||||
hello_chunk,
|
||||
there_chunk,
|
||||
there_chunk,
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_by_attention_window():
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "0.5.8-dev4" # pragma: no cover
|
||||
__version__ = "0.5.8-dev5" # pragma: no cover
|
||||
|
||||
@ -1,19 +1,32 @@
|
||||
from copy import deepcopy
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from unstructured.documents.elements import Text
|
||||
from unstructured.documents.elements import Element, NarrativeText, Text
|
||||
|
||||
|
||||
def stage_for_transformers(
|
||||
elements: List[Text],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
**chunk_kwargs,
|
||||
) -> List[str]:
|
||||
) -> List[Element]:
|
||||
"""Stages text elements for transformers pipelines by chunking them into sections that can
|
||||
fit into the attention window for the model associated with the tokenizer."""
|
||||
combined_text = "\n\n".join([str(element) for element in elements])
|
||||
return chunk_by_attention_window(combined_text, tokenizer, **chunk_kwargs)
|
||||
chunked_elements: List[Element] = []
|
||||
for element in elements:
|
||||
# NOTE(robinson) - Only chunk potentially lengthy text. Shorter text (like titles)
|
||||
# should already fit into the attention window just fine.
|
||||
if isinstance(element, (NarrativeText, Text)):
|
||||
chunked_text = chunk_by_attention_window(element.text, tokenizer, **chunk_kwargs)
|
||||
for chunk in chunked_text:
|
||||
_chunk_element = deepcopy(element)
|
||||
_chunk_element.text = chunk
|
||||
chunked_elements.append(_chunk_element)
|
||||
else:
|
||||
chunked_elements.append(element)
|
||||
|
||||
return chunked_elements
|
||||
|
||||
|
||||
def chunk_by_attention_window(
|
||||
@ -68,8 +81,8 @@ def chunk_by_attention_window(
|
||||
f"error is: \n\n{segment}",
|
||||
)
|
||||
|
||||
if chunk_size + num_tokens > max_chunk_size or i == (num_splits - 1):
|
||||
chunks.append(chunk_text)
|
||||
if chunk_size + num_tokens > max_chunk_size:
|
||||
chunks.append(chunk_text + chunk_separator.strip())
|
||||
chunk_text = ""
|
||||
chunk_size = 0
|
||||
|
||||
@ -79,4 +92,7 @@ def chunk_by_attention_window(
|
||||
chunk_text += segment
|
||||
chunk_size += num_tokens
|
||||
|
||||
if i == (num_splits - 1) and len(chunk_text) > 0:
|
||||
chunks.append(chunk_text)
|
||||
|
||||
return chunks
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user