96 lines
3.7 KiB
Python

"""S3 file and directory reader.
A loader that fetches a file or iterates through a directory on AWS S3.
"""
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from llama_index import download_loader
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
class S3Reader(BaseReader):
"""General reader for any S3 file or directory."""
def __init__(
self,
*args: Any,
bucket: str,
key: Optional[str] = None,
prefix: Optional[str] = "",
file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = None,
aws_access_id: Optional[str] = None,
aws_access_secret: Optional[str] = None,
aws_session_token: Optional[str] = None,
s3_endpoint_url: Optional[str] = "https://s3.amazonaws.com",
**kwargs: Any,
) -> None:
"""Initialize S3 bucket and key, along with credentials if needed.
If key is not set, the entire bucket (filtered by prefix) is parsed.
Args:
bucket (str): the name of your S3 bucket
key (Optional[str]): the name of the specific file. If none is provided,
this loader will iterate through the entire bucket.
prefix (Optional[str]): the prefix to filter by in the case that the loader
iterates through the entire bucket. Defaults to empty string.
file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file
extension to a BaseReader class that specifies how to convert that file
to text. See `SimpleDirectoryReader` for more details.
aws_access_id (Optional[str]): provide AWS access key directly.
aws_access_secret (Optional[str]): provide AWS access key directly.
s3_endpoint_url (Optional[str]): provide S3 endpoint URL directly.
"""
super().__init__(*args, **kwargs)
self.bucket = bucket
self.key = key
self.prefix = prefix
self.file_extractor = file_extractor
self.aws_access_id = aws_access_id
self.aws_access_secret = aws_access_secret
self.aws_session_token = aws_session_token
self.s3_endpoint_url = s3_endpoint_url
def load_data(self) -> List[Document]:
"""Load file(s) from S3."""
import boto3
s3 = boto3.resource("s3")
s3_client = boto3.client("s3")
if self.aws_access_id:
session = boto3.Session(
aws_access_key_id=self.aws_access_id,
aws_secret_access_key=self.aws_access_secret,
aws_session_token=self.aws_session_token,
)
s3 = session.resource("s3")
s3_client = session.client("s3", endpoint_url=self.s3_endpoint_url )
with tempfile.TemporaryDirectory() as temp_dir:
if self.key:
suffix = Path(self.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
s3_client.download_file(self.bucket, self.key, filepath)
else:
bucket = s3.Bucket(self.bucket)
for obj in bucket.objects.filter(Prefix=self.prefix):
if obj.key.endswith("/"): # skip folders
continue
suffix = Path(obj.key).suffix
filepath = (
f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
)
s3_client.download_file(self.bucket, obj.key, filepath)
SimpleDirectoryReader = download_loader("SimpleDirectoryReader")
loader = SimpleDirectoryReader(temp_dir, file_extractor=self.file_extractor)
return loader.load_data()