2023-01-23 11:59:56 +01:00
from typing import List , Optional , Union
import logging
from tqdm . auto import tqdm
from haystack . schema import Document
from haystack . nodes . image_to_text . base import BaseImageToText
from haystack . errors import ImageToTextError
2023-06-14 12:00:20 +02:00
from haystack . lazy_imports import LazyImport
2023-01-23 11:59:56 +01:00
logger = logging . getLogger ( __name__ )
2023-06-20 09:54:10 +02:00
with LazyImport ( message = " Run ' pip install farm-haystack[inference] ' " ) as torch_and_transformers_import :
2023-06-14 12:00:20 +02:00
import torch
from transformers import pipeline
from haystack . modeling . utils import initialize_device_settings # pylint: disable=ungrouped-imports
from haystack . utils . torch_utils import ListDataset
2023-01-23 11:59:56 +01:00
# supported models classes should be extended when HF image-to-text pipeline willl support more classes
# see https://github.com/huggingface/transformers/issues/21110
2023-05-16 10:57:41 +02:00
SUPPORTED_MODELS_CLASSES = [
" VisionEncoderDecoderModel " ,
" BlipForConditionalGeneration " ,
" Blip2ForConditionalGeneration " ,
]
2023-01-23 11:59:56 +01:00
2023-03-09 15:35:17 +01:00
UNSUPPORTED_MODEL_MESSAGE = (
f " The supported classes are: { SUPPORTED_MODELS_CLASSES } . \n "
f " To find the supported models: \n "
f " 1. Visit [image-to-text models on Hugging Face](https://huggingface.co/models?pipeline_tag=image-to-text). \n "
f " 2. Open a model you want to check. \n "
f ' 3. On the model page, go to the " Files and Versions " tab \n '
f " 4. Open the `config.json` file, and make sure the `architectures` field contains one of the supported classes: { SUPPORTED_MODELS_CLASSES } . "
)
2023-01-23 11:59:56 +01:00
class TransformersImageToText ( BaseImageToText ) :
"""
2023-01-27 08:31:29 +01:00
A transformer - based model to generate captions for images using the Hugging Face ' s transformers framework.
2023-01-23 11:59:56 +01:00
* * Example * *
` ` ` python
image_file_paths = [ " /path/to/images/apple.jpg " ,
" /path/to/images/cat.jpg " , ]
# Generate captions
documents = image_to_text . generate_captions ( image_file_paths = image_file_paths )
# Show results (List of Documents, containing caption and image file_path)
print ( documents )
[
{
" content " : " a red apple is sitting on a pile of hay " ,
. . .
" meta " : {
" image_path " : " /path/to/images/apple.jpg " ,
. . .
} ,
. . .
} ,
. . .
]
` ` `
"""
def __init__ (
self ,
2023-05-16 10:57:41 +02:00
model_name_or_path : str = " Salesforce/blip-image-captioning-base " ,
2023-01-23 11:59:56 +01:00
model_version : Optional [ str ] = None ,
generation_kwargs : Optional [ dict ] = None ,
use_gpu : bool = True ,
batch_size : int = 16 ,
progress_bar : bool = True ,
use_auth_token : Optional [ Union [ str , bool ] ] = None ,
2023-06-14 12:00:20 +02:00
devices : Optional [ List [ Union [ str , " torch.device " ] ] ] = None ,
2023-01-23 11:59:56 +01:00
) :
"""
2023-05-16 10:57:41 +02:00
Load an Image - to - Text model from transformers .
2023-01-23 11:59:56 +01:00
: param model_name_or_path : Directory of a saved model or the name of a public model .
2023-03-09 15:35:17 +01:00
To find these models :
1. Visit [ Hugging Face image to text models ] ( https : / / huggingface . co / models ? pipeline_tag = image - to - text ) . `
2. Open the model you want to check .
3. On the model page , go to the " Files and Versions " tab .
2023-05-16 10:57:41 +02:00
4. Open the ` config . json ` file and make sure the ` architectures ` field contains ` VisionEncoderDecoderModel ` , ` BlipForConditionalGeneration ` , or ` Blip2ForConditionalGeneration ` .
2023-01-27 08:31:29 +01:00
: param model_version : The version of the model to use from the Hugging Face model hub . This can be the tag name , branch name , or commit hash .
: param generation_kwargs : Dictionary containing arguments for the ` generate ( ) ` method of the Hugging Face model .
See [ generate ( ) ] ( https : / / huggingface . co / docs / transformers / en / main_classes / text_generation #transformers.GenerationMixin.generate) in Hugging Face documentation.
2023-01-23 11:59:56 +01:00
: param use_gpu : Whether to use GPU ( if available ) .
: param batch_size : Number of documents to process at a time .
: param progress_bar : Whether to show a progress bar .
2023-01-27 08:31:29 +01:00
: param use_auth_token : The API token used to download private models from Hugging Face .
If set to ` True ` , the token generated when running
` transformers - cli login ` ( stored in ~ / . huggingface ) is used .
For more information , see [ from_pretrained ( ) ] ( https : / / huggingface . co / transformers / main_classes / model . html #transformers.PreTrainedModel.from_pretrained) in Hugging Face documentation.
: param devices : List of torch devices ( for example , cuda , cpu , mps ) to limit inference to specific devices .
A list containing torch device objects or strings is supported ( for example
[ torch . device ( ' cuda:0 ' ) , " mps " , " cuda:1 " ] ) . If you set ` use_gpu = False ` , the devices
parameter is not used and a single CPU device is used for inference .
2023-01-23 11:59:56 +01:00
"""
2023-06-14 12:00:20 +02:00
torch_and_transformers_import . check ( )
2023-01-23 11:59:56 +01:00
super ( ) . __init__ ( )
self . devices , _ = initialize_device_settings ( devices = devices , use_cuda = use_gpu , multi_gpu = False )
if len ( self . devices ) > 1 :
logger . warning (
" Multiple devices are not supported in %s inference, using the first device %s . " ,
self . __class__ . __name__ ,
self . devices [ 0 ] ,
)
2023-03-09 15:35:17 +01:00
try :
self . model = pipeline (
task = " image-to-text " ,
model = model_name_or_path ,
revision = model_version ,
device = self . devices [ 0 ] ,
use_auth_token = use_auth_token ,
)
except KeyError as err :
raise ValueError (
f " The model ' { model_name_or_path } ' is not supported for ImageToText. " f " { UNSUPPORTED_MODEL_MESSAGE } "
) from err
2023-01-23 11:59:56 +01:00
2023-03-09 15:35:17 +01:00
# for some unsupported models, initializing the HF pipeline doesn't raise errors but does not work
2023-01-23 11:59:56 +01:00
model_class_name = self . model . model . __class__ . __name__
if model_class_name not in SUPPORTED_MODELS_CLASSES :
raise ValueError (
2023-03-09 15:35:17 +01:00
f " The model ' { model_name_or_path } ' (class ' { model_class_name } ' ) is not supported for ImageToText. "
f " { UNSUPPORTED_MODEL_MESSAGE } "
2023-01-23 11:59:56 +01:00
)
self . generation_kwargs = generation_kwargs
self . batch_size = batch_size
self . progress_bar = progress_bar
def generate_captions (
self , image_file_paths : List [ str ] , generation_kwargs : Optional [ dict ] = None , batch_size : Optional [ int ] = None
) - > List [ Document ] :
"""
2023-01-27 08:31:29 +01:00
Generate captions for the image files you specify .
2023-01-23 11:59:56 +01:00
2023-01-27 08:31:29 +01:00
: param image_file_paths : Paths to the images for which you want to generate captions .
2023-01-23 11:59:56 +01:00
: param generation_kwargs : Dictionary containing arguments for the generate method of the Hugging Face model .
2023-01-27 08:31:29 +01:00
See [ generate ( ) ] ( https : / / huggingface . co / docs / transformers / en / main_classes / text_generation #transformers.GenerationMixin.generate) in Hugging Face documentation.
2023-01-23 11:59:56 +01:00
: param batch_size : Number of images to process at a time .
2023-01-27 08:31:29 +01:00
: return : List of Documents . ` Document . content ` is the caption . ` Document . meta [ " image_file_path " ] ` contains the path to the image file .
2023-01-23 11:59:56 +01:00
"""
generation_kwargs = generation_kwargs or self . generation_kwargs
batch_size = batch_size or self . batch_size
if len ( image_file_paths ) == 0 :
2023-01-27 08:31:29 +01:00
raise ImageToTextError ( " ImageToText needs at least one file path to produce a caption. " )
2023-01-23 11:59:56 +01:00
images_dataset = ListDataset ( image_file_paths )
captions : List [ str ] = [ ]
try :
for captions_batch in tqdm (
self . model ( images_dataset , generate_kwargs = generation_kwargs , batch_size = batch_size ) ,
disable = not self . progress_bar ,
total = len ( images_dataset ) ,
desc = " Generating captions " ,
) :
captions . append ( " " . join ( [ el [ " generated_text " ] for el in captions_batch ] ) . strip ( ) )
except Exception as exc :
raise ImageToTextError ( str ( exc ) ) from exc
result : List [ Document ] = [ ]
for caption , image_file_path in zip ( captions , image_file_paths ) :
document = Document ( content = caption , content_type = " text " , meta = { " image_path " : image_file_path } )
result . append ( document )
return result