feat: add connection check method to all source and destination connectors (#2000)

### Description
Add a `check_connection` method to each connector to easily be able to
check it without running the full ingest process. As part of this PR,
some refactoring done to allow clients to be shared and populated across
the `check_connection` method and the `initialize` method, allowing for
the `check_connection` method to be called without having to rely on the
`initialize` one to be called first.

* bonus: fix the changelog

---------

Co-authored-by: ryannikolaidis <1208590+ryannikolaidis@users.noreply.github.com>
This commit is contained in:
Roman Isecke 2023-11-07 22:11:39 -05:00 committed by GitHub
parent 92ddf3a337
commit 03f62faf9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 475 additions and 100 deletions

View File

@ -1,19 +1,14 @@
## 0.10.30-dev1 ## 0.10.30-dev2
### Enhancements
### Features
* **Adds ability to pass timeout for a request when partitioning via a `url`.** `partition` now accepts a new optional parameter `request_timeout` which if set will prevent any `requests.get` from hanging indefinitely and instead will raise a timeout error. This is useful when partitioning a url that may be slow to respond or may not respond at all.
### Fixes
### Enhancements ### Enhancements
* **Support nested DOCX tables.** In DOCX, like HTML, a table cell can itself contain a table. In this case, create nested HTML tables to reflect that structure and create a plain-text table with captures all the text in nested tables, formatting it as a reasonable facsimile of a table. * **Support nested DOCX tables.** In DOCX, like HTML, a table cell can itself contain a table. In this case, create nested HTML tables to reflect that structure and create a plain-text table with captures all the text in nested tables, formatting it as a reasonable facsimile of a table.
* **Add connection check to ingest connectors** Each source and destination connector now support a `check_connection()` method which makes sure a valid connection can be established with the source/destination given any authentication credentials in a lightweight request.
### Features ### Features
* **Adds ability to pass timeout for a request when partitioning via a `url`.** `partition` now accepts a new optional parameter `request_timeout` which if set will prevent any `requests.get` from hanging indefinitely and instead will raise a timeout error. This is useful when partitioning a url that may be slow to respond or may not respond at all.
### Fixes ### Fixes
## 0.10.29 ## 0.10.29

View File

@ -1 +1 @@
__version__ = "0.10.30-dev1" # pragma: no cover __version__ = "0.10.30-dev2" # pragma: no cover

View File

@ -1,9 +1,11 @@
import os import os
import typing as t import typing as t
from dataclasses import dataclass from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import requests
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
from unstructured.ingest.interfaces import ( from unstructured.ingest.interfaces import (
BaseConnectorConfig, BaseConnectorConfig,
@ -16,6 +18,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from pyairtable import Api
@dataclass @dataclass
class SimpleAirtableConfig(BaseConnectorConfig): class SimpleAirtableConfig(BaseConnectorConfig):
@ -200,6 +205,24 @@ class AirtableSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
"""Fetches tables or views from an Airtable org.""" """Fetches tables or views from an Airtable org."""
connector_config: SimpleAirtableConfig connector_config: SimpleAirtableConfig
_api: t.Optional["Api"] = field(init=False, default=None)
@property
def api(self):
if self._api is None:
self._api = Api(self.connector_config.personal_access_token)
return self._api
@api.setter
def api(self, api: "Api"):
self._api = api
def check_connection(self):
try:
self.api.request(method="HEAD", url=self.api.build_url("meta", "bases"))
except requests.HTTPError as http_error:
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {http_error}")
@requires_dependencies(["pyairtable"], extras="airtable") @requires_dependencies(["pyairtable"], extras="airtable")
def initialize(self): def initialize(self):

View File

@ -7,7 +7,8 @@ from unstructured.ingest.connector.fsspec import (
FsspecSourceConnector, FsspecSourceConnector,
SimpleFsspecConfig, SimpleFsspecConfig,
) )
from unstructured.ingest.error import SourceConnectionError from unstructured.ingest.error import DestinationConnectionError, SourceConnectionError
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
@ -31,6 +32,16 @@ class AzureBlobStorageIngestDoc(FsspecIngestDoc):
class AzureBlobStorageSourceConnector(FsspecSourceConnector): class AzureBlobStorageSourceConnector(FsspecSourceConnector):
connector_config: SimpleAzureBlobStorageConfig connector_config: SimpleAzureBlobStorageConfig
@requires_dependencies(["adlfs"], extras="azure")
def check_connection(self):
from adlfs import AzureBlobFileSystem
try:
AzureBlobFileSystem(**self.connector_config.access_kwargs)
except ValueError as connection_error:
logger.error(f"failed to validate connection: {connection_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {connection_error}")
def __post_init__(self): def __post_init__(self):
self.ingest_doc_cls: t.Type[AzureBlobStorageIngestDoc] = AzureBlobStorageIngestDoc self.ingest_doc_cls: t.Type[AzureBlobStorageIngestDoc] = AzureBlobStorageIngestDoc
@ -39,3 +50,13 @@ class AzureBlobStorageSourceConnector(FsspecSourceConnector):
@dataclass @dataclass
class AzureBlobStorageDestinationConnector(FsspecDestinationConnector): class AzureBlobStorageDestinationConnector(FsspecDestinationConnector):
connector_config: SimpleAzureBlobStorageConfig connector_config: SimpleAzureBlobStorageConfig
@requires_dependencies(["adlfs"], extras="azure")
def check_connection(self):
from adlfs import AzureBlobFileSystem
try:
AzureBlobFileSystem(**self.connector_config.access_kwargs)
except ValueError as connection_error:
logger.error(f"failed to validate connection: {connection_error}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {connection_error}")

View File

@ -1,11 +1,11 @@
import json import json
import typing as t import typing as t
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass, field
import azure.core.exceptions import azure.core.exceptions
from unstructured.ingest.error import WriteError from unstructured.ingest.error import DestinationConnectionError, WriteError
from unstructured.ingest.interfaces import ( from unstructured.ingest.interfaces import (
BaseConnectorConfig, BaseConnectorConfig,
BaseDestinationConnector, BaseDestinationConnector,
@ -15,6 +15,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from azure.search.documents import SearchClient
@dataclass @dataclass
class SimpleAzureCognitiveSearchStorageConfig(BaseConnectorConfig): class SimpleAzureCognitiveSearchStorageConfig(BaseConnectorConfig):
@ -31,20 +34,37 @@ class AzureCognitiveSearchWriteConfig(WriteConfig):
class AzureCognitiveSearchDestinationConnector(BaseDestinationConnector): class AzureCognitiveSearchDestinationConnector(BaseDestinationConnector):
write_config: AzureCognitiveSearchWriteConfig write_config: AzureCognitiveSearchWriteConfig
connector_config: SimpleAzureCognitiveSearchStorageConfig connector_config: SimpleAzureCognitiveSearchStorageConfig
_client: t.Optional["SearchClient"] = field(init=False, default=None)
@requires_dependencies(["azure"], extras="azure-cognitive-search") @requires_dependencies(["azure"], extras="azure-cognitive-search")
def initialize(self): def generate_client(self) -> "SearchClient":
from azure.core.credentials import AzureKeyCredential from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient from azure.search.documents import SearchClient
# Create a client # Create a client
credential = AzureKeyCredential(self.connector_config.key) credential = AzureKeyCredential(self.connector_config.key)
self.client = SearchClient( return SearchClient(
endpoint=self.connector_config.endpoint, endpoint=self.connector_config.endpoint,
index_name=self.write_config.index, index_name=self.write_config.index,
credential=credential, credential=credential,
) )
@property
def client(self) -> "SearchClient":
if self._client is None:
self._client = self.generate_client()
return self._client
def check_connection(self):
try:
self.client.get_document_count()
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")
def initialize(self):
_ = self.client
def conform_dict(self, data: dict) -> None: def conform_dict(self, data: dict) -> None:
""" """
updates the dictionary that is from each Element being converted into a dict/json updates the dictionary that is from each Element being converted into a dict/json

View File

@ -40,11 +40,11 @@ class SimpleBiomedConfig(BaseConnectorConfig):
"""Connector config where path is the FTP directory path and """Connector config where path is the FTP directory path and
id_, from_, until, format are API parameters.""" id_, from_, until, format are API parameters."""
path: t.Optional[str] path: t.Optional[str] = None
# OA Web Service API Options # OA Web Service API Options
id_: t.Optional[str] id_: t.Optional[str] = None
from_: t.Optional[str] from_: t.Optional[str] = None
until: t.Optional[str] until: t.Optional[str] = None
request_timeout: int = 45 request_timeout: int = 45
def validate_api_inputs(self): def validate_api_inputs(self):
@ -152,6 +152,20 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleBiomedConfig connector_config: SimpleBiomedConfig
def get_base_endpoints_url(self) -> str:
endpoint_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?format=pdf"
if self.connector_config.id_:
endpoint_url += f"&id={self.connector_config.id_}"
if self.connector_config.from_:
endpoint_url += f"&from={self.connector_config.from_}"
if self.connector_config.until:
endpoint_url += f"&until={self.connector_config.until}"
return endpoint_url
def _list_objects_api(self) -> t.List[BiomedFileMeta]: def _list_objects_api(self) -> t.List[BiomedFileMeta]:
def urls_to_metadata(urls): def urls_to_metadata(urls):
files = [] files = []
@ -175,16 +189,7 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
files: t.List[BiomedFileMeta] = [] files: t.List[BiomedFileMeta] = []
endpoint_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?format=pdf" endpoint_url = self.get_base_endpoints_url()
if self.connector_config.id_:
endpoint_url += f"&id={self.connector_config.id_}"
if self.connector_config.from_:
endpoint_url += f"&from={self.connector_config.from_}"
if self.connector_config.until:
endpoint_url += f"&until={self.connector_config.until}"
while endpoint_url: while endpoint_url:
session = requests.Session() session = requests.Session()
@ -287,6 +292,13 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
def initialize(self): def initialize(self):
pass pass
def check_connection(self):
resp = requests.head(self.get_base_endpoints_url())
try:
resp.raise_for_status()
except requests.HTTPError as http_error:
raise SourceConnectionError(f"failed to validate connection: {http_error}")
def get_ingest_docs(self): def get_ingest_docs(self):
files = self._list_objects_api() if self.connector_config.is_api else self._list_objects() files = self._list_objects_api() if self.connector_config.is_api else self._list_objects()
return [ return [

View File

@ -17,7 +17,8 @@ from unstructured.ingest.connector.fsspec import (
FsspecSourceConnector, FsspecSourceConnector,
SimpleFsspecConfig, SimpleFsspecConfig,
) )
from unstructured.ingest.error import SourceConnectionError from unstructured.ingest.error import DestinationConnectionError, SourceConnectionError
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
@ -57,6 +58,16 @@ class BoxIngestDoc(FsspecIngestDoc):
class BoxSourceConnector(FsspecSourceConnector): class BoxSourceConnector(FsspecSourceConnector):
connector_config: SimpleBoxConfig connector_config: SimpleBoxConfig
@requires_dependencies(["boxfs"], extras="box")
def check_connection(self):
from boxfs import BoxFileSystem
try:
BoxFileSystem(**self.connector_config.access_kwargs)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def __post_init__(self): def __post_init__(self):
self.ingest_doc_cls: t.Type[BoxIngestDoc] = BoxIngestDoc self.ingest_doc_cls: t.Type[BoxIngestDoc] = BoxIngestDoc
@ -65,3 +76,13 @@ class BoxSourceConnector(FsspecSourceConnector):
@dataclass @dataclass
class BoxDestinationConnector(FsspecDestinationConnector): class BoxDestinationConnector(FsspecDestinationConnector):
connector_config: SimpleBoxConfig connector_config: SimpleBoxConfig
@requires_dependencies(["boxfs"], extras="box")
def check_connection(self):
from boxfs import BoxFileSystem
try:
BoxFileSystem(**self.connector_config.access_kwargs)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

View File

@ -5,6 +5,8 @@ from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import requests
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
from unstructured.ingest.interfaces import ( from unstructured.ingest.interfaces import (
BaseConnectorConfig, BaseConnectorConfig,
@ -17,6 +19,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from atlassian import Confluence
@dataclass @dataclass
class SimpleConfluenceConfig(BaseConnectorConfig): class SimpleConfluenceConfig(BaseConnectorConfig):
@ -185,17 +190,31 @@ class ConfluenceSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
"""Fetches body fields from all documents within all spaces in a Confluence Cloud instance.""" """Fetches body fields from all documents within all spaces in a Confluence Cloud instance."""
connector_config: SimpleConfluenceConfig connector_config: SimpleConfluenceConfig
_confluence: t.Optional["Confluence"] = field(init=False, default=None)
@requires_dependencies(["atlassian"], extras="Confluence") @property
def initialize(self): def confluence(self) -> "Confluence":
from atlassian import Confluence from atlassian import Confluence
self.confluence = Confluence( if self._confluence is None:
self._confluence = Confluence(
url=self.connector_config.url, url=self.connector_config.url,
username=self.connector_config.user_email, username=self.connector_config.user_email,
password=self.connector_config.api_token, password=self.connector_config.api_token,
) )
return self._confluence
@requires_dependencies(["atlassian"], extras="Confluence")
def check_connection(self):
url = "rest/api/space"
try:
self.confluence.request(method="HEAD", path=url)
except requests.HTTPError as http_error:
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {http_error}")
@requires_dependencies(["atlassian"], extras="Confluence")
def initialize(self):
self.list_of_spaces = None self.list_of_spaces = None
if self.connector_config.spaces: if self.connector_config.spaces:
self.list_of_spaces = self.connector_config.spaces self.list_of_spaces = self.connector_config.spaces

View File

@ -119,6 +119,9 @@ class DeltaTableSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
connector_config: SimpleDeltaTableConfig connector_config: SimpleDeltaTableConfig
delta_table: t.Optional["DeltaTable"] = None delta_table: t.Optional["DeltaTable"] = None
def check_connection(self):
pass
@requires_dependencies(["deltalake"], extras="delta-table") @requires_dependencies(["deltalake"], extras="delta-table")
def initialize(self): def initialize(self):
from deltalake import DeltaTable from deltalake import DeltaTable
@ -172,6 +175,9 @@ class DeltaTableDestinationConnector(BaseDestinationConnector):
def initialize(self): def initialize(self):
pass pass
def check_connection(self):
pass
def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None: def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None:
from deltalake.writer import write_deltalake from deltalake.writer import write_deltalake

View File

@ -156,6 +156,21 @@ class DiscordSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
def initialize(self): def initialize(self):
pass pass
@requires_dependencies(dependencies=["discord"], extras="discord")
def check_connection(self):
import asyncio
import discord
from discord.client import Client
intents = discord.Intents.default()
try:
client = Client(intents=intents)
asyncio.run(client.start(token=self.connector_config.token))
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self): def get_ingest_docs(self):
return [ return [
DiscordIngestDoc( DiscordIngestDoc(

View File

@ -20,6 +20,7 @@ from unstructured.ingest.connector.fsspec import (
SimpleFsspecConfig, SimpleFsspecConfig,
) )
from unstructured.ingest.error import SourceConnectionError from unstructured.ingest.error import SourceConnectionError
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
@ -90,12 +91,16 @@ class DropboxSourceConnector(FsspecSourceConnector):
def initialize(self): def initialize(self):
from fsspec import AbstractFileSystem, get_filesystem_class from fsspec import AbstractFileSystem, get_filesystem_class
try:
self.fs: AbstractFileSystem = get_filesystem_class(self.connector_config.protocol)( self.fs: AbstractFileSystem = get_filesystem_class(self.connector_config.protocol)(
**self.connector_config.get_access_kwargs(), **self.connector_config.get_access_kwargs(),
) )
# Dropbox requires a forward slash at the front of the folder path. This # Dropbox requires a forward slash at the front of the folder path. This
# creates some complications in path joining so a custom path is created here. # creates some complications in path joining so a custom path is created here.
ls_output = self.fs.ls(f"/{self.connector_config.path_without_protocol}") ls_output = self.fs.ls(f"/{self.connector_config.path_without_protocol}")
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
if ls_output and len(ls_output) >= 1: if ls_output and len(ls_output) >= 1:
return return
elif ls_output: elif ls_output:

View File

@ -2,7 +2,7 @@ import hashlib
import json import json
import os import os
import typing as t import typing as t
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
@ -17,6 +17,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from elasticsearch import Elasticsearch
@dataclass @dataclass
class SimpleElasticsearchConfig(BaseConnectorConfig): class SimpleElasticsearchConfig(BaseConnectorConfig):
@ -30,7 +33,7 @@ class SimpleElasticsearchConfig(BaseConnectorConfig):
url: str url: str
index_name: str index_name: str
jq_query: t.Optional[str] jq_query: t.Optional[str] = None
@dataclass @dataclass
@ -185,12 +188,25 @@ class ElasticsearchSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnec
"""Fetches particular fields from all documents in a given elasticsearch cluster and index""" """Fetches particular fields from all documents in a given elasticsearch cluster and index"""
connector_config: SimpleElasticsearchConfig connector_config: SimpleElasticsearchConfig
_es: t.Optional["Elasticsearch"] = field(init=False, default=None)
@property
def es(self):
from elasticsearch import Elasticsearch
if self._es is None:
self._es = Elasticsearch(self.connector_config.url)
return self._es
def check_connection(self):
try:
self.es.perform_request("HEAD", "/", headers={"accept": "application/json"})
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
@requires_dependencies(["elasticsearch"], extras="elasticsearch") @requires_dependencies(["elasticsearch"], extras="elasticsearch")
def initialize(self): def initialize(self):
from elasticsearch import Elasticsearch
self.es = Elasticsearch(self.connector_config.url)
self.scan_query: dict = {"query": {"match_all": {}}} self.scan_query: dict = {"query": {"match_all": {}}}
self.search_query: dict = {"match_all": {}} self.search_query: dict = {"match_all": {}}
self.es.search(index=self.connector_config.index_name, query=self.search_query, size=1) self.es.search(index=self.connector_config.index_name, query=self.search_query, size=1)

View File

@ -5,7 +5,11 @@ from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path, PurePath from pathlib import Path, PurePath
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError from unstructured.ingest.error import (
DestinationConnectionError,
SourceConnectionError,
SourceConnectionNetworkError,
)
from unstructured.ingest.interfaces import ( from unstructured.ingest.interfaces import (
BaseConnectorConfig, BaseConnectorConfig,
BaseDestinationConnector, BaseDestinationConnector,
@ -147,6 +151,18 @@ class FsspecSourceConnector(
connector_config: SimpleFsspecConfig connector_config: SimpleFsspecConfig
def check_connection(self):
from fsspec import get_filesystem_class
try:
fs = get_filesystem_class(self.connector_config.protocol)(
**self.connector_config.get_access_kwargs(),
)
fs.ls(path=self.connector_config.path_without_protocol)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def __post_init__(self): def __post_init__(self):
self.ingest_doc_cls: t.Type[FsspecIngestDoc] = FsspecIngestDoc self.ingest_doc_cls: t.Type[FsspecIngestDoc] = FsspecIngestDoc
@ -244,6 +260,18 @@ class FsspecDestinationConnector(BaseDestinationConnector):
**self.connector_config.get_access_kwargs(), **self.connector_config.get_access_kwargs(),
) )
def check_connection(self):
from fsspec import get_filesystem_class
try:
fs = get_filesystem_class(self.connector_config.protocol)(
**self.connector_config.get_access_kwargs(),
)
fs.ls(path=self.connector_config.path_without_protocol)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")
def write_dict( def write_dict(
self, self,
*args, *args,

View File

@ -18,9 +18,9 @@ from unstructured.ingest.logger import logger
@dataclass @dataclass
class SimpleGitConfig(BaseConnectorConfig): class SimpleGitConfig(BaseConnectorConfig):
url: str url: str
access_token: t.Optional[str] access_token: t.Optional[str] = None
branch: t.Optional[str] branch: t.Optional[str] = None
file_glob: t.Optional[str] file_glob: t.Optional[str] = None
repo_path: str = field(init=False, repr=False) repo_path: str = field(init=False, repr=False)
@ -76,6 +76,9 @@ class GitSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
def initialize(self): def initialize(self):
pass pass
def check_connection(self):
pass
def is_file_type_supported(self, path: str) -> bool: def is_file_type_supported(self, path: str) -> bool:
# Workaround to ensure that auto.partition isn't fed with .yaml, .py, etc. files # Workaround to ensure that auto.partition isn't fed with .yaml, .py, etc. files
# TODO: What to do with no filenames? e.g. LICENSE, Makefile, etc. # TODO: What to do with no filenames? e.g. LICENSE, Makefile, etc.

View File

@ -127,6 +127,33 @@ class GitHubIngestDoc(GitIngestDoc):
class GitHubSourceConnector(GitSourceConnector): class GitHubSourceConnector(GitSourceConnector):
connector_config: SimpleGitHubConfig connector_config: SimpleGitHubConfig
@requires_dependencies(["github"], extras="github")
def check_connection(self):
from github import Consts
from github.GithubRetry import GithubRetry
from github.Requester import Requester
try:
requester = Requester(
auth=self.connector_config.access_token,
base_url=Consts.DEFAULT_BASE_URL,
timeout=Consts.DEFAULT_TIMEOUT,
user_agent=Consts.DEFAULT_USER_AGENT,
per_page=Consts.DEFAULT_PER_PAGE,
verify=True,
retry=GithubRetry(),
pool_size=None,
)
url_base = (
"/repositories/" if isinstance(self.connector_config.repo_path, int) else "/repos/"
)
url = f"{url_base}{self.connector_config.repo_path}"
headers, _ = requester.requestJsonAndCheck("HEAD", url)
logger.debug(f"headers from HEAD request: {headers}")
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self): def get_ingest_docs(self):
repo = self.connector_config.get_repo() repo = self.connector_config.get_repo()
# Load the Git tree with all files, and then create Ingest docs # Load the Git tree with all files, and then create Ingest docs

View File

@ -103,6 +103,20 @@ class GitLabIngestDoc(GitIngestDoc):
class GitLabSourceConnector(GitSourceConnector): class GitLabSourceConnector(GitSourceConnector):
connector_config: SimpleGitLabConfig connector_config: SimpleGitLabConfig
@requires_dependencies(["gitlab"], extras="gitlab")
def check_connection(self):
from gitlab import Gitlab
from gitlab.exceptions import GitlabError
try:
gitlab = Gitlab(
self.connector_config.base_url, private_token=self.connector_config.access_token
)
gitlab.auth()
except GitlabError as gitlab_error:
logger.error(f"failed to validate connection: {gitlab_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {gitlab_error}")
def get_ingest_docs(self): def get_ingest_docs(self):
# Load the Git tree with all files, and then create Ingest docs # Load the Git tree with all files, and then create Ingest docs
# for all blobs, i.e. all files, ignoring directories # for all blobs, i.e. all files, ignoring directories

View File

@ -324,6 +324,13 @@ class GoogleDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnecto
def initialize(self): def initialize(self):
pass pass
def check_connection(self):
try:
self.connector_config.create_session_handle().service
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self): def get_ingest_docs(self):
files = self._list_objects(self.connector_config.drive_id, self.connector_config.recursive) files = self._list_objects(self.connector_config.drive_id, self.connector_config.recursive)
return [ return [

View File

@ -2,7 +2,7 @@ import math
import os import os
import typing as t import typing as t
from collections import abc from collections import abc
from dataclasses import dataclass from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
@ -81,9 +81,9 @@ class SimpleJiraConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
user_email: str user_email: str
api_token: str api_token: str
url: str url: str
projects: t.Optional[t.List[str]] projects: t.Optional[t.List[str]] = None
boards: t.Optional[t.List[str]] boards: t.Optional[t.List[str]] = None
issues: t.Optional[t.List[str]] issues: t.Optional[t.List[str]] = None
def create_session_handle( def create_session_handle(
self, self,
@ -342,10 +342,24 @@ class JiraSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
"""Fetches issues from projects in an Atlassian (Jira) Cloud instance.""" """Fetches issues from projects in an Atlassian (Jira) Cloud instance."""
connector_config: SimpleJiraConfig connector_config: SimpleJiraConfig
_jira: t.Optional["Jira"] = field(init=False, default=None)
@property
def jira(self) -> "Jira":
if self._jira is None:
try:
self._jira = self.connector_config.create_session_handle().service
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
return self._jira
@requires_dependencies(["atlassian"], extras="jira") @requires_dependencies(["atlassian"], extras="jira")
def initialize(self): def initialize(self):
self.jira = self.connector_config.create_session_handle().service _ = self.jira
def check_connection(self):
_ = self.jira
@requires_dependencies(["atlassian"], extras="jira") @requires_dependencies(["atlassian"], extras="jira")
def _get_all_project_ids(self): def _get_all_project_ids(self):

View File

@ -89,6 +89,9 @@ class LocalIngestDoc(BaseIngestDoc):
class LocalSourceConnector(BaseSourceConnector): class LocalSourceConnector(BaseSourceConnector):
"""Objects of this class support fetching document(s) from local file system""" """Objects of this class support fetching document(s) from local file system"""
def check_connection(self):
pass
connector_config: SimpleLocalConfig connector_config: SimpleLocalConfig
def __post_init__(self): def __post_init__(self):

View File

@ -1,8 +1,11 @@
import os import os
import typing as t import typing as t
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
import httpx
from unstructured.ingest.error import SourceConnectionError
from unstructured.ingest.interfaces import ( from unstructured.ingest.interfaces import (
BaseConnectorConfig, BaseConnectorConfig,
BaseIngestDoc, BaseIngestDoc,
@ -17,16 +20,18 @@ from unstructured.utils import (
) )
NOTION_API_VERSION = "2022-06-28" NOTION_API_VERSION = "2022-06-28"
if t.TYPE_CHECKING:
from unstructured.ingest.connector.notion.client import Client as NotionClient
@dataclass @dataclass
class SimpleNotionConfig(BaseConnectorConfig): class SimpleNotionConfig(BaseConnectorConfig):
"""Connector config to process all messages by channel id's.""" """Connector config to process all messages by channel id's."""
page_ids: t.List[str]
database_ids: t.List[str]
recursive: bool
notion_api_key: str notion_api_key: str
page_ids: t.List[str] = field(default_factory=list)
database_ids: t.List[str] = field(default_factory=list)
recursive: bool = False
@dataclass @dataclass
@ -284,20 +289,35 @@ class NotionSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleNotionConfig connector_config: SimpleNotionConfig
retry_strategy_config: t.Optional[RetryStrategyConfig] = None retry_strategy_config: t.Optional[RetryStrategyConfig] = None
_client: t.Optional["NotionClient"] = field(init=False, default=None)
@requires_dependencies(dependencies=["notion_client"], extras="notion") @property
def initialize(self): def client(self) -> "NotionClient":
"""Verify that can get metadata for an object, validates connections info."""
from unstructured.ingest.connector.notion.client import Client as NotionClient from unstructured.ingest.connector.notion.client import Client as NotionClient
# Pin the version of the api to avoid schema changes if self._client is None:
self.client = NotionClient( self._client = NotionClient(
notion_version=NOTION_API_VERSION, notion_version=NOTION_API_VERSION,
auth=self.connector_config.notion_api_key, auth=self.connector_config.notion_api_key,
logger=logger, logger=logger,
log_level=logger.level, log_level=logger.level,
retry_strategy_config=self.retry_strategy_config, retry_strategy_config=self.retry_strategy_config,
) )
return self._client
def check_connection(self):
try:
request = self.client._build_request("HEAD", "users")
response = self.client.client.send(request)
response.raise_for_status()
except httpx.HTTPStatusError as http_error:
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {http_error}")
@requires_dependencies(dependencies=["notion_client"], extras="notion")
def initialize(self):
"""Verify that can get metadata for an object, validates connections info."""
_ = self.client
@requires_dependencies(dependencies=["notion_client"], extras="notion") @requires_dependencies(dependencies=["notion_client"], extras="notion")
def get_child_page_content(self, page_id: str): def get_child_page_content(self, page_id: str):

View File

@ -17,6 +17,7 @@ from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from office365.graph_client import GraphClient
from office365.onedrive.driveitems.driveItem import DriveItem from office365.onedrive.driveitems.driveItem import DriveItem
MAX_MB_SIZE = 512_000_000 MAX_MB_SIZE = 512_000_000
@ -28,7 +29,7 @@ class SimpleOneDriveConfig(BaseConnectorConfig):
client_credential: str = field(repr=False) client_credential: str = field(repr=False)
user_pname: str user_pname: str
tenant: str = field(repr=False) tenant: str = field(repr=False)
authority_url: t.Optional[str] = field(repr=False) authority_url: t.Optional[str] = field(repr=False, default="https://login.microsoftonline.com")
path: t.Optional[str] = field(default="") path: t.Optional[str] = field(default="")
recursive: bool = False recursive: bool = False
@ -177,12 +178,32 @@ class OneDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
@dataclass @dataclass
class OneDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector): class OneDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleOneDriveConfig connector_config: SimpleOneDriveConfig
_client: t.Optional["GraphClient"] = field(init=False, default=None)
@requires_dependencies(["office365"], extras="onedrive") @property
def _set_client(self): def client(self) -> "GraphClient":
from office365.graph_client import GraphClient from office365.graph_client import GraphClient
self.client = GraphClient(self.connector_config.token_factory) if self._client is None:
self._client = GraphClient(self.connector_config.token_factory)
return self._client
@requires_dependencies(["office365"], extras="onedrive")
def initialize(self):
_ = self.client
@requires_dependencies(["office365"], extras="onedrive")
def check_connection(self):
try:
token_resp: dict = self.connector_config.token_factory()
if error := token_resp.get("error"):
raise SourceConnectionError(
"{} ({})".format(error, token_resp.get("error_description"))
)
_ = self.client
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def _list_objects(self, folder, recursive) -> t.List["DriveItem"]: def _list_objects(self, folder, recursive) -> t.List["DriveItem"]:
drive_items = folder.children.get().execute_query() drive_items = folder.children.get().execute_query()
@ -205,9 +226,6 @@ class OneDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
file_path=file_path, file_path=file_path,
) )
def initialize(self):
self._set_client()
def get_ingest_docs(self): def get_ingest_docs(self):
root = self.client.users[self.connector_config.user_pname].drive.get().execute_query().root root = self.client.users[self.connector_config.user_pname].drive.get().execute_query().root
if fpath := self.connector_config.path: if fpath := self.connector_config.path:

View File

@ -19,6 +19,8 @@ from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
MAX_NUM_EMAILS = 1000000 # Maximum number of emails per folder MAX_NUM_EMAILS = 1000000 # Maximum number of emails per folder
if t.TYPE_CHECKING:
from office365.graph_client import GraphClient
class MissingFolderError(Exception): class MissingFolderError(Exception):
@ -29,12 +31,12 @@ class MissingFolderError(Exception):
class SimpleOutlookConfig(BaseConnectorConfig): class SimpleOutlookConfig(BaseConnectorConfig):
"""This class is getting the token.""" """This class is getting the token."""
client_id: t.Optional[str]
client_credential: t.Optional[str] = field(repr=False)
user_email: str user_email: str
tenant: t.Optional[str] = field(repr=False) client_id: str
authority_url: t.Optional[str] = field(repr=False) client_credential: str = field(repr=False)
ms_outlook_folders: t.List[str] tenant: t.Optional[str] = field(repr=False, default="common")
authority_url: t.Optional[str] = field(repr=False, default="https://login.microsoftonline.com")
ms_outlook_folders: t.List[str] = field(default_factory=list)
recursive: bool = False recursive: bool = False
registry_name: str = "outlook" registry_name: str = "outlook"
@ -42,7 +44,7 @@ class SimpleOutlookConfig(BaseConnectorConfig):
if not (self.client_id and self.client_credential and self.user_email): if not (self.client_id and self.client_credential and self.user_email):
raise ValueError( raise ValueError(
"Please provide one of the following mandatory values:" "Please provide one of the following mandatory values:"
"\n--client_id\n--client_cred\n--user-email", "\nclient_id\nclient_cred\nuser_email",
) )
self.token_factory = self._acquire_token self.token_factory = self._acquire_token
@ -180,10 +182,26 @@ class OutlookIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
@dataclass @dataclass
class OutlookSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector): class OutlookSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleOutlookConfig connector_config: SimpleOutlookConfig
_client: t.Optional["GraphClient"] = field(init=False, default=None)
@property
def client(self) -> "GraphClient":
if self._client is None:
self._client = self.connector_config._get_client()
return self._client
def initialize(self): def initialize(self):
self.client = self.connector_config._get_client() try:
self.get_folder_ids() self.get_folder_ids()
except Exception as e:
raise SourceConnectionError(f"failed to validate connection: {e}")
def check_connection(self):
try:
_ = self.client
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def recurse_folders(self, folder_id, main_folder_dict): def recurse_folders(self, folder_id, main_folder_dict):
"""We only get a count of subfolders for any folder. """We only get a count of subfolders for any folder.

View File

@ -16,15 +16,18 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from praw import Reddit
@dataclass @dataclass
class SimpleRedditConfig(BaseConnectorConfig): class SimpleRedditConfig(BaseConnectorConfig):
subreddit_name: str subreddit_name: str
client_id: t.Optional[str]
client_secret: t.Optional[str]
user_agent: str
search_query: t.Optional[str]
num_posts: int num_posts: int
user_agent: str
client_id: str
client_secret: t.Optional[str] = None
search_query: t.Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.num_posts <= 0: if self.num_posts <= 0:
@ -110,16 +113,33 @@ class RedditIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
@dataclass @dataclass
class RedditSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector): class RedditSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleRedditConfig connector_config: SimpleRedditConfig
_reddit: t.Optional["Reddit"] = field(init=False, default=None)
@requires_dependencies(["praw"], extras="reddit") @property
def initialize(self): def reddit(self) -> "Reddit":
from praw import Reddit from praw import Reddit
self.reddit = Reddit( if self._reddit is None:
self._reddit = Reddit(
client_id=self.connector_config.client_id, client_id=self.connector_config.client_id,
client_secret=self.connector_config.client_secret, client_secret=self.connector_config.client_secret,
user_agent=self.connector_config.user_agent, user_agent=self.connector_config.user_agent,
) )
return self._reddit
@requires_dependencies(["praw"], extras="reddit")
def initialize(self):
_ = self.reddit
def check_connection(self):
from praw.endpoints import API_PATH
from prawcore import ResponseException
try:
self.reddit._objectify_request(method="HEAD", params=None, path=API_PATH["me"])
except ResponseException as response_error:
logger.error(f"failed to validate connection: {response_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {response_error}")
def get_ingest_docs(self): def get_ingest_docs(self):
subreddit = self.reddit.subreddit(self.connector_config.subreddit_name) subreddit = self.reddit.subreddit(self.connector_config.subreddit_name)

View File

@ -224,6 +224,16 @@ class SalesforceSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
def initialize(self): def initialize(self):
pass pass
@requires_dependencies(["simple_salesforce"], extras="salesforce")
def check_connection(self):
from simple_salesforce.exceptions import SalesforceError
try:
self.connector_config.get_client()
except SalesforceError as salesforce_error:
logger.error(f"failed to validate connection: {salesforce_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {salesforce_error}")
@requires_dependencies(["simple_salesforce"], extras="salesforce") @requires_dependencies(["simple_salesforce"], extras="salesforce")
def get_ingest_docs(self) -> t.List[SalesforceIngestDoc]: def get_ingest_docs(self) -> t.List[SalesforceIngestDoc]:
"""Get Salesforce Ids for the records. """Get Salesforce Ids for the records.

View File

@ -291,6 +291,14 @@ class SharepointIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
class SharepointSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector): class SharepointSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleSharepointConfig connector_config: SimpleSharepointConfig
def check_connection(self):
try:
site_client = self.connector_config.get_site_client()
site_client.site_pages.pages.get().execute_query()
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
@requires_dependencies(["office365"], extras="sharepoint") @requires_dependencies(["office365"], extras="sharepoint")
def _list_files(self, folder, recursive) -> t.List["File"]: def _list_files(self, folder, recursive) -> t.List["File"]:
from office365.runtime.client_request_exception import ClientRequestException from office365.runtime.client_request_exception import ClientRequestException

View File

@ -198,6 +198,18 @@ class SlackSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleSlackConfig connector_config: SimpleSlackConfig
@requires_dependencies(dependencies=["slack_sdk"], extras="slack")
def check_connection(self):
from slack_sdk import WebClient
from slack_sdk.errors import SlackClientError
try:
client = WebClient(token=self.connector_config.token)
client.users_identity()
except SlackClientError as slack_error:
logger.error(f"failed to validate connection: {slack_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {slack_error}")
def initialize(self): def initialize(self):
"""Verify that can get metadata for an object, validates connections info.""" """Verify that can get metadata for an object, validates connections info."""

View File

@ -177,6 +177,19 @@ class WikipediaSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector)
def initialize(self): def initialize(self):
pass pass
@requires_dependencies(["wikipedia"], extras="wikipedia")
def check_connection(self):
import wikipedia
try:
wikipedia.page(
self.connector_config.title,
auto_suggest=self.connector_config.auto_suggest,
)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self): def get_ingest_docs(self):
return [ return [
WikipediaIngestTextDoc( WikipediaIngestTextDoc(

View File

@ -523,7 +523,14 @@ class BaseIngestDoc(IngestDocJsonMixin, ABC):
@dataclass @dataclass
class BaseSourceConnector(DataClassJsonMixin, ABC): class BaseConnector(DataClassJsonMixin, ABC):
@abstractmethod
def check_connection(self):
pass
@dataclass
class BaseSourceConnector(BaseConnector, ABC):
"""Abstract Base Class for a connector to a remote source, e.g. S3 or Google Drive.""" """Abstract Base Class for a connector to a remote source, e.g. S3 or Google Drive."""
processor_config: ProcessorConfig processor_config: ProcessorConfig
@ -551,7 +558,7 @@ class BaseSourceConnector(DataClassJsonMixin, ABC):
@dataclass @dataclass
class BaseDestinationConnector(DataClassJsonMixin, ABC): class BaseDestinationConnector(BaseConnector, ABC):
write_config: WriteConfig write_config: WriteConfig
connector_config: BaseConnectorConfig connector_config: BaseConnectorConfig