feat: add connection check method to all source and destination connectors (#2000)

### Description
Add a `check_connection` method to each connector to easily be able to
check it without running the full ingest process. As part of this PR,
some refactoring done to allow clients to be shared and populated across
the `check_connection` method and the `initialize` method, allowing for
the `check_connection` method to be called without having to rely on the
`initialize` one to be called first.

* bonus: fix the changelog

---------

Co-authored-by: ryannikolaidis <1208590+ryannikolaidis@users.noreply.github.com>
This commit is contained in:
Roman Isecke 2023-11-07 22:11:39 -05:00 committed by GitHub
parent 92ddf3a337
commit 03f62faf9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 475 additions and 100 deletions

View File

@ -1,19 +1,14 @@
## 0.10.30-dev1
### Enhancements
### Features
* **Adds ability to pass timeout for a request when partitioning via a `url`.** `partition` now accepts a new optional parameter `request_timeout` which if set will prevent any `requests.get` from hanging indefinitely and instead will raise a timeout error. This is useful when partitioning a url that may be slow to respond or may not respond at all.
### Fixes
## 0.10.30-dev2
### Enhancements
* **Support nested DOCX tables.** In DOCX, like HTML, a table cell can itself contain a table. In this case, create nested HTML tables to reflect that structure and create a plain-text table with captures all the text in nested tables, formatting it as a reasonable facsimile of a table.
* **Add connection check to ingest connectors** Each source and destination connector now support a `check_connection()` method which makes sure a valid connection can be established with the source/destination given any authentication credentials in a lightweight request.
### Features
* **Adds ability to pass timeout for a request when partitioning via a `url`.** `partition` now accepts a new optional parameter `request_timeout` which if set will prevent any `requests.get` from hanging indefinitely and instead will raise a timeout error. This is useful when partitioning a url that may be slow to respond or may not respond at all.
### Fixes
## 0.10.29

View File

@ -1 +1 @@
__version__ = "0.10.30-dev1" # pragma: no cover
__version__ = "0.10.30-dev2" # pragma: no cover

View File

@ -1,9 +1,11 @@
import os
import typing as t
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
import requests
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
from unstructured.ingest.interfaces import (
BaseConnectorConfig,
@ -16,6 +18,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from pyairtable import Api
@dataclass
class SimpleAirtableConfig(BaseConnectorConfig):
@ -200,6 +205,24 @@ class AirtableSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
"""Fetches tables or views from an Airtable org."""
connector_config: SimpleAirtableConfig
_api: t.Optional["Api"] = field(init=False, default=None)
@property
def api(self):
if self._api is None:
self._api = Api(self.connector_config.personal_access_token)
return self._api
@api.setter
def api(self, api: "Api"):
self._api = api
def check_connection(self):
try:
self.api.request(method="HEAD", url=self.api.build_url("meta", "bases"))
except requests.HTTPError as http_error:
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {http_error}")
@requires_dependencies(["pyairtable"], extras="airtable")
def initialize(self):

View File

@ -7,7 +7,8 @@ from unstructured.ingest.connector.fsspec import (
FsspecSourceConnector,
SimpleFsspecConfig,
)
from unstructured.ingest.error import SourceConnectionError
from unstructured.ingest.error import DestinationConnectionError, SourceConnectionError
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
@ -31,6 +32,16 @@ class AzureBlobStorageIngestDoc(FsspecIngestDoc):
class AzureBlobStorageSourceConnector(FsspecSourceConnector):
connector_config: SimpleAzureBlobStorageConfig
@requires_dependencies(["adlfs"], extras="azure")
def check_connection(self):
from adlfs import AzureBlobFileSystem
try:
AzureBlobFileSystem(**self.connector_config.access_kwargs)
except ValueError as connection_error:
logger.error(f"failed to validate connection: {connection_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {connection_error}")
def __post_init__(self):
self.ingest_doc_cls: t.Type[AzureBlobStorageIngestDoc] = AzureBlobStorageIngestDoc
@ -39,3 +50,13 @@ class AzureBlobStorageSourceConnector(FsspecSourceConnector):
@dataclass
class AzureBlobStorageDestinationConnector(FsspecDestinationConnector):
connector_config: SimpleAzureBlobStorageConfig
@requires_dependencies(["adlfs"], extras="azure")
def check_connection(self):
from adlfs import AzureBlobFileSystem
try:
AzureBlobFileSystem(**self.connector_config.access_kwargs)
except ValueError as connection_error:
logger.error(f"failed to validate connection: {connection_error}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {connection_error}")

View File

@ -1,11 +1,11 @@
import json
import typing as t
import uuid
from dataclasses import dataclass
from dataclasses import dataclass, field
import azure.core.exceptions
from unstructured.ingest.error import WriteError
from unstructured.ingest.error import DestinationConnectionError, WriteError
from unstructured.ingest.interfaces import (
BaseConnectorConfig,
BaseDestinationConnector,
@ -15,6 +15,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from azure.search.documents import SearchClient
@dataclass
class SimpleAzureCognitiveSearchStorageConfig(BaseConnectorConfig):
@ -31,20 +34,37 @@ class AzureCognitiveSearchWriteConfig(WriteConfig):
class AzureCognitiveSearchDestinationConnector(BaseDestinationConnector):
write_config: AzureCognitiveSearchWriteConfig
connector_config: SimpleAzureCognitiveSearchStorageConfig
_client: t.Optional["SearchClient"] = field(init=False, default=None)
@requires_dependencies(["azure"], extras="azure-cognitive-search")
def initialize(self):
def generate_client(self) -> "SearchClient":
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
# Create a client
credential = AzureKeyCredential(self.connector_config.key)
self.client = SearchClient(
return SearchClient(
endpoint=self.connector_config.endpoint,
index_name=self.write_config.index,
credential=credential,
)
@property
def client(self) -> "SearchClient":
if self._client is None:
self._client = self.generate_client()
return self._client
def check_connection(self):
try:
self.client.get_document_count()
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")
def initialize(self):
_ = self.client
def conform_dict(self, data: dict) -> None:
"""
updates the dictionary that is from each Element being converted into a dict/json

View File

@ -40,11 +40,11 @@ class SimpleBiomedConfig(BaseConnectorConfig):
"""Connector config where path is the FTP directory path and
id_, from_, until, format are API parameters."""
path: t.Optional[str]
path: t.Optional[str] = None
# OA Web Service API Options
id_: t.Optional[str]
from_: t.Optional[str]
until: t.Optional[str]
id_: t.Optional[str] = None
from_: t.Optional[str] = None
until: t.Optional[str] = None
request_timeout: int = 45
def validate_api_inputs(self):
@ -152,6 +152,20 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleBiomedConfig
def get_base_endpoints_url(self) -> str:
endpoint_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?format=pdf"
if self.connector_config.id_:
endpoint_url += f"&id={self.connector_config.id_}"
if self.connector_config.from_:
endpoint_url += f"&from={self.connector_config.from_}"
if self.connector_config.until:
endpoint_url += f"&until={self.connector_config.until}"
return endpoint_url
def _list_objects_api(self) -> t.List[BiomedFileMeta]:
def urls_to_metadata(urls):
files = []
@ -175,16 +189,7 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
files: t.List[BiomedFileMeta] = []
endpoint_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?format=pdf"
if self.connector_config.id_:
endpoint_url += f"&id={self.connector_config.id_}"
if self.connector_config.from_:
endpoint_url += f"&from={self.connector_config.from_}"
if self.connector_config.until:
endpoint_url += f"&until={self.connector_config.until}"
endpoint_url = self.get_base_endpoints_url()
while endpoint_url:
session = requests.Session()
@ -287,6 +292,13 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
def initialize(self):
pass
def check_connection(self):
resp = requests.head(self.get_base_endpoints_url())
try:
resp.raise_for_status()
except requests.HTTPError as http_error:
raise SourceConnectionError(f"failed to validate connection: {http_error}")
def get_ingest_docs(self):
files = self._list_objects_api() if self.connector_config.is_api else self._list_objects()
return [

View File

@ -17,7 +17,8 @@ from unstructured.ingest.connector.fsspec import (
FsspecSourceConnector,
SimpleFsspecConfig,
)
from unstructured.ingest.error import SourceConnectionError
from unstructured.ingest.error import DestinationConnectionError, SourceConnectionError
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
@ -57,6 +58,16 @@ class BoxIngestDoc(FsspecIngestDoc):
class BoxSourceConnector(FsspecSourceConnector):
connector_config: SimpleBoxConfig
@requires_dependencies(["boxfs"], extras="box")
def check_connection(self):
from boxfs import BoxFileSystem
try:
BoxFileSystem(**self.connector_config.access_kwargs)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def __post_init__(self):
self.ingest_doc_cls: t.Type[BoxIngestDoc] = BoxIngestDoc
@ -65,3 +76,13 @@ class BoxSourceConnector(FsspecSourceConnector):
@dataclass
class BoxDestinationConnector(FsspecDestinationConnector):
connector_config: SimpleBoxConfig
@requires_dependencies(["boxfs"], extras="box")
def check_connection(self):
from boxfs import BoxFileSystem
try:
BoxFileSystem(**self.connector_config.access_kwargs)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

View File

@ -5,6 +5,8 @@ from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
import requests
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
from unstructured.ingest.interfaces import (
BaseConnectorConfig,
@ -17,6 +19,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from atlassian import Confluence
@dataclass
class SimpleConfluenceConfig(BaseConnectorConfig):
@ -185,17 +190,31 @@ class ConfluenceSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
"""Fetches body fields from all documents within all spaces in a Confluence Cloud instance."""
connector_config: SimpleConfluenceConfig
_confluence: t.Optional["Confluence"] = field(init=False, default=None)
@requires_dependencies(["atlassian"], extras="Confluence")
def initialize(self):
@property
def confluence(self) -> "Confluence":
from atlassian import Confluence
self.confluence = Confluence(
if self._confluence is None:
self._confluence = Confluence(
url=self.connector_config.url,
username=self.connector_config.user_email,
password=self.connector_config.api_token,
)
return self._confluence
@requires_dependencies(["atlassian"], extras="Confluence")
def check_connection(self):
url = "rest/api/space"
try:
self.confluence.request(method="HEAD", path=url)
except requests.HTTPError as http_error:
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {http_error}")
@requires_dependencies(["atlassian"], extras="Confluence")
def initialize(self):
self.list_of_spaces = None
if self.connector_config.spaces:
self.list_of_spaces = self.connector_config.spaces

View File

@ -119,6 +119,9 @@ class DeltaTableSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
connector_config: SimpleDeltaTableConfig
delta_table: t.Optional["DeltaTable"] = None
def check_connection(self):
pass
@requires_dependencies(["deltalake"], extras="delta-table")
def initialize(self):
from deltalake import DeltaTable
@ -172,6 +175,9 @@ class DeltaTableDestinationConnector(BaseDestinationConnector):
def initialize(self):
pass
def check_connection(self):
pass
def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None:
from deltalake.writer import write_deltalake

View File

@ -156,6 +156,21 @@ class DiscordSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
def initialize(self):
pass
@requires_dependencies(dependencies=["discord"], extras="discord")
def check_connection(self):
import asyncio
import discord
from discord.client import Client
intents = discord.Intents.default()
try:
client = Client(intents=intents)
asyncio.run(client.start(token=self.connector_config.token))
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self):
return [
DiscordIngestDoc(

View File

@ -20,6 +20,7 @@ from unstructured.ingest.connector.fsspec import (
SimpleFsspecConfig,
)
from unstructured.ingest.error import SourceConnectionError
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
@ -90,12 +91,16 @@ class DropboxSourceConnector(FsspecSourceConnector):
def initialize(self):
from fsspec import AbstractFileSystem, get_filesystem_class
try:
self.fs: AbstractFileSystem = get_filesystem_class(self.connector_config.protocol)(
**self.connector_config.get_access_kwargs(),
)
# Dropbox requires a forward slash at the front of the folder path. This
# creates some complications in path joining so a custom path is created here.
ls_output = self.fs.ls(f"/{self.connector_config.path_without_protocol}")
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
if ls_output and len(ls_output) >= 1:
return
elif ls_output:

View File

@ -2,7 +2,7 @@ import hashlib
import json
import os
import typing as t
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
@ -17,6 +17,9 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from elasticsearch import Elasticsearch
@dataclass
class SimpleElasticsearchConfig(BaseConnectorConfig):
@ -30,7 +33,7 @@ class SimpleElasticsearchConfig(BaseConnectorConfig):
url: str
index_name: str
jq_query: t.Optional[str]
jq_query: t.Optional[str] = None
@dataclass
@ -185,12 +188,25 @@ class ElasticsearchSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnec
"""Fetches particular fields from all documents in a given elasticsearch cluster and index"""
connector_config: SimpleElasticsearchConfig
_es: t.Optional["Elasticsearch"] = field(init=False, default=None)
@property
def es(self):
from elasticsearch import Elasticsearch
if self._es is None:
self._es = Elasticsearch(self.connector_config.url)
return self._es
def check_connection(self):
try:
self.es.perform_request("HEAD", "/", headers={"accept": "application/json"})
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
def initialize(self):
from elasticsearch import Elasticsearch
self.es = Elasticsearch(self.connector_config.url)
self.scan_query: dict = {"query": {"match_all": {}}}
self.search_query: dict = {"match_all": {}}
self.es.search(index=self.connector_config.index_name, query=self.search_query, size=1)

View File

@ -5,7 +5,11 @@ from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path, PurePath
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
from unstructured.ingest.error import (
DestinationConnectionError,
SourceConnectionError,
SourceConnectionNetworkError,
)
from unstructured.ingest.interfaces import (
BaseConnectorConfig,
BaseDestinationConnector,
@ -147,6 +151,18 @@ class FsspecSourceConnector(
connector_config: SimpleFsspecConfig
def check_connection(self):
from fsspec import get_filesystem_class
try:
fs = get_filesystem_class(self.connector_config.protocol)(
**self.connector_config.get_access_kwargs(),
)
fs.ls(path=self.connector_config.path_without_protocol)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def __post_init__(self):
self.ingest_doc_cls: t.Type[FsspecIngestDoc] = FsspecIngestDoc
@ -244,6 +260,18 @@ class FsspecDestinationConnector(BaseDestinationConnector):
**self.connector_config.get_access_kwargs(),
)
def check_connection(self):
from fsspec import get_filesystem_class
try:
fs = get_filesystem_class(self.connector_config.protocol)(
**self.connector_config.get_access_kwargs(),
)
fs.ls(path=self.connector_config.path_without_protocol)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")
def write_dict(
self,
*args,

View File

@ -18,9 +18,9 @@ from unstructured.ingest.logger import logger
@dataclass
class SimpleGitConfig(BaseConnectorConfig):
url: str
access_token: t.Optional[str]
branch: t.Optional[str]
file_glob: t.Optional[str]
access_token: t.Optional[str] = None
branch: t.Optional[str] = None
file_glob: t.Optional[str] = None
repo_path: str = field(init=False, repr=False)
@ -76,6 +76,9 @@ class GitSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
def initialize(self):
pass
def check_connection(self):
pass
def is_file_type_supported(self, path: str) -> bool:
# Workaround to ensure that auto.partition isn't fed with .yaml, .py, etc. files
# TODO: What to do with no filenames? e.g. LICENSE, Makefile, etc.

View File

@ -127,6 +127,33 @@ class GitHubIngestDoc(GitIngestDoc):
class GitHubSourceConnector(GitSourceConnector):
connector_config: SimpleGitHubConfig
@requires_dependencies(["github"], extras="github")
def check_connection(self):
from github import Consts
from github.GithubRetry import GithubRetry
from github.Requester import Requester
try:
requester = Requester(
auth=self.connector_config.access_token,
base_url=Consts.DEFAULT_BASE_URL,
timeout=Consts.DEFAULT_TIMEOUT,
user_agent=Consts.DEFAULT_USER_AGENT,
per_page=Consts.DEFAULT_PER_PAGE,
verify=True,
retry=GithubRetry(),
pool_size=None,
)
url_base = (
"/repositories/" if isinstance(self.connector_config.repo_path, int) else "/repos/"
)
url = f"{url_base}{self.connector_config.repo_path}"
headers, _ = requester.requestJsonAndCheck("HEAD", url)
logger.debug(f"headers from HEAD request: {headers}")
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self):
repo = self.connector_config.get_repo()
# Load the Git tree with all files, and then create Ingest docs

View File

@ -103,6 +103,20 @@ class GitLabIngestDoc(GitIngestDoc):
class GitLabSourceConnector(GitSourceConnector):
connector_config: SimpleGitLabConfig
@requires_dependencies(["gitlab"], extras="gitlab")
def check_connection(self):
from gitlab import Gitlab
from gitlab.exceptions import GitlabError
try:
gitlab = Gitlab(
self.connector_config.base_url, private_token=self.connector_config.access_token
)
gitlab.auth()
except GitlabError as gitlab_error:
logger.error(f"failed to validate connection: {gitlab_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {gitlab_error}")
def get_ingest_docs(self):
# Load the Git tree with all files, and then create Ingest docs
# for all blobs, i.e. all files, ignoring directories

View File

@ -324,6 +324,13 @@ class GoogleDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnecto
def initialize(self):
pass
def check_connection(self):
try:
self.connector_config.create_session_handle().service
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self):
files = self._list_objects(self.connector_config.drive_id, self.connector_config.recursive)
return [

View File

@ -2,7 +2,7 @@ import math
import os
import typing as t
from collections import abc
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from functools import cached_property
from pathlib import Path
@ -81,9 +81,9 @@ class SimpleJiraConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
user_email: str
api_token: str
url: str
projects: t.Optional[t.List[str]]
boards: t.Optional[t.List[str]]
issues: t.Optional[t.List[str]]
projects: t.Optional[t.List[str]] = None
boards: t.Optional[t.List[str]] = None
issues: t.Optional[t.List[str]] = None
def create_session_handle(
self,
@ -342,10 +342,24 @@ class JiraSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
"""Fetches issues from projects in an Atlassian (Jira) Cloud instance."""
connector_config: SimpleJiraConfig
_jira: t.Optional["Jira"] = field(init=False, default=None)
@property
def jira(self) -> "Jira":
if self._jira is None:
try:
self._jira = self.connector_config.create_session_handle().service
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
return self._jira
@requires_dependencies(["atlassian"], extras="jira")
def initialize(self):
self.jira = self.connector_config.create_session_handle().service
_ = self.jira
def check_connection(self):
_ = self.jira
@requires_dependencies(["atlassian"], extras="jira")
def _get_all_project_ids(self):

View File

@ -89,6 +89,9 @@ class LocalIngestDoc(BaseIngestDoc):
class LocalSourceConnector(BaseSourceConnector):
"""Objects of this class support fetching document(s) from local file system"""
def check_connection(self):
pass
connector_config: SimpleLocalConfig
def __post_init__(self):

View File

@ -1,8 +1,11 @@
import os
import typing as t
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
import httpx
from unstructured.ingest.error import SourceConnectionError
from unstructured.ingest.interfaces import (
BaseConnectorConfig,
BaseIngestDoc,
@ -17,16 +20,18 @@ from unstructured.utils import (
)
NOTION_API_VERSION = "2022-06-28"
if t.TYPE_CHECKING:
from unstructured.ingest.connector.notion.client import Client as NotionClient
@dataclass
class SimpleNotionConfig(BaseConnectorConfig):
"""Connector config to process all messages by channel id's."""
page_ids: t.List[str]
database_ids: t.List[str]
recursive: bool
notion_api_key: str
page_ids: t.List[str] = field(default_factory=list)
database_ids: t.List[str] = field(default_factory=list)
recursive: bool = False
@dataclass
@ -284,20 +289,35 @@ class NotionSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleNotionConfig
retry_strategy_config: t.Optional[RetryStrategyConfig] = None
_client: t.Optional["NotionClient"] = field(init=False, default=None)
@requires_dependencies(dependencies=["notion_client"], extras="notion")
def initialize(self):
"""Verify that can get metadata for an object, validates connections info."""
@property
def client(self) -> "NotionClient":
from unstructured.ingest.connector.notion.client import Client as NotionClient
# Pin the version of the api to avoid schema changes
self.client = NotionClient(
if self._client is None:
self._client = NotionClient(
notion_version=NOTION_API_VERSION,
auth=self.connector_config.notion_api_key,
logger=logger,
log_level=logger.level,
retry_strategy_config=self.retry_strategy_config,
)
return self._client
def check_connection(self):
try:
request = self.client._build_request("HEAD", "users")
response = self.client.client.send(request)
response.raise_for_status()
except httpx.HTTPStatusError as http_error:
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {http_error}")
@requires_dependencies(dependencies=["notion_client"], extras="notion")
def initialize(self):
"""Verify that can get metadata for an object, validates connections info."""
_ = self.client
@requires_dependencies(dependencies=["notion_client"], extras="notion")
def get_child_page_content(self, page_id: str):

View File

@ -17,6 +17,7 @@ from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from office365.graph_client import GraphClient
from office365.onedrive.driveitems.driveItem import DriveItem
MAX_MB_SIZE = 512_000_000
@ -28,7 +29,7 @@ class SimpleOneDriveConfig(BaseConnectorConfig):
client_credential: str = field(repr=False)
user_pname: str
tenant: str = field(repr=False)
authority_url: t.Optional[str] = field(repr=False)
authority_url: t.Optional[str] = field(repr=False, default="https://login.microsoftonline.com")
path: t.Optional[str] = field(default="")
recursive: bool = False
@ -177,12 +178,32 @@ class OneDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
@dataclass
class OneDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleOneDriveConfig
_client: t.Optional["GraphClient"] = field(init=False, default=None)
@requires_dependencies(["office365"], extras="onedrive")
def _set_client(self):
@property
def client(self) -> "GraphClient":
from office365.graph_client import GraphClient
self.client = GraphClient(self.connector_config.token_factory)
if self._client is None:
self._client = GraphClient(self.connector_config.token_factory)
return self._client
@requires_dependencies(["office365"], extras="onedrive")
def initialize(self):
_ = self.client
@requires_dependencies(["office365"], extras="onedrive")
def check_connection(self):
try:
token_resp: dict = self.connector_config.token_factory()
if error := token_resp.get("error"):
raise SourceConnectionError(
"{} ({})".format(error, token_resp.get("error_description"))
)
_ = self.client
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def _list_objects(self, folder, recursive) -> t.List["DriveItem"]:
drive_items = folder.children.get().execute_query()
@ -205,9 +226,6 @@ class OneDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
file_path=file_path,
)
def initialize(self):
self._set_client()
def get_ingest_docs(self):
root = self.client.users[self.connector_config.user_pname].drive.get().execute_query().root
if fpath := self.connector_config.path:

View File

@ -19,6 +19,8 @@ from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
MAX_NUM_EMAILS = 1000000 # Maximum number of emails per folder
if t.TYPE_CHECKING:
from office365.graph_client import GraphClient
class MissingFolderError(Exception):
@ -29,12 +31,12 @@ class MissingFolderError(Exception):
class SimpleOutlookConfig(BaseConnectorConfig):
"""This class is getting the token."""
client_id: t.Optional[str]
client_credential: t.Optional[str] = field(repr=False)
user_email: str
tenant: t.Optional[str] = field(repr=False)
authority_url: t.Optional[str] = field(repr=False)
ms_outlook_folders: t.List[str]
client_id: str
client_credential: str = field(repr=False)
tenant: t.Optional[str] = field(repr=False, default="common")
authority_url: t.Optional[str] = field(repr=False, default="https://login.microsoftonline.com")
ms_outlook_folders: t.List[str] = field(default_factory=list)
recursive: bool = False
registry_name: str = "outlook"
@ -42,7 +44,7 @@ class SimpleOutlookConfig(BaseConnectorConfig):
if not (self.client_id and self.client_credential and self.user_email):
raise ValueError(
"Please provide one of the following mandatory values:"
"\n--client_id\n--client_cred\n--user-email",
"\nclient_id\nclient_cred\nuser_email",
)
self.token_factory = self._acquire_token
@ -180,10 +182,26 @@ class OutlookIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
@dataclass
class OutlookSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleOutlookConfig
_client: t.Optional["GraphClient"] = field(init=False, default=None)
@property
def client(self) -> "GraphClient":
if self._client is None:
self._client = self.connector_config._get_client()
return self._client
def initialize(self):
self.client = self.connector_config._get_client()
try:
self.get_folder_ids()
except Exception as e:
raise SourceConnectionError(f"failed to validate connection: {e}")
def check_connection(self):
try:
_ = self.client
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def recurse_folders(self, folder_id, main_folder_dict):
"""We only get a count of subfolders for any folder.

View File

@ -16,15 +16,18 @@ from unstructured.ingest.interfaces import (
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
if t.TYPE_CHECKING:
from praw import Reddit
@dataclass
class SimpleRedditConfig(BaseConnectorConfig):
subreddit_name: str
client_id: t.Optional[str]
client_secret: t.Optional[str]
user_agent: str
search_query: t.Optional[str]
num_posts: int
user_agent: str
client_id: str
client_secret: t.Optional[str] = None
search_query: t.Optional[str] = None
def __post_init__(self):
if self.num_posts <= 0:
@ -110,16 +113,33 @@ class RedditIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
@dataclass
class RedditSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleRedditConfig
_reddit: t.Optional["Reddit"] = field(init=False, default=None)
@requires_dependencies(["praw"], extras="reddit")
def initialize(self):
@property
def reddit(self) -> "Reddit":
from praw import Reddit
self.reddit = Reddit(
if self._reddit is None:
self._reddit = Reddit(
client_id=self.connector_config.client_id,
client_secret=self.connector_config.client_secret,
user_agent=self.connector_config.user_agent,
)
return self._reddit
@requires_dependencies(["praw"], extras="reddit")
def initialize(self):
_ = self.reddit
def check_connection(self):
from praw.endpoints import API_PATH
from prawcore import ResponseException
try:
self.reddit._objectify_request(method="HEAD", params=None, path=API_PATH["me"])
except ResponseException as response_error:
logger.error(f"failed to validate connection: {response_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {response_error}")
def get_ingest_docs(self):
subreddit = self.reddit.subreddit(self.connector_config.subreddit_name)

View File

@ -224,6 +224,16 @@ class SalesforceSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
def initialize(self):
pass
@requires_dependencies(["simple_salesforce"], extras="salesforce")
def check_connection(self):
from simple_salesforce.exceptions import SalesforceError
try:
self.connector_config.get_client()
except SalesforceError as salesforce_error:
logger.error(f"failed to validate connection: {salesforce_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {salesforce_error}")
@requires_dependencies(["simple_salesforce"], extras="salesforce")
def get_ingest_docs(self) -> t.List[SalesforceIngestDoc]:
"""Get Salesforce Ids for the records.

View File

@ -291,6 +291,14 @@ class SharepointIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
class SharepointSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleSharepointConfig
def check_connection(self):
try:
site_client = self.connector_config.get_site_client()
site_client.site_pages.pages.get().execute_query()
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
@requires_dependencies(["office365"], extras="sharepoint")
def _list_files(self, folder, recursive) -> t.List["File"]:
from office365.runtime.client_request_exception import ClientRequestException

View File

@ -198,6 +198,18 @@ class SlackSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
connector_config: SimpleSlackConfig
@requires_dependencies(dependencies=["slack_sdk"], extras="slack")
def check_connection(self):
from slack_sdk import WebClient
from slack_sdk.errors import SlackClientError
try:
client = WebClient(token=self.connector_config.token)
client.users_identity()
except SlackClientError as slack_error:
logger.error(f"failed to validate connection: {slack_error}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {slack_error}")
def initialize(self):
"""Verify that can get metadata for an object, validates connections info."""

View File

@ -177,6 +177,19 @@ class WikipediaSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector)
def initialize(self):
pass
@requires_dependencies(["wikipedia"], extras="wikipedia")
def check_connection(self):
import wikipedia
try:
wikipedia.page(
self.connector_config.title,
auto_suggest=self.connector_config.auto_suggest,
)
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
def get_ingest_docs(self):
return [
WikipediaIngestTextDoc(

View File

@ -523,7 +523,14 @@ class BaseIngestDoc(IngestDocJsonMixin, ABC):
@dataclass
class BaseSourceConnector(DataClassJsonMixin, ABC):
class BaseConnector(DataClassJsonMixin, ABC):
@abstractmethod
def check_connection(self):
pass
@dataclass
class BaseSourceConnector(BaseConnector, ABC):
"""Abstract Base Class for a connector to a remote source, e.g. S3 or Google Drive."""
processor_config: ProcessorConfig
@ -551,7 +558,7 @@ class BaseSourceConnector(DataClassJsonMixin, ABC):
@dataclass
class BaseDestinationConnector(DataClassJsonMixin, ABC):
class BaseDestinationConnector(BaseConnector, ABC):
write_config: WriteConfig
connector_config: BaseConnectorConfig