mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-09-16 12:03:34 +00:00
feat: add connection check method to all source and destination connectors (#2000)
### Description Add a `check_connection` method to each connector to easily be able to check it without running the full ingest process. As part of this PR, some refactoring done to allow clients to be shared and populated across the `check_connection` method and the `initialize` method, allowing for the `check_connection` method to be called without having to rely on the `initialize` one to be called first. * bonus: fix the changelog --------- Co-authored-by: ryannikolaidis <1208590+ryannikolaidis@users.noreply.github.com>
This commit is contained in:
parent
92ddf3a337
commit
03f62faf9b
13
CHANGELOG.md
13
CHANGELOG.md
@ -1,19 +1,14 @@
|
||||
## 0.10.30-dev1
|
||||
|
||||
### Enhancements
|
||||
|
||||
### Features
|
||||
|
||||
* **Adds ability to pass timeout for a request when partitioning via a `url`.** `partition` now accepts a new optional parameter `request_timeout` which if set will prevent any `requests.get` from hanging indefinitely and instead will raise a timeout error. This is useful when partitioning a url that may be slow to respond or may not respond at all.
|
||||
|
||||
### Fixes
|
||||
## 0.10.30-dev2
|
||||
|
||||
### Enhancements
|
||||
|
||||
* **Support nested DOCX tables.** In DOCX, like HTML, a table cell can itself contain a table. In this case, create nested HTML tables to reflect that structure and create a plain-text table with captures all the text in nested tables, formatting it as a reasonable facsimile of a table.
|
||||
* **Add connection check to ingest connectors** Each source and destination connector now support a `check_connection()` method which makes sure a valid connection can be established with the source/destination given any authentication credentials in a lightweight request.
|
||||
|
||||
### Features
|
||||
|
||||
* **Adds ability to pass timeout for a request when partitioning via a `url`.** `partition` now accepts a new optional parameter `request_timeout` which if set will prevent any `requests.get` from hanging indefinitely and instead will raise a timeout error. This is useful when partitioning a url that may be slow to respond or may not respond at all.
|
||||
|
||||
### Fixes
|
||||
|
||||
## 0.10.29
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.10.30-dev1" # pragma: no cover
|
||||
__version__ = "0.10.30-dev2" # pragma: no cover
|
||||
|
@ -1,9 +1,11 @@
|
||||
import os
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
|
||||
from unstructured.ingest.interfaces import (
|
||||
BaseConnectorConfig,
|
||||
@ -16,6 +18,9 @@ from unstructured.ingest.interfaces import (
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from pyairtable import Api
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleAirtableConfig(BaseConnectorConfig):
|
||||
@ -200,6 +205,24 @@ class AirtableSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
"""Fetches tables or views from an Airtable org."""
|
||||
|
||||
connector_config: SimpleAirtableConfig
|
||||
_api: t.Optional["Api"] = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
if self._api is None:
|
||||
self._api = Api(self.connector_config.personal_access_token)
|
||||
return self._api
|
||||
|
||||
@api.setter
|
||||
def api(self, api: "Api"):
|
||||
self._api = api
|
||||
|
||||
def check_connection(self):
|
||||
try:
|
||||
self.api.request(method="HEAD", url=self.api.build_url("meta", "bases"))
|
||||
except requests.HTTPError as http_error:
|
||||
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {http_error}")
|
||||
|
||||
@requires_dependencies(["pyairtable"], extras="airtable")
|
||||
def initialize(self):
|
||||
|
@ -7,7 +7,8 @@ from unstructured.ingest.connector.fsspec import (
|
||||
FsspecSourceConnector,
|
||||
SimpleFsspecConfig,
|
||||
)
|
||||
from unstructured.ingest.error import SourceConnectionError
|
||||
from unstructured.ingest.error import DestinationConnectionError, SourceConnectionError
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
|
||||
@ -31,6 +32,16 @@ class AzureBlobStorageIngestDoc(FsspecIngestDoc):
|
||||
class AzureBlobStorageSourceConnector(FsspecSourceConnector):
|
||||
connector_config: SimpleAzureBlobStorageConfig
|
||||
|
||||
@requires_dependencies(["adlfs"], extras="azure")
|
||||
def check_connection(self):
|
||||
from adlfs import AzureBlobFileSystem
|
||||
|
||||
try:
|
||||
AzureBlobFileSystem(**self.connector_config.access_kwargs)
|
||||
except ValueError as connection_error:
|
||||
logger.error(f"failed to validate connection: {connection_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {connection_error}")
|
||||
|
||||
def __post_init__(self):
|
||||
self.ingest_doc_cls: t.Type[AzureBlobStorageIngestDoc] = AzureBlobStorageIngestDoc
|
||||
|
||||
@ -39,3 +50,13 @@ class AzureBlobStorageSourceConnector(FsspecSourceConnector):
|
||||
@dataclass
|
||||
class AzureBlobStorageDestinationConnector(FsspecDestinationConnector):
|
||||
connector_config: SimpleAzureBlobStorageConfig
|
||||
|
||||
@requires_dependencies(["adlfs"], extras="azure")
|
||||
def check_connection(self):
|
||||
from adlfs import AzureBlobFileSystem
|
||||
|
||||
try:
|
||||
AzureBlobFileSystem(**self.connector_config.access_kwargs)
|
||||
except ValueError as connection_error:
|
||||
logger.error(f"failed to validate connection: {connection_error}", exc_info=True)
|
||||
raise DestinationConnectionError(f"failed to validate connection: {connection_error}")
|
||||
|
@ -1,11 +1,11 @@
|
||||
import json
|
||||
import typing as t
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import azure.core.exceptions
|
||||
|
||||
from unstructured.ingest.error import WriteError
|
||||
from unstructured.ingest.error import DestinationConnectionError, WriteError
|
||||
from unstructured.ingest.interfaces import (
|
||||
BaseConnectorConfig,
|
||||
BaseDestinationConnector,
|
||||
@ -15,6 +15,9 @@ from unstructured.ingest.interfaces import (
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from azure.search.documents import SearchClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleAzureCognitiveSearchStorageConfig(BaseConnectorConfig):
|
||||
@ -31,20 +34,37 @@ class AzureCognitiveSearchWriteConfig(WriteConfig):
|
||||
class AzureCognitiveSearchDestinationConnector(BaseDestinationConnector):
|
||||
write_config: AzureCognitiveSearchWriteConfig
|
||||
connector_config: SimpleAzureCognitiveSearchStorageConfig
|
||||
_client: t.Optional["SearchClient"] = field(init=False, default=None)
|
||||
|
||||
@requires_dependencies(["azure"], extras="azure-cognitive-search")
|
||||
def initialize(self):
|
||||
def generate_client(self) -> "SearchClient":
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.search.documents import SearchClient
|
||||
|
||||
# Create a client
|
||||
credential = AzureKeyCredential(self.connector_config.key)
|
||||
self.client = SearchClient(
|
||||
return SearchClient(
|
||||
endpoint=self.connector_config.endpoint,
|
||||
index_name=self.write_config.index,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self) -> "SearchClient":
|
||||
if self._client is None:
|
||||
self._client = self.generate_client()
|
||||
return self._client
|
||||
|
||||
def check_connection(self):
|
||||
try:
|
||||
self.client.get_document_count()
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def initialize(self):
|
||||
_ = self.client
|
||||
|
||||
def conform_dict(self, data: dict) -> None:
|
||||
"""
|
||||
updates the dictionary that is from each Element being converted into a dict/json
|
||||
|
@ -40,11 +40,11 @@ class SimpleBiomedConfig(BaseConnectorConfig):
|
||||
"""Connector config where path is the FTP directory path and
|
||||
id_, from_, until, format are API parameters."""
|
||||
|
||||
path: t.Optional[str]
|
||||
path: t.Optional[str] = None
|
||||
# OA Web Service API Options
|
||||
id_: t.Optional[str]
|
||||
from_: t.Optional[str]
|
||||
until: t.Optional[str]
|
||||
id_: t.Optional[str] = None
|
||||
from_: t.Optional[str] = None
|
||||
until: t.Optional[str] = None
|
||||
request_timeout: int = 45
|
||||
|
||||
def validate_api_inputs(self):
|
||||
@ -152,6 +152,20 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
|
||||
connector_config: SimpleBiomedConfig
|
||||
|
||||
def get_base_endpoints_url(self) -> str:
|
||||
endpoint_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?format=pdf"
|
||||
|
||||
if self.connector_config.id_:
|
||||
endpoint_url += f"&id={self.connector_config.id_}"
|
||||
|
||||
if self.connector_config.from_:
|
||||
endpoint_url += f"&from={self.connector_config.from_}"
|
||||
|
||||
if self.connector_config.until:
|
||||
endpoint_url += f"&until={self.connector_config.until}"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def _list_objects_api(self) -> t.List[BiomedFileMeta]:
|
||||
def urls_to_metadata(urls):
|
||||
files = []
|
||||
@ -175,16 +189,7 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
|
||||
files: t.List[BiomedFileMeta] = []
|
||||
|
||||
endpoint_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?format=pdf"
|
||||
|
||||
if self.connector_config.id_:
|
||||
endpoint_url += f"&id={self.connector_config.id_}"
|
||||
|
||||
if self.connector_config.from_:
|
||||
endpoint_url += f"&from={self.connector_config.from_}"
|
||||
|
||||
if self.connector_config.until:
|
||||
endpoint_url += f"&until={self.connector_config.until}"
|
||||
endpoint_url = self.get_base_endpoints_url()
|
||||
|
||||
while endpoint_url:
|
||||
session = requests.Session()
|
||||
@ -287,6 +292,13 @@ class BiomedSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def check_connection(self):
|
||||
resp = requests.head(self.get_base_endpoints_url())
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except requests.HTTPError as http_error:
|
||||
raise SourceConnectionError(f"failed to validate connection: {http_error}")
|
||||
|
||||
def get_ingest_docs(self):
|
||||
files = self._list_objects_api() if self.connector_config.is_api else self._list_objects()
|
||||
return [
|
||||
|
@ -17,7 +17,8 @@ from unstructured.ingest.connector.fsspec import (
|
||||
FsspecSourceConnector,
|
||||
SimpleFsspecConfig,
|
||||
)
|
||||
from unstructured.ingest.error import SourceConnectionError
|
||||
from unstructured.ingest.error import DestinationConnectionError, SourceConnectionError
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
|
||||
@ -57,6 +58,16 @@ class BoxIngestDoc(FsspecIngestDoc):
|
||||
class BoxSourceConnector(FsspecSourceConnector):
|
||||
connector_config: SimpleBoxConfig
|
||||
|
||||
@requires_dependencies(["boxfs"], extras="box")
|
||||
def check_connection(self):
|
||||
from boxfs import BoxFileSystem
|
||||
|
||||
try:
|
||||
BoxFileSystem(**self.connector_config.access_kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def __post_init__(self):
|
||||
self.ingest_doc_cls: t.Type[BoxIngestDoc] = BoxIngestDoc
|
||||
|
||||
@ -65,3 +76,13 @@ class BoxSourceConnector(FsspecSourceConnector):
|
||||
@dataclass
|
||||
class BoxDestinationConnector(FsspecDestinationConnector):
|
||||
connector_config: SimpleBoxConfig
|
||||
|
||||
@requires_dependencies(["boxfs"], extras="box")
|
||||
def check_connection(self):
|
||||
from boxfs import BoxFileSystem
|
||||
|
||||
try:
|
||||
BoxFileSystem(**self.connector_config.access_kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
||||
|
@ -5,6 +5,8 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
|
||||
from unstructured.ingest.interfaces import (
|
||||
BaseConnectorConfig,
|
||||
@ -17,6 +19,9 @@ from unstructured.ingest.interfaces import (
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from atlassian import Confluence
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleConfluenceConfig(BaseConnectorConfig):
|
||||
@ -185,17 +190,31 @@ class ConfluenceSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
|
||||
"""Fetches body fields from all documents within all spaces in a Confluence Cloud instance."""
|
||||
|
||||
connector_config: SimpleConfluenceConfig
|
||||
_confluence: t.Optional["Confluence"] = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def confluence(self) -> "Confluence":
|
||||
from atlassian import Confluence
|
||||
|
||||
if self._confluence is None:
|
||||
self._confluence = Confluence(
|
||||
url=self.connector_config.url,
|
||||
username=self.connector_config.user_email,
|
||||
password=self.connector_config.api_token,
|
||||
)
|
||||
return self._confluence
|
||||
|
||||
@requires_dependencies(["atlassian"], extras="Confluence")
|
||||
def check_connection(self):
|
||||
url = "rest/api/space"
|
||||
try:
|
||||
self.confluence.request(method="HEAD", path=url)
|
||||
except requests.HTTPError as http_error:
|
||||
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {http_error}")
|
||||
|
||||
@requires_dependencies(["atlassian"], extras="Confluence")
|
||||
def initialize(self):
|
||||
from atlassian import Confluence
|
||||
|
||||
self.confluence = Confluence(
|
||||
url=self.connector_config.url,
|
||||
username=self.connector_config.user_email,
|
||||
password=self.connector_config.api_token,
|
||||
)
|
||||
|
||||
self.list_of_spaces = None
|
||||
if self.connector_config.spaces:
|
||||
self.list_of_spaces = self.connector_config.spaces
|
||||
|
@ -119,6 +119,9 @@ class DeltaTableSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
|
||||
connector_config: SimpleDeltaTableConfig
|
||||
delta_table: t.Optional["DeltaTable"] = None
|
||||
|
||||
def check_connection(self):
|
||||
pass
|
||||
|
||||
@requires_dependencies(["deltalake"], extras="delta-table")
|
||||
def initialize(self):
|
||||
from deltalake import DeltaTable
|
||||
@ -172,6 +175,9 @@ class DeltaTableDestinationConnector(BaseDestinationConnector):
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def check_connection(self):
|
||||
pass
|
||||
|
||||
def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None:
|
||||
from deltalake.writer import write_deltalake
|
||||
|
||||
|
@ -156,6 +156,21 @@ class DiscordSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
@requires_dependencies(dependencies=["discord"], extras="discord")
|
||||
def check_connection(self):
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.client import Client
|
||||
|
||||
intents = discord.Intents.default()
|
||||
try:
|
||||
client = Client(intents=intents)
|
||||
asyncio.run(client.start(token=self.connector_config.token))
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def get_ingest_docs(self):
|
||||
return [
|
||||
DiscordIngestDoc(
|
||||
|
@ -20,6 +20,7 @@ from unstructured.ingest.connector.fsspec import (
|
||||
SimpleFsspecConfig,
|
||||
)
|
||||
from unstructured.ingest.error import SourceConnectionError
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
|
||||
@ -90,12 +91,16 @@ class DropboxSourceConnector(FsspecSourceConnector):
|
||||
def initialize(self):
|
||||
from fsspec import AbstractFileSystem, get_filesystem_class
|
||||
|
||||
self.fs: AbstractFileSystem = get_filesystem_class(self.connector_config.protocol)(
|
||||
**self.connector_config.get_access_kwargs(),
|
||||
)
|
||||
# Dropbox requires a forward slash at the front of the folder path. This
|
||||
# creates some complications in path joining so a custom path is created here.
|
||||
ls_output = self.fs.ls(f"/{self.connector_config.path_without_protocol}")
|
||||
try:
|
||||
self.fs: AbstractFileSystem = get_filesystem_class(self.connector_config.protocol)(
|
||||
**self.connector_config.get_access_kwargs(),
|
||||
)
|
||||
# Dropbox requires a forward slash at the front of the folder path. This
|
||||
# creates some complications in path joining so a custom path is created here.
|
||||
ls_output = self.fs.ls(f"/{self.connector_config.path_without_protocol}")
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
if ls_output and len(ls_output) >= 1:
|
||||
return
|
||||
elif ls_output:
|
||||
|
@ -2,7 +2,7 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
|
||||
@ -17,6 +17,9 @@ from unstructured.ingest.interfaces import (
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleElasticsearchConfig(BaseConnectorConfig):
|
||||
@ -30,7 +33,7 @@ class SimpleElasticsearchConfig(BaseConnectorConfig):
|
||||
|
||||
url: str
|
||||
index_name: str
|
||||
jq_query: t.Optional[str]
|
||||
jq_query: t.Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -185,12 +188,25 @@ class ElasticsearchSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnec
|
||||
"""Fetches particular fields from all documents in a given elasticsearch cluster and index"""
|
||||
|
||||
connector_config: SimpleElasticsearchConfig
|
||||
_es: t.Optional["Elasticsearch"] = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def es(self):
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
if self._es is None:
|
||||
self._es = Elasticsearch(self.connector_config.url)
|
||||
return self._es
|
||||
|
||||
def check_connection(self):
|
||||
try:
|
||||
self.es.perform_request("HEAD", "/", headers={"accept": "application/json"})
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
|
||||
def initialize(self):
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
self.es = Elasticsearch(self.connector_config.url)
|
||||
self.scan_query: dict = {"query": {"match_all": {}}}
|
||||
self.search_query: dict = {"match_all": {}}
|
||||
self.es.search(index=self.connector_config.index_name, query=self.search_query, size=1)
|
||||
|
@ -5,7 +5,11 @@ from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
from unstructured.ingest.error import SourceConnectionError, SourceConnectionNetworkError
|
||||
from unstructured.ingest.error import (
|
||||
DestinationConnectionError,
|
||||
SourceConnectionError,
|
||||
SourceConnectionNetworkError,
|
||||
)
|
||||
from unstructured.ingest.interfaces import (
|
||||
BaseConnectorConfig,
|
||||
BaseDestinationConnector,
|
||||
@ -147,6 +151,18 @@ class FsspecSourceConnector(
|
||||
|
||||
connector_config: SimpleFsspecConfig
|
||||
|
||||
def check_connection(self):
|
||||
from fsspec import get_filesystem_class
|
||||
|
||||
try:
|
||||
fs = get_filesystem_class(self.connector_config.protocol)(
|
||||
**self.connector_config.get_access_kwargs(),
|
||||
)
|
||||
fs.ls(path=self.connector_config.path_without_protocol)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def __post_init__(self):
|
||||
self.ingest_doc_cls: t.Type[FsspecIngestDoc] = FsspecIngestDoc
|
||||
|
||||
@ -244,6 +260,18 @@ class FsspecDestinationConnector(BaseDestinationConnector):
|
||||
**self.connector_config.get_access_kwargs(),
|
||||
)
|
||||
|
||||
def check_connection(self):
|
||||
from fsspec import get_filesystem_class
|
||||
|
||||
try:
|
||||
fs = get_filesystem_class(self.connector_config.protocol)(
|
||||
**self.connector_config.get_access_kwargs(),
|
||||
)
|
||||
fs.ls(path=self.connector_config.path_without_protocol)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def write_dict(
|
||||
self,
|
||||
*args,
|
||||
|
@ -18,9 +18,9 @@ from unstructured.ingest.logger import logger
|
||||
@dataclass
|
||||
class SimpleGitConfig(BaseConnectorConfig):
|
||||
url: str
|
||||
access_token: t.Optional[str]
|
||||
branch: t.Optional[str]
|
||||
file_glob: t.Optional[str]
|
||||
access_token: t.Optional[str] = None
|
||||
branch: t.Optional[str] = None
|
||||
file_glob: t.Optional[str] = None
|
||||
repo_path: str = field(init=False, repr=False)
|
||||
|
||||
|
||||
@ -76,6 +76,9 @@ class GitSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def check_connection(self):
|
||||
pass
|
||||
|
||||
def is_file_type_supported(self, path: str) -> bool:
|
||||
# Workaround to ensure that auto.partition isn't fed with .yaml, .py, etc. files
|
||||
# TODO: What to do with no filenames? e.g. LICENSE, Makefile, etc.
|
||||
|
@ -127,6 +127,33 @@ class GitHubIngestDoc(GitIngestDoc):
|
||||
class GitHubSourceConnector(GitSourceConnector):
|
||||
connector_config: SimpleGitHubConfig
|
||||
|
||||
@requires_dependencies(["github"], extras="github")
|
||||
def check_connection(self):
|
||||
from github import Consts
|
||||
from github.GithubRetry import GithubRetry
|
||||
from github.Requester import Requester
|
||||
|
||||
try:
|
||||
requester = Requester(
|
||||
auth=self.connector_config.access_token,
|
||||
base_url=Consts.DEFAULT_BASE_URL,
|
||||
timeout=Consts.DEFAULT_TIMEOUT,
|
||||
user_agent=Consts.DEFAULT_USER_AGENT,
|
||||
per_page=Consts.DEFAULT_PER_PAGE,
|
||||
verify=True,
|
||||
retry=GithubRetry(),
|
||||
pool_size=None,
|
||||
)
|
||||
url_base = (
|
||||
"/repositories/" if isinstance(self.connector_config.repo_path, int) else "/repos/"
|
||||
)
|
||||
url = f"{url_base}{self.connector_config.repo_path}"
|
||||
headers, _ = requester.requestJsonAndCheck("HEAD", url)
|
||||
logger.debug(f"headers from HEAD request: {headers}")
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def get_ingest_docs(self):
|
||||
repo = self.connector_config.get_repo()
|
||||
# Load the Git tree with all files, and then create Ingest docs
|
||||
|
@ -103,6 +103,20 @@ class GitLabIngestDoc(GitIngestDoc):
|
||||
class GitLabSourceConnector(GitSourceConnector):
|
||||
connector_config: SimpleGitLabConfig
|
||||
|
||||
@requires_dependencies(["gitlab"], extras="gitlab")
|
||||
def check_connection(self):
|
||||
from gitlab import Gitlab
|
||||
from gitlab.exceptions import GitlabError
|
||||
|
||||
try:
|
||||
gitlab = Gitlab(
|
||||
self.connector_config.base_url, private_token=self.connector_config.access_token
|
||||
)
|
||||
gitlab.auth()
|
||||
except GitlabError as gitlab_error:
|
||||
logger.error(f"failed to validate connection: {gitlab_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {gitlab_error}")
|
||||
|
||||
def get_ingest_docs(self):
|
||||
# Load the Git tree with all files, and then create Ingest docs
|
||||
# for all blobs, i.e. all files, ignoring directories
|
||||
|
@ -324,6 +324,13 @@ class GoogleDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnecto
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def check_connection(self):
|
||||
try:
|
||||
self.connector_config.create_session_handle().service
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def get_ingest_docs(self):
|
||||
files = self._list_objects(self.connector_config.drive_id, self.connector_config.recursive)
|
||||
return [
|
||||
|
@ -2,7 +2,7 @@ import math
|
||||
import os
|
||||
import typing as t
|
||||
from collections import abc
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
@ -81,9 +81,9 @@ class SimpleJiraConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
|
||||
user_email: str
|
||||
api_token: str
|
||||
url: str
|
||||
projects: t.Optional[t.List[str]]
|
||||
boards: t.Optional[t.List[str]]
|
||||
issues: t.Optional[t.List[str]]
|
||||
projects: t.Optional[t.List[str]] = None
|
||||
boards: t.Optional[t.List[str]] = None
|
||||
issues: t.Optional[t.List[str]] = None
|
||||
|
||||
def create_session_handle(
|
||||
self,
|
||||
@ -342,10 +342,24 @@ class JiraSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
"""Fetches issues from projects in an Atlassian (Jira) Cloud instance."""
|
||||
|
||||
connector_config: SimpleJiraConfig
|
||||
_jira: t.Optional["Jira"] = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def jira(self) -> "Jira":
|
||||
if self._jira is None:
|
||||
try:
|
||||
self._jira = self.connector_config.create_session_handle().service
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
return self._jira
|
||||
|
||||
@requires_dependencies(["atlassian"], extras="jira")
|
||||
def initialize(self):
|
||||
self.jira = self.connector_config.create_session_handle().service
|
||||
_ = self.jira
|
||||
|
||||
def check_connection(self):
|
||||
_ = self.jira
|
||||
|
||||
@requires_dependencies(["atlassian"], extras="jira")
|
||||
def _get_all_project_ids(self):
|
||||
|
@ -89,6 +89,9 @@ class LocalIngestDoc(BaseIngestDoc):
|
||||
class LocalSourceConnector(BaseSourceConnector):
|
||||
"""Objects of this class support fetching document(s) from local file system"""
|
||||
|
||||
def check_connection(self):
|
||||
pass
|
||||
|
||||
connector_config: SimpleLocalConfig
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -1,8 +1,11 @@
|
||||
import os
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from unstructured.ingest.error import SourceConnectionError
|
||||
from unstructured.ingest.interfaces import (
|
||||
BaseConnectorConfig,
|
||||
BaseIngestDoc,
|
||||
@ -17,16 +20,18 @@ from unstructured.utils import (
|
||||
)
|
||||
|
||||
NOTION_API_VERSION = "2022-06-28"
|
||||
if t.TYPE_CHECKING:
|
||||
from unstructured.ingest.connector.notion.client import Client as NotionClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleNotionConfig(BaseConnectorConfig):
|
||||
"""Connector config to process all messages by channel id's."""
|
||||
|
||||
page_ids: t.List[str]
|
||||
database_ids: t.List[str]
|
||||
recursive: bool
|
||||
notion_api_key: str
|
||||
page_ids: t.List[str] = field(default_factory=list)
|
||||
database_ids: t.List[str] = field(default_factory=list)
|
||||
recursive: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -284,20 +289,35 @@ class NotionSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
|
||||
connector_config: SimpleNotionConfig
|
||||
retry_strategy_config: t.Optional[RetryStrategyConfig] = None
|
||||
_client: t.Optional["NotionClient"] = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def client(self) -> "NotionClient":
|
||||
from unstructured.ingest.connector.notion.client import Client as NotionClient
|
||||
|
||||
if self._client is None:
|
||||
self._client = NotionClient(
|
||||
notion_version=NOTION_API_VERSION,
|
||||
auth=self.connector_config.notion_api_key,
|
||||
logger=logger,
|
||||
log_level=logger.level,
|
||||
retry_strategy_config=self.retry_strategy_config,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def check_connection(self):
|
||||
try:
|
||||
request = self.client._build_request("HEAD", "users")
|
||||
response = self.client.client.send(request)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as http_error:
|
||||
logger.error(f"failed to validate connection: {http_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {http_error}")
|
||||
|
||||
@requires_dependencies(dependencies=["notion_client"], extras="notion")
|
||||
def initialize(self):
|
||||
"""Verify that can get metadata for an object, validates connections info."""
|
||||
from unstructured.ingest.connector.notion.client import Client as NotionClient
|
||||
|
||||
# Pin the version of the api to avoid schema changes
|
||||
self.client = NotionClient(
|
||||
notion_version=NOTION_API_VERSION,
|
||||
auth=self.connector_config.notion_api_key,
|
||||
logger=logger,
|
||||
log_level=logger.level,
|
||||
retry_strategy_config=self.retry_strategy_config,
|
||||
)
|
||||
_ = self.client
|
||||
|
||||
@requires_dependencies(dependencies=["notion_client"], extras="notion")
|
||||
def get_child_page_content(self, page_id: str):
|
||||
|
@ -17,6 +17,7 @@ from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from office365.graph_client import GraphClient
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem
|
||||
|
||||
MAX_MB_SIZE = 512_000_000
|
||||
@ -28,7 +29,7 @@ class SimpleOneDriveConfig(BaseConnectorConfig):
|
||||
client_credential: str = field(repr=False)
|
||||
user_pname: str
|
||||
tenant: str = field(repr=False)
|
||||
authority_url: t.Optional[str] = field(repr=False)
|
||||
authority_url: t.Optional[str] = field(repr=False, default="https://login.microsoftonline.com")
|
||||
path: t.Optional[str] = field(default="")
|
||||
recursive: bool = False
|
||||
|
||||
@ -177,12 +178,32 @@ class OneDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
|
||||
@dataclass
|
||||
class OneDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
connector_config: SimpleOneDriveConfig
|
||||
_client: t.Optional["GraphClient"] = field(init=False, default=None)
|
||||
|
||||
@requires_dependencies(["office365"], extras="onedrive")
|
||||
def _set_client(self):
|
||||
@property
|
||||
def client(self) -> "GraphClient":
|
||||
from office365.graph_client import GraphClient
|
||||
|
||||
self.client = GraphClient(self.connector_config.token_factory)
|
||||
if self._client is None:
|
||||
self._client = GraphClient(self.connector_config.token_factory)
|
||||
return self._client
|
||||
|
||||
@requires_dependencies(["office365"], extras="onedrive")
|
||||
def initialize(self):
|
||||
_ = self.client
|
||||
|
||||
@requires_dependencies(["office365"], extras="onedrive")
|
||||
def check_connection(self):
|
||||
try:
|
||||
token_resp: dict = self.connector_config.token_factory()
|
||||
if error := token_resp.get("error"):
|
||||
raise SourceConnectionError(
|
||||
"{} ({})".format(error, token_resp.get("error_description"))
|
||||
)
|
||||
_ = self.client
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def _list_objects(self, folder, recursive) -> t.List["DriveItem"]:
|
||||
drive_items = folder.children.get().execute_query()
|
||||
@ -205,9 +226,6 @@ class OneDriveSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
def initialize(self):
|
||||
self._set_client()
|
||||
|
||||
def get_ingest_docs(self):
|
||||
root = self.client.users[self.connector_config.user_pname].drive.get().execute_query().root
|
||||
if fpath := self.connector_config.path:
|
||||
|
@ -19,6 +19,8 @@ from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
MAX_NUM_EMAILS = 1000000 # Maximum number of emails per folder
|
||||
if t.TYPE_CHECKING:
|
||||
from office365.graph_client import GraphClient
|
||||
|
||||
|
||||
class MissingFolderError(Exception):
|
||||
@ -29,12 +31,12 @@ class MissingFolderError(Exception):
|
||||
class SimpleOutlookConfig(BaseConnectorConfig):
|
||||
"""This class is getting the token."""
|
||||
|
||||
client_id: t.Optional[str]
|
||||
client_credential: t.Optional[str] = field(repr=False)
|
||||
user_email: str
|
||||
tenant: t.Optional[str] = field(repr=False)
|
||||
authority_url: t.Optional[str] = field(repr=False)
|
||||
ms_outlook_folders: t.List[str]
|
||||
client_id: str
|
||||
client_credential: str = field(repr=False)
|
||||
tenant: t.Optional[str] = field(repr=False, default="common")
|
||||
authority_url: t.Optional[str] = field(repr=False, default="https://login.microsoftonline.com")
|
||||
ms_outlook_folders: t.List[str] = field(default_factory=list)
|
||||
recursive: bool = False
|
||||
registry_name: str = "outlook"
|
||||
|
||||
@ -42,7 +44,7 @@ class SimpleOutlookConfig(BaseConnectorConfig):
|
||||
if not (self.client_id and self.client_credential and self.user_email):
|
||||
raise ValueError(
|
||||
"Please provide one of the following mandatory values:"
|
||||
"\n--client_id\n--client_cred\n--user-email",
|
||||
"\nclient_id\nclient_cred\nuser_email",
|
||||
)
|
||||
self.token_factory = self._acquire_token
|
||||
|
||||
@ -180,10 +182,26 @@ class OutlookIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
|
||||
@dataclass
|
||||
class OutlookSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
connector_config: SimpleOutlookConfig
|
||||
_client: t.Optional["GraphClient"] = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def client(self) -> "GraphClient":
|
||||
if self._client is None:
|
||||
self._client = self.connector_config._get_client()
|
||||
return self._client
|
||||
|
||||
def initialize(self):
|
||||
self.client = self.connector_config._get_client()
|
||||
self.get_folder_ids()
|
||||
try:
|
||||
self.get_folder_ids()
|
||||
except Exception as e:
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def check_connection(self):
|
||||
try:
|
||||
_ = self.client
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def recurse_folders(self, folder_id, main_folder_dict):
|
||||
"""We only get a count of subfolders for any folder.
|
||||
|
@ -16,15 +16,18 @@ from unstructured.ingest.interfaces import (
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from praw import Reddit
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleRedditConfig(BaseConnectorConfig):
|
||||
subreddit_name: str
|
||||
client_id: t.Optional[str]
|
||||
client_secret: t.Optional[str]
|
||||
user_agent: str
|
||||
search_query: t.Optional[str]
|
||||
num_posts: int
|
||||
user_agent: str
|
||||
client_id: str
|
||||
client_secret: t.Optional[str] = None
|
||||
search_query: t.Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_posts <= 0:
|
||||
@ -110,16 +113,33 @@ class RedditIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
|
||||
@dataclass
|
||||
class RedditSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
connector_config: SimpleRedditConfig
|
||||
_reddit: t.Optional["Reddit"] = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def reddit(self) -> "Reddit":
|
||||
from praw import Reddit
|
||||
|
||||
if self._reddit is None:
|
||||
self._reddit = Reddit(
|
||||
client_id=self.connector_config.client_id,
|
||||
client_secret=self.connector_config.client_secret,
|
||||
user_agent=self.connector_config.user_agent,
|
||||
)
|
||||
return self._reddit
|
||||
|
||||
@requires_dependencies(["praw"], extras="reddit")
|
||||
def initialize(self):
|
||||
from praw import Reddit
|
||||
_ = self.reddit
|
||||
|
||||
self.reddit = Reddit(
|
||||
client_id=self.connector_config.client_id,
|
||||
client_secret=self.connector_config.client_secret,
|
||||
user_agent=self.connector_config.user_agent,
|
||||
)
|
||||
def check_connection(self):
|
||||
from praw.endpoints import API_PATH
|
||||
from prawcore import ResponseException
|
||||
|
||||
try:
|
||||
self.reddit._objectify_request(method="HEAD", params=None, path=API_PATH["me"])
|
||||
except ResponseException as response_error:
|
||||
logger.error(f"failed to validate connection: {response_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {response_error}")
|
||||
|
||||
def get_ingest_docs(self):
|
||||
subreddit = self.reddit.subreddit(self.connector_config.subreddit_name)
|
||||
|
@ -224,6 +224,16 @@ class SalesforceSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
@requires_dependencies(["simple_salesforce"], extras="salesforce")
|
||||
def check_connection(self):
|
||||
from simple_salesforce.exceptions import SalesforceError
|
||||
|
||||
try:
|
||||
self.connector_config.get_client()
|
||||
except SalesforceError as salesforce_error:
|
||||
logger.error(f"failed to validate connection: {salesforce_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {salesforce_error}")
|
||||
|
||||
@requires_dependencies(["simple_salesforce"], extras="salesforce")
|
||||
def get_ingest_docs(self) -> t.List[SalesforceIngestDoc]:
|
||||
"""Get Salesforce Ids for the records.
|
||||
|
@ -291,6 +291,14 @@ class SharepointIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
|
||||
class SharepointSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
connector_config: SimpleSharepointConfig
|
||||
|
||||
def check_connection(self):
|
||||
try:
|
||||
site_client = self.connector_config.get_site_client()
|
||||
site_client.site_pages.pages.get().execute_query()
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
@requires_dependencies(["office365"], extras="sharepoint")
|
||||
def _list_files(self, folder, recursive) -> t.List["File"]:
|
||||
from office365.runtime.client_request_exception import ClientRequestException
|
||||
|
@ -198,6 +198,18 @@ class SlackSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
||||
|
||||
connector_config: SimpleSlackConfig
|
||||
|
||||
@requires_dependencies(dependencies=["slack_sdk"], extras="slack")
|
||||
def check_connection(self):
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackClientError
|
||||
|
||||
try:
|
||||
client = WebClient(token=self.connector_config.token)
|
||||
client.users_identity()
|
||||
except SlackClientError as slack_error:
|
||||
logger.error(f"failed to validate connection: {slack_error}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {slack_error}")
|
||||
|
||||
def initialize(self):
|
||||
"""Verify that can get metadata for an object, validates connections info."""
|
||||
|
||||
|
@ -177,6 +177,19 @@ class WikipediaSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector)
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
@requires_dependencies(["wikipedia"], extras="wikipedia")
|
||||
def check_connection(self):
|
||||
import wikipedia
|
||||
|
||||
try:
|
||||
wikipedia.page(
|
||||
self.connector_config.title,
|
||||
auto_suggest=self.connector_config.auto_suggest,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
||||
raise SourceConnectionError(f"failed to validate connection: {e}")
|
||||
|
||||
def get_ingest_docs(self):
|
||||
return [
|
||||
WikipediaIngestTextDoc(
|
||||
|
@ -523,7 +523,14 @@ class BaseIngestDoc(IngestDocJsonMixin, ABC):
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseSourceConnector(DataClassJsonMixin, ABC):
|
||||
class BaseConnector(DataClassJsonMixin, ABC):
|
||||
@abstractmethod
|
||||
def check_connection(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseSourceConnector(BaseConnector, ABC):
|
||||
"""Abstract Base Class for a connector to a remote source, e.g. S3 or Google Drive."""
|
||||
|
||||
processor_config: ProcessorConfig
|
||||
@ -551,7 +558,7 @@ class BaseSourceConnector(DataClassJsonMixin, ABC):
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDestinationConnector(DataClassJsonMixin, ABC):
|
||||
class BaseDestinationConnector(BaseConnector, ABC):
|
||||
write_config: WriteConfig
|
||||
connector_config: BaseConnectorConfig
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user