ragflow/common/data_source/interfaces.py
Yongteng Lei 465a140727
Feat: refine Confluence connector (#10994)
### What problem does this PR solve?

Refine Confluence connector.
#10953

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
2025-11-04 17:29:11 +08:00

413 lines
13 KiB
Python

"""Interface definitions"""
import abc
import uuid
from abc import ABC, abstractmethod
from enum import IntFlag, auto
from types import TracebackType
from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias
from anthropic import BaseModel
from common.data_source.models import (
Document,
SlimDocument,
ConnectorCheckpoint,
ConnectorFailure,
SecondsSinceUnixEpoch, GenerateSlimDocumentOutput
)
class LoadConnector(ABC):
"""Load connector interface"""
@abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials"""
pass
@abstractmethod
def load_from_state(self) -> Generator[list[Document], None, None]:
"""Load documents from state"""
pass
@abstractmethod
def validate_connector_settings(self) -> None:
"""Validate connector settings"""
pass
class PollConnector(ABC):
"""Poll connector interface"""
@abstractmethod
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
"""Poll source to get documents"""
pass
class CredentialsConnector(ABC):
"""Credentials connector interface"""
@abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials"""
pass
class SlimConnectorWithPermSync(ABC):
"""Simplified connector interface (with permission sync)"""
@abstractmethod
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Generator[list[SlimDocument], None, None]:
"""Retrieve all simplified documents (with permission sync)"""
pass
class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)"""
@abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint"""
pass
@abstractmethod
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint (with permission sync)"""
pass
@abstractmethod
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
pass
@abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
pass
T = TypeVar("T", bound="CredentialsProviderInterface")
class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def __enter__(self) -> T:
raise NotImplementedError
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def get_tenant_id(self) -> str | None:
raise NotImplementedError
@abc.abstractmethod
def get_provider_key(self) -> str:
"""a unique key that the connector can use to lock around a credential
that might be used simultaneously.
Will typically be the credential id, but can also just be something random
in cases when there is nothing to lock (aka static credentials)
"""
raise NotImplementedError
@abc.abstractmethod
def get_credentials(self) -> dict[str, Any]:
raise NotImplementedError
@abc.abstractmethod
def set_credentials(self, credential_json: dict[str, Any]) -> None:
raise NotImplementedError
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
needs to use the locking features of the credentials provider to operate
correctly.
If static, the client can simply reference the credentials once and use them
through the entire indexing run.
"""
raise NotImplementedError
class StaticCredentialsProvider(
CredentialsProviderInterface["StaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "StaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False
CT = TypeVar("CT", bound=ConnectorCheckpoint)
class BaseConnector(abc.ABC, Generic[CT]):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError
@staticmethod
def parse_metadata(metadata: dict[str, Any]) -> list[str]:
"""Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context"""
custom_parser_req_msg = (
"Specific metadata parsing required, connector has not implemented it."
)
metadata_lines = []
for metadata_key, metadata_value in metadata.items():
if isinstance(metadata_value, str):
metadata_lines.append(f"{metadata_key}: {metadata_value}")
elif isinstance(metadata_value, list):
if not all([isinstance(val, str) for val in metadata_value]):
raise RuntimeError(custom_parser_req_msg)
metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}')
else:
raise RuntimeError(custom_parser_req_msg)
return metadata_lines
def validate_connector_settings(self) -> None:
"""
Override this if your connector needs to validate credentials or settings.
Raise an exception if invalid, otherwise do nothing.
Default is a no-op (always successful).
"""
def validate_perm_sync(self) -> None:
"""
Don't override this; add a function to perm_sync_valid.py in the ee package
to do permission sync validation
"""
"""
validate_connector_settings_fn = fetch_ee_implementation_or_noop(
"onyx.connectors.perm_sync_valid",
"validate_perm_sync",
noop_return_value=None,
)
validate_connector_settings_fn(self)"""
def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def build_dummy_checkpoint(self) -> CT:
# TODO: find a way to make this work without type: ignore
return ConnectorCheckpoint(has_more=True) # type: ignore
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
LoadFunction = Callable[[CT], CheckpointOutput[CT]]
class CheckpointedConnector(BaseConnector[CT]):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CT,
) -> CheckpointOutput[CT]:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
```
try:
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value # Extracting the return value
print(checkpoint)
```
OR
```
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
```
"""
raise NotImplementedError
@abc.abstractmethod
def build_dummy_checkpoint(self) -> CT:
raise NotImplementedError
@abc.abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
"""Validate the checkpoint json and return the checkpoint object"""
raise NotImplementedError
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,
specifically for Document outputs.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: CT | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[CT],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, CT | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[CT],
) -> CheckpointOutput[CT]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
# Slim connectors retrieve just the ids of documents
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_docs(
self,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError
class ConfluenceUser(BaseModel):
user_id: str # accountId in Cloud, userKey in Server
username: str | None # Confluence Cloud doesn't give usernames
display_name: str
# Confluence Data Center doesn't give email back by default,
# have to fetch it with a different endpoint
email: str | None
type: str
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
class AttachmentProcessingResult(BaseModel):
"""
A container for results after processing a Confluence attachment.
'text' is the textual content of the attachment.
'file_name' is the final file name used in FileStore to store the content.
'error' holds an exception or string if something failed.
"""
text: str | None
file_blob: bytes | bytearray | None
file_name: str | None
error: str | None = None
model_config = {"arbitrary_types_allowed": True}
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, tag: str, amount: int) -> None:
"""Send progress updates to the caller.
Amount can be a positive number to indicate progress or <= 0
just to act as a keep-alive.
"""