mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 16:43:44 +00:00
feat: ImageToText (caption generator) (#3859)
* first draft * fix pylint and mypy * retry w mypy * mypy :-) * rem unused import * incorporate feedback and initial tests * better tests * fix import order * fix docstring * other fix docstring * more and better tests Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
parent
d2bba4935b
commit
b910df7ec7
@ -147,3 +147,10 @@ class CohereError(NodeError):
|
|||||||
):
|
):
|
||||||
super().__init__(message=message, send_message_in_event=send_message_in_event)
|
super().__init__(message=message, send_message_in_event=send_message_in_event)
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToTextError(NodeError):
|
||||||
|
"""Exception for issues that occur in the ImageToText node"""
|
||||||
|
|
||||||
|
def __init__(self, message: Optional[str] = None):
|
||||||
|
super().__init__(message=message)
|
||||||
|
|||||||
2
haystack/nodes/image_to_text/__init__.py
Normal file
2
haystack/nodes/image_to_text/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from haystack.nodes.image_to_text.base import BaseImageToText
|
||||||
|
from haystack.nodes.image_to_text.transformers import TransformersImageToText
|
||||||
53
haystack/nodes/image_to_text/base.py
Normal file
53
haystack/nodes/image_to_text/base.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from haystack.schema import Document
|
||||||
|
from haystack.nodes.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
|
class BaseImageToText(BaseComponent):
|
||||||
|
"""
|
||||||
|
Abstract class for ImageToText
|
||||||
|
"""
|
||||||
|
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_captions(
|
||||||
|
self, image_file_paths: List[str], generation_kwargs: Optional[dict] = None, batch_size: Optional[int] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Abstract method for generating captions.
|
||||||
|
|
||||||
|
:param image_file_paths: Paths of the images
|
||||||
|
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
|
||||||
|
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||||
|
:param batch_size: Number of images to process at a time.
|
||||||
|
:return: List of Documents. Document.content is the caption. Document.meta["image_file_path"] contains the image file path.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None): # type: ignore
|
||||||
|
|
||||||
|
if file_paths is None and documents is None:
|
||||||
|
raise ValueError("You must either specify documents or image file_paths to process.")
|
||||||
|
|
||||||
|
image_file_paths = []
|
||||||
|
if file_paths is not None:
|
||||||
|
image_file_paths.extend(file_paths)
|
||||||
|
if documents is not None:
|
||||||
|
if any((doc.content_type != "image" for doc in documents)):
|
||||||
|
raise ValueError("The ImageToText node only supports image documents.")
|
||||||
|
image_file_paths.extend([doc.content for doc in documents])
|
||||||
|
|
||||||
|
results: dict = {}
|
||||||
|
results["documents"] = self.generate_captions(image_file_paths=image_file_paths)
|
||||||
|
|
||||||
|
return results, "output_1"
|
||||||
|
|
||||||
|
def run_batch( # type: ignore
|
||||||
|
self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None
|
||||||
|
):
|
||||||
|
|
||||||
|
return self.run(file_paths=file_paths, documents=documents)
|
||||||
160
haystack/nodes/image_to_text/transformers.py
Normal file
160
haystack/nodes/image_to_text/transformers.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
from haystack.schema import Document
|
||||||
|
from haystack.nodes.image_to_text.base import BaseImageToText
|
||||||
|
from haystack.modeling.utils import initialize_device_settings
|
||||||
|
from haystack.utils.torch_utils import ListDataset
|
||||||
|
from haystack.errors import ImageToTextError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# supported models classes should be extended when HF image-to-text pipeline willl support more classes
|
||||||
|
# see https://github.com/huggingface/transformers/issues/21110
|
||||||
|
SUPPORTED_MODELS_CLASSES = ["VisionEncoderDecoderModel"]
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersImageToText(BaseImageToText):
|
||||||
|
"""
|
||||||
|
Transformer based model to generate captions for images using the HuggingFace's transformers framework
|
||||||
|
|
||||||
|
See the up-to-date list of available models on
|
||||||
|
`huggingface.co/models <https://huggingface.co/models?pipeline_tag=image-to-text>`__
|
||||||
|
|
||||||
|
**Example**
|
||||||
|
|
||||||
|
```python
|
||||||
|
image_file_paths = ["/path/to/images/apple.jpg",
|
||||||
|
"/path/to/images/cat.jpg", ]
|
||||||
|
|
||||||
|
# Generate captions
|
||||||
|
documents = image_to_text.generate_captions(image_file_paths=image_file_paths)
|
||||||
|
|
||||||
|
# Show results (List of Documents, containing caption and image file_path)
|
||||||
|
print(documents)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"content": "a red apple is sitting on a pile of hay",
|
||||||
|
...
|
||||||
|
"meta": {
|
||||||
|
"image_path": "/path/to/images/apple.jpg",
|
||||||
|
...
|
||||||
|
},
|
||||||
|
...
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str = "nlpconnect/vit-gpt2-image-captioning",
|
||||||
|
model_version: Optional[str] = None,
|
||||||
|
generation_kwargs: Optional[dict] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
batch_size: int = 16,
|
||||||
|
progress_bar: bool = True,
|
||||||
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
|
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load an Image To Text model from Transformers.
|
||||||
|
See the up-to-date list of available models at
|
||||||
|
https://huggingface.co/models?pipeline_tag=image-to-text
|
||||||
|
|
||||||
|
:param model_name_or_path: Directory of a saved model or the name of a public model.
|
||||||
|
See https://huggingface.co/models?pipeline_tag=image-to-text for full list of available models.
|
||||||
|
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||||
|
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
|
||||||
|
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||||
|
:param use_gpu: Whether to use GPU (if available).
|
||||||
|
:param batch_size: Number of documents to process at a time.
|
||||||
|
:param progress_bar: Whether to show a progress bar.
|
||||||
|
:param use_auth_token: The API token used to download private models from Huggingface.
|
||||||
|
If this parameter is set to `True`, then the token generated when running
|
||||||
|
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||||
|
Additional information can be found here
|
||||||
|
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||||
|
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
|
||||||
|
A list containing torch device objects and/or strings is supported (For example
|
||||||
|
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||||
|
parameter is not used and a single cpu device is used for inference.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
|
||||||
|
if len(self.devices) > 1:
|
||||||
|
logger.warning(
|
||||||
|
"Multiple devices are not supported in %s inference, using the first device %s.",
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.devices[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = pipeline(
|
||||||
|
task="image-to-text",
|
||||||
|
model=model_name_or_path,
|
||||||
|
revision=model_version,
|
||||||
|
device=self.devices[0],
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_class_name = self.model.model.__class__.__name__
|
||||||
|
if model_class_name not in SUPPORTED_MODELS_CLASSES:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model of class '{model_class_name}' is not supported for ImageToText."
|
||||||
|
f"The supported classes are: {SUPPORTED_MODELS_CLASSES}."
|
||||||
|
f"You can find the availaible models here: https://huggingface.co/models?pipeline_tag=image-to-text."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.generation_kwargs = generation_kwargs
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.progress_bar = progress_bar
|
||||||
|
|
||||||
|
def generate_captions(
|
||||||
|
self, image_file_paths: List[str], generation_kwargs: Optional[dict] = None, batch_size: Optional[int] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Generate captions for provided image files
|
||||||
|
|
||||||
|
:param image_file_paths: Paths of the images
|
||||||
|
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
|
||||||
|
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||||
|
:param batch_size: Number of images to process at a time.
|
||||||
|
:return: List of Documents. Document.content is the caption. Document.meta["image_file_path"] contains the image file path.
|
||||||
|
"""
|
||||||
|
generation_kwargs = generation_kwargs or self.generation_kwargs
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
|
||||||
|
if len(image_file_paths) == 0:
|
||||||
|
raise ImageToTextError("ImageToText needs at least one filepath to produce a caption.")
|
||||||
|
|
||||||
|
images_dataset = ListDataset(image_file_paths)
|
||||||
|
|
||||||
|
captions: List[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for captions_batch in tqdm(
|
||||||
|
self.model(images_dataset, generate_kwargs=generation_kwargs, batch_size=batch_size),
|
||||||
|
disable=not self.progress_bar,
|
||||||
|
total=len(images_dataset),
|
||||||
|
desc="Generating captions",
|
||||||
|
):
|
||||||
|
captions.append("".join([el["generated_text"] for el in captions_batch]).strip())
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
raise ImageToTextError(str(exc)) from exc
|
||||||
|
|
||||||
|
result: List[Document] = []
|
||||||
|
for caption, image_file_path in zip(captions, image_file_paths):
|
||||||
|
document = Document(content=caption, content_type="text", meta={"image_path": image_file_path})
|
||||||
|
result.append(document)
|
||||||
|
|
||||||
|
return result
|
||||||
89
test/nodes/test_image_to_text.py
Normal file
89
test/nodes/test_image_to_text.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack import Document
|
||||||
|
from haystack.nodes.image_to_text.transformers import TransformersImageToText
|
||||||
|
from haystack.nodes.image_to_text.base import BaseImageToText
|
||||||
|
from haystack.errors import ImageToTextError
|
||||||
|
|
||||||
|
|
||||||
|
from ..conftest import SAMPLES_PATH
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_FILE_NAMES = ["apple.jpg", "car.jpg", "cat.jpg", "galaxy.jpg", "paris.jpg"]
|
||||||
|
IMAGE_FILE_PATHS = [os.path.join(SAMPLES_PATH, "images", file_name) for file_name in IMAGE_FILE_NAMES]
|
||||||
|
IMAGE_DOCS = [Document(content=image_path, content_type="image") for image_path in IMAGE_FILE_PATHS]
|
||||||
|
|
||||||
|
EXPECTED_CAPTIONS = [
|
||||||
|
"a red apple is sitting on a pile of hay",
|
||||||
|
"a white car parked in a parking lot",
|
||||||
|
"a cat laying in the grass",
|
||||||
|
"a blurry photo of a blurry shot of a black object",
|
||||||
|
"a city with a large building and a clock tower",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def image_to_text():
|
||||||
|
return TransformersImageToText(
|
||||||
|
model_name_or_path="nlpconnect/vit-gpt2-image-captioning",
|
||||||
|
devices=["cpu"],
|
||||||
|
generation_kwargs={"max_new_tokens": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_image_to_text_from_files(image_to_text):
|
||||||
|
assert isinstance(image_to_text, BaseImageToText)
|
||||||
|
|
||||||
|
results = image_to_text.run(file_paths=IMAGE_FILE_PATHS)
|
||||||
|
image_paths = [doc.meta["image_path"] for doc in results[0]["documents"]]
|
||||||
|
assert image_paths == IMAGE_FILE_PATHS
|
||||||
|
generated_captions = [doc.content for doc in results[0]["documents"]]
|
||||||
|
assert generated_captions == EXPECTED_CAPTIONS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_image_to_text_from_documents(image_to_text):
|
||||||
|
results = image_to_text.run(documents=IMAGE_DOCS)
|
||||||
|
image_paths = [doc.meta["image_path"] for doc in results[0]["documents"]]
|
||||||
|
assert image_paths == IMAGE_FILE_PATHS
|
||||||
|
generated_captions = [doc.content for doc in results[0]["documents"]]
|
||||||
|
assert generated_captions == EXPECTED_CAPTIONS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_image_to_text_from_files_and_documents(image_to_text):
|
||||||
|
results = image_to_text.run(file_paths=IMAGE_FILE_PATHS[:3], documents=IMAGE_DOCS[3:])
|
||||||
|
image_paths = [doc.meta["image_path"] for doc in results[0]["documents"]]
|
||||||
|
assert image_paths == IMAGE_FILE_PATHS
|
||||||
|
generated_captions = [doc.content for doc in results[0]["documents"]]
|
||||||
|
assert generated_captions == EXPECTED_CAPTIONS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_image_to_text_invalid_image(image_to_text):
|
||||||
|
markdown_path = str(SAMPLES_PATH / "markdown" / "sample.md")
|
||||||
|
with pytest.raises(ImageToTextError, match="cannot identify image file"):
|
||||||
|
image_to_text.run(file_paths=[markdown_path])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_image_to_text_incorrect_path(image_to_text):
|
||||||
|
with pytest.raises(ImageToTextError, match="Incorrect path"):
|
||||||
|
image_to_text.run(file_paths=["wrong_path.jpg"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_image_to_text_not_image_document(image_to_text):
|
||||||
|
textual_document = Document(content="this document is textual", content_type="text")
|
||||||
|
with pytest.raises(ValueError, match="The ImageToText node only supports image documents."):
|
||||||
|
image_to_text.run(documents=[textual_document])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_image_to_text_unsupported_model():
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="The model of class 'BertForQuestionAnswering' is not supported for ImageToText"
|
||||||
|
):
|
||||||
|
_ = TransformersImageToText(model_name_or_path="deepset/minilm-uncased-squad2")
|
||||||
Loading…
x
Reference in New Issue
Block a user