mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
feat: ImageToText (caption generator) (#3859)
* first draft * fix pylint and mypy * retry w mypy * mypy :-) * rem unused import * incorporate feedback and initial tests * better tests * fix import order * fix docstring * other fix docstring * more and better tests Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
parent
d2bba4935b
commit
b910df7ec7
@ -147,3 +147,10 @@ class CohereError(NodeError):
|
||||
):
|
||||
super().__init__(message=message, send_message_in_event=send_message_in_event)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class ImageToTextError(NodeError):
|
||||
"""Exception for issues that occur in the ImageToText node"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None):
|
||||
super().__init__(message=message)
|
||||
|
||||
2
haystack/nodes/image_to_text/__init__.py
Normal file
2
haystack/nodes/image_to_text/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from haystack.nodes.image_to_text.base import BaseImageToText
|
||||
from haystack.nodes.image_to_text.transformers import TransformersImageToText
|
||||
53
haystack/nodes/image_to_text/base.py
Normal file
53
haystack/nodes/image_to_text/base.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
class BaseImageToText(BaseComponent):
|
||||
"""
|
||||
Abstract class for ImageToText
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def generate_captions(
|
||||
self, image_file_paths: List[str], generation_kwargs: Optional[dict] = None, batch_size: Optional[int] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Abstract method for generating captions.
|
||||
|
||||
:param image_file_paths: Paths of the images
|
||||
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
|
||||
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
:param batch_size: Number of images to process at a time.
|
||||
:return: List of Documents. Document.content is the caption. Document.meta["image_file_path"] contains the image file path.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None): # type: ignore
|
||||
|
||||
if file_paths is None and documents is None:
|
||||
raise ValueError("You must either specify documents or image file_paths to process.")
|
||||
|
||||
image_file_paths = []
|
||||
if file_paths is not None:
|
||||
image_file_paths.extend(file_paths)
|
||||
if documents is not None:
|
||||
if any((doc.content_type != "image" for doc in documents)):
|
||||
raise ValueError("The ImageToText node only supports image documents.")
|
||||
image_file_paths.extend([doc.content for doc in documents])
|
||||
|
||||
results: dict = {}
|
||||
results["documents"] = self.generate_captions(image_file_paths=image_file_paths)
|
||||
|
||||
return results, "output_1"
|
||||
|
||||
def run_batch( # type: ignore
|
||||
self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None
|
||||
):
|
||||
|
||||
return self.run(file_paths=file_paths, documents=documents)
|
||||
160
haystack/nodes/image_to_text/transformers.py
Normal file
160
haystack/nodes/image_to_text/transformers.py
Normal file
@ -0,0 +1,160 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import pipeline
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.image_to_text.base import BaseImageToText
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
from haystack.utils.torch_utils import ListDataset
|
||||
from haystack.errors import ImageToTextError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# supported models classes should be extended when HF image-to-text pipeline willl support more classes
|
||||
# see https://github.com/huggingface/transformers/issues/21110
|
||||
SUPPORTED_MODELS_CLASSES = ["VisionEncoderDecoderModel"]
|
||||
|
||||
|
||||
class TransformersImageToText(BaseImageToText):
|
||||
"""
|
||||
Transformer based model to generate captions for images using the HuggingFace's transformers framework
|
||||
|
||||
See the up-to-date list of available models on
|
||||
`huggingface.co/models <https://huggingface.co/models?pipeline_tag=image-to-text>`__
|
||||
|
||||
**Example**
|
||||
|
||||
```python
|
||||
image_file_paths = ["/path/to/images/apple.jpg",
|
||||
"/path/to/images/cat.jpg", ]
|
||||
|
||||
# Generate captions
|
||||
documents = image_to_text.generate_captions(image_file_paths=image_file_paths)
|
||||
|
||||
# Show results (List of Documents, containing caption and image file_path)
|
||||
print(documents)
|
||||
|
||||
[
|
||||
{
|
||||
"content": "a red apple is sitting on a pile of hay",
|
||||
...
|
||||
"meta": {
|
||||
"image_path": "/path/to/images/apple.jpg",
|
||||
...
|
||||
},
|
||||
...
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "nlpconnect/vit-gpt2-image-captioning",
|
||||
model_version: Optional[str] = None,
|
||||
generation_kwargs: Optional[dict] = None,
|
||||
use_gpu: bool = True,
|
||||
batch_size: int = 16,
|
||||
progress_bar: bool = True,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
Load an Image To Text model from Transformers.
|
||||
See the up-to-date list of available models at
|
||||
https://huggingface.co/models?pipeline_tag=image-to-text
|
||||
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model.
|
||||
See https://huggingface.co/models?pipeline_tag=image-to-text for full list of available models.
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
|
||||
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
:param use_gpu: Whether to use GPU (if available).
|
||||
:param batch_size: Number of documents to process at a time.
|
||||
:param progress_bar: Whether to show a progress bar.
|
||||
:param use_auth_token: The API token used to download private models from Huggingface.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
Additional information can be found here
|
||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
|
||||
A list containing torch device objects and/or strings is supported (For example
|
||||
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||||
parameter is not used and a single cpu device is used for inference.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
|
||||
if len(self.devices) > 1:
|
||||
logger.warning(
|
||||
"Multiple devices are not supported in %s inference, using the first device %s.",
|
||||
self.__class__.__name__,
|
||||
self.devices[0],
|
||||
)
|
||||
|
||||
self.model = pipeline(
|
||||
task="image-to-text",
|
||||
model=model_name_or_path,
|
||||
revision=model_version,
|
||||
device=self.devices[0],
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
model_class_name = self.model.model.__class__.__name__
|
||||
if model_class_name not in SUPPORTED_MODELS_CLASSES:
|
||||
raise ValueError(
|
||||
f"The model of class '{model_class_name}' is not supported for ImageToText."
|
||||
f"The supported classes are: {SUPPORTED_MODELS_CLASSES}."
|
||||
f"You can find the availaible models here: https://huggingface.co/models?pipeline_tag=image-to-text."
|
||||
)
|
||||
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
def generate_captions(
|
||||
self, image_file_paths: List[str], generation_kwargs: Optional[dict] = None, batch_size: Optional[int] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Generate captions for provided image files
|
||||
|
||||
:param image_file_paths: Paths of the images
|
||||
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
|
||||
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
:param batch_size: Number of images to process at a time.
|
||||
:return: List of Documents. Document.content is the caption. Document.meta["image_file_path"] contains the image file path.
|
||||
"""
|
||||
generation_kwargs = generation_kwargs or self.generation_kwargs
|
||||
batch_size = batch_size or self.batch_size
|
||||
|
||||
if len(image_file_paths) == 0:
|
||||
raise ImageToTextError("ImageToText needs at least one filepath to produce a caption.")
|
||||
|
||||
images_dataset = ListDataset(image_file_paths)
|
||||
|
||||
captions: List[str] = []
|
||||
|
||||
try:
|
||||
for captions_batch in tqdm(
|
||||
self.model(images_dataset, generate_kwargs=generation_kwargs, batch_size=batch_size),
|
||||
disable=not self.progress_bar,
|
||||
total=len(images_dataset),
|
||||
desc="Generating captions",
|
||||
):
|
||||
captions.append("".join([el["generated_text"] for el in captions_batch]).strip())
|
||||
|
||||
except Exception as exc:
|
||||
raise ImageToTextError(str(exc)) from exc
|
||||
|
||||
result: List[Document] = []
|
||||
for caption, image_file_path in zip(captions, image_file_paths):
|
||||
document = Document(content=caption, content_type="text", meta={"image_path": image_file_path})
|
||||
result.append(document)
|
||||
|
||||
return result
|
||||
89
test/nodes/test_image_to_text.py
Normal file
89
test/nodes/test_image_to_text.py
Normal file
@ -0,0 +1,89 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.nodes.image_to_text.transformers import TransformersImageToText
|
||||
from haystack.nodes.image_to_text.base import BaseImageToText
|
||||
from haystack.errors import ImageToTextError
|
||||
|
||||
|
||||
from ..conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
IMAGE_FILE_NAMES = ["apple.jpg", "car.jpg", "cat.jpg", "galaxy.jpg", "paris.jpg"]
|
||||
IMAGE_FILE_PATHS = [os.path.join(SAMPLES_PATH, "images", file_name) for file_name in IMAGE_FILE_NAMES]
|
||||
IMAGE_DOCS = [Document(content=image_path, content_type="image") for image_path in IMAGE_FILE_PATHS]
|
||||
|
||||
EXPECTED_CAPTIONS = [
|
||||
"a red apple is sitting on a pile of hay",
|
||||
"a white car parked in a parking lot",
|
||||
"a cat laying in the grass",
|
||||
"a blurry photo of a blurry shot of a black object",
|
||||
"a city with a large building and a clock tower",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_to_text():
|
||||
return TransformersImageToText(
|
||||
model_name_or_path="nlpconnect/vit-gpt2-image-captioning",
|
||||
devices=["cpu"],
|
||||
generation_kwargs={"max_new_tokens": 50},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_image_to_text_from_files(image_to_text):
|
||||
assert isinstance(image_to_text, BaseImageToText)
|
||||
|
||||
results = image_to_text.run(file_paths=IMAGE_FILE_PATHS)
|
||||
image_paths = [doc.meta["image_path"] for doc in results[0]["documents"]]
|
||||
assert image_paths == IMAGE_FILE_PATHS
|
||||
generated_captions = [doc.content for doc in results[0]["documents"]]
|
||||
assert generated_captions == EXPECTED_CAPTIONS
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_image_to_text_from_documents(image_to_text):
|
||||
results = image_to_text.run(documents=IMAGE_DOCS)
|
||||
image_paths = [doc.meta["image_path"] for doc in results[0]["documents"]]
|
||||
assert image_paths == IMAGE_FILE_PATHS
|
||||
generated_captions = [doc.content for doc in results[0]["documents"]]
|
||||
assert generated_captions == EXPECTED_CAPTIONS
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_image_to_text_from_files_and_documents(image_to_text):
|
||||
results = image_to_text.run(file_paths=IMAGE_FILE_PATHS[:3], documents=IMAGE_DOCS[3:])
|
||||
image_paths = [doc.meta["image_path"] for doc in results[0]["documents"]]
|
||||
assert image_paths == IMAGE_FILE_PATHS
|
||||
generated_captions = [doc.content for doc in results[0]["documents"]]
|
||||
assert generated_captions == EXPECTED_CAPTIONS
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_image_to_text_invalid_image(image_to_text):
|
||||
markdown_path = str(SAMPLES_PATH / "markdown" / "sample.md")
|
||||
with pytest.raises(ImageToTextError, match="cannot identify image file"):
|
||||
image_to_text.run(file_paths=[markdown_path])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_image_to_text_incorrect_path(image_to_text):
|
||||
with pytest.raises(ImageToTextError, match="Incorrect path"):
|
||||
image_to_text.run(file_paths=["wrong_path.jpg"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_image_to_text_not_image_document(image_to_text):
|
||||
textual_document = Document(content="this document is textual", content_type="text")
|
||||
with pytest.raises(ValueError, match="The ImageToText node only supports image documents."):
|
||||
image_to_text.run(documents=[textual_document])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_image_to_text_unsupported_model():
|
||||
with pytest.raises(
|
||||
ValueError, match="The model of class 'BertForQuestionAnswering' is not supported for ImageToText"
|
||||
):
|
||||
_ = TransformersImageToText(model_name_or_path="deepset/minilm-uncased-squad2")
|
||||
Loading…
x
Reference in New Issue
Block a user