mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-04 10:59:27 +00:00
### What problem does this PR solve? Add Jira connector. <img width="978" height="925" alt="image" src="https://github.com/user-attachments/assets/78bb5c77-2710-4569-a76e-9087ca23b227" /> --- <img width="1903" height="489" alt="image" src="https://github.com/user-attachments/assets/193bc5c5-f751-4bd5-883a-2173282c2b96" /> --- <img width="1035" height="925" alt="image" src="https://github.com/user-attachments/assets/1a0aec19-30eb-4ada-9283-61d1c915f59d" /> --- <img width="1905" height="601" alt="image" src="https://github.com/user-attachments/assets/3dde1062-3f27-4717-8e09-fd5fd5e64171" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
498 lines
20 KiB
Python
498 lines
20 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
# from beartype import BeartypeConf
|
|
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
|
|
|
|
|
import copy
|
|
import faulthandler
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import trio
|
|
|
|
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from common import settings
|
|
from common.config_utils import show_configs
|
|
from common.constants import FileSource, TaskStatus
|
|
from common.data_source import (
|
|
BlobStorageConnector,
|
|
DiscordConnector,
|
|
GoogleDriveConnector,
|
|
JiraConnector,
|
|
NotionConnector,
|
|
)
|
|
from common.data_source.config import INDEX_BATCH_SIZE
|
|
from common.data_source.confluence_connector import ConfluenceConnector
|
|
from common.data_source.interfaces import CheckpointOutputWrapper
|
|
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
|
from common.log_utils import init_root_logger
|
|
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
|
from common.versions import get_ragflow_version
|
|
|
|
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
|
|
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
|
|
|
|
|
class SyncBase:
|
|
SOURCE_NAME: str = None
|
|
|
|
def __init__(self, conf: dict) -> None:
|
|
self.conf = conf
|
|
|
|
async def __call__(self, task: dict):
|
|
SyncLogsService.start(task["id"], task["connector_id"])
|
|
try:
|
|
async with task_limiter:
|
|
with trio.fail_after(task["timeout_secs"]):
|
|
document_batch_generator = await self._generate(task)
|
|
doc_num = 0
|
|
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
|
if task["poll_range_start"]:
|
|
next_update = task["poll_range_start"]
|
|
for document_batch in document_batch_generator:
|
|
if not document_batch:
|
|
continue
|
|
min_update = min([doc.doc_updated_at for doc in document_batch])
|
|
max_update = max([doc.doc_updated_at for doc in document_batch])
|
|
next_update = max([next_update, max_update])
|
|
docs = [
|
|
{
|
|
"id": doc.id,
|
|
"connector_id": task["connector_id"],
|
|
"source": self.SOURCE_NAME,
|
|
"semantic_identifier": doc.semantic_identifier,
|
|
"extension": doc.extension,
|
|
"size_bytes": doc.size_bytes,
|
|
"doc_updated_at": doc.doc_updated_at,
|
|
"blob": doc.blob,
|
|
}
|
|
for doc in document_batch
|
|
]
|
|
|
|
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
|
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
|
|
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
|
doc_num += len(docs)
|
|
|
|
prefix = "[Jira] " if self.SOURCE_NAME == FileSource.JIRA else ""
|
|
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
|
|
SyncLogsService.done(task["id"], task["connector_id"])
|
|
task["poll_range_start"] = next_update
|
|
|
|
except Exception as ex:
|
|
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
|
|
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
|
|
|
|
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
|
|
|
async def _generate(self, task: dict):
|
|
raise NotImplementedError
|
|
|
|
|
|
class S3(SyncBase):
|
|
SOURCE_NAME: str = FileSource.S3
|
|
|
|
async def _generate(self, task: dict):
|
|
self.connector = BlobStorageConnector(bucket_type=self.conf.get("bucket_type", "s3"), bucket_name=self.conf["bucket_name"], prefix=self.conf.get("prefix", ""))
|
|
self.connector.load_credentials(self.conf["credentials"])
|
|
document_batch_generator = (
|
|
self.connector.load_from_state()
|
|
if task["reindex"] == "1" or not task["poll_range_start"]
|
|
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
|
)
|
|
|
|
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
|
logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"), self.conf["bucket_name"], self.conf.get("prefix", ""), begin_info))
|
|
return document_batch_generator
|
|
|
|
|
|
class Confluence(SyncBase):
|
|
SOURCE_NAME: str = FileSource.CONFLUENCE
|
|
|
|
async def _generate(self, task: dict):
|
|
from common.data_source.config import DocumentSource
|
|
from common.data_source.interfaces import StaticCredentialsProvider
|
|
|
|
self.connector = ConfluenceConnector(
|
|
wiki_base=self.conf["wiki_base"],
|
|
space=self.conf.get("space", ""),
|
|
is_cloud=self.conf.get("is_cloud", True),
|
|
# page_id=self.conf.get("page_id", ""),
|
|
)
|
|
|
|
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
|
|
self.connector.set_credentials_provider(credentials_provider)
|
|
|
|
# Determine the time range for synchronization based on reindex or poll_range_start
|
|
if task["reindex"] == "1" or not task["poll_range_start"]:
|
|
start_time = 0.0
|
|
begin_info = "totally"
|
|
else:
|
|
start_time = task["poll_range_start"].timestamp()
|
|
begin_info = f"from {task['poll_range_start']}"
|
|
|
|
end_time = datetime.now(timezone.utc).timestamp()
|
|
|
|
document_generator = load_all_docs_from_checkpoint_connector(
|
|
connector=self.connector,
|
|
start=start_time,
|
|
end=end_time,
|
|
)
|
|
|
|
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
|
return [document_generator]
|
|
|
|
|
|
class Notion(SyncBase):
|
|
SOURCE_NAME: str = FileSource.NOTION
|
|
|
|
async def _generate(self, task: dict):
|
|
self.connector = NotionConnector(root_page_id=self.conf["root_page_id"])
|
|
self.connector.load_credentials(self.conf["credentials"])
|
|
document_generator = (
|
|
self.connector.load_from_state()
|
|
if task["reindex"] == "1" or not task["poll_range_start"]
|
|
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
|
)
|
|
|
|
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
|
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
|
|
return document_generator
|
|
|
|
|
|
class Discord(SyncBase):
|
|
SOURCE_NAME: str = FileSource.DISCORD
|
|
|
|
async def _generate(self, task: dict):
|
|
server_ids: str | None = self.conf.get("server_ids", None)
|
|
# "channel1,channel2"
|
|
channel_names: str | None = self.conf.get("channel_names", None)
|
|
|
|
self.connector = DiscordConnector(
|
|
server_ids=server_ids.split(",") if server_ids else [],
|
|
channel_names=channel_names.split(",") if channel_names else [],
|
|
start_date=datetime(1970, 1, 1, tzinfo=timezone.utc).strftime("%Y-%m-%d"),
|
|
batch_size=self.conf.get("batch_size", 1024),
|
|
)
|
|
self.connector.load_credentials(self.conf["credentials"])
|
|
document_generator = (
|
|
self.connector.load_from_state()
|
|
if task["reindex"] == "1" or not task["poll_range_start"]
|
|
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
|
)
|
|
|
|
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
|
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
|
|
return document_generator
|
|
|
|
|
|
class Gmail(SyncBase):
|
|
SOURCE_NAME: str = FileSource.GMAIL
|
|
|
|
async def _generate(self, task: dict):
|
|
pass
|
|
|
|
|
|
class GoogleDrive(SyncBase):
|
|
SOURCE_NAME: str = FileSource.GOOGLE_DRIVE
|
|
|
|
async def _generate(self, task: dict):
|
|
connector_kwargs = {
|
|
"include_shared_drives": self.conf.get("include_shared_drives", False),
|
|
"include_my_drives": self.conf.get("include_my_drives", False),
|
|
"include_files_shared_with_me": self.conf.get("include_files_shared_with_me", False),
|
|
"shared_drive_urls": self.conf.get("shared_drive_urls"),
|
|
"my_drive_emails": self.conf.get("my_drive_emails"),
|
|
"shared_folder_urls": self.conf.get("shared_folder_urls"),
|
|
"specific_user_emails": self.conf.get("specific_user_emails"),
|
|
"batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
|
}
|
|
self.connector = GoogleDriveConnector(**connector_kwargs)
|
|
self.connector.set_allow_images(self.conf.get("allow_images", False))
|
|
|
|
credentials = self.conf.get("credentials")
|
|
if not credentials:
|
|
raise ValueError("Google Drive connector is missing credentials.")
|
|
|
|
new_credentials = self.connector.load_credentials(credentials)
|
|
if new_credentials:
|
|
self._persist_rotated_credentials(task["connector_id"], new_credentials)
|
|
|
|
if task["reindex"] == "1" or not task["poll_range_start"]:
|
|
start_time = 0.0
|
|
begin_info = "totally"
|
|
else:
|
|
start_time = task["poll_range_start"].timestamp()
|
|
begin_info = f"from {task['poll_range_start']}"
|
|
|
|
end_time = datetime.now(timezone.utc).timestamp()
|
|
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
|
try:
|
|
batch_size = int(raw_batch_size)
|
|
except (TypeError, ValueError):
|
|
batch_size = INDEX_BATCH_SIZE
|
|
if batch_size <= 0:
|
|
batch_size = INDEX_BATCH_SIZE
|
|
|
|
def document_batches():
|
|
checkpoint = self.connector.build_dummy_checkpoint()
|
|
pending_docs = []
|
|
iterations = 0
|
|
iteration_limit = 100_000
|
|
|
|
while checkpoint.has_more:
|
|
wrapper = CheckpointOutputWrapper()
|
|
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
|
for document, failure, next_checkpoint in doc_generator:
|
|
if failure is not None:
|
|
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
|
|
continue
|
|
if document is not None:
|
|
pending_docs.append(document)
|
|
if len(pending_docs) >= batch_size:
|
|
yield pending_docs
|
|
pending_docs = []
|
|
if next_checkpoint is not None:
|
|
checkpoint = next_checkpoint
|
|
|
|
iterations += 1
|
|
if iterations > iteration_limit:
|
|
raise RuntimeError("Too many iterations while loading Google Drive documents.")
|
|
|
|
if pending_docs:
|
|
yield pending_docs
|
|
|
|
try:
|
|
admin_email = self.connector.primary_admin_email
|
|
except RuntimeError:
|
|
admin_email = "unknown"
|
|
logging.info(f"Connect to Google Drive as {admin_email} {begin_info}")
|
|
return document_batches()
|
|
|
|
def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None:
|
|
try:
|
|
updated_conf = copy.deepcopy(self.conf)
|
|
updated_conf["credentials"] = credentials
|
|
ConnectorService.update_by_id(connector_id, {"config": updated_conf})
|
|
self.conf = updated_conf
|
|
logging.info("Persisted refreshed Google Drive credentials for connector %s", connector_id)
|
|
except Exception:
|
|
logging.exception("Failed to persist refreshed Google Drive credentials for connector %s", connector_id)
|
|
|
|
|
|
class Jira(SyncBase):
|
|
SOURCE_NAME: str = FileSource.JIRA
|
|
|
|
async def _generate(self, task: dict):
|
|
connector_kwargs = {
|
|
"jira_base_url": self.conf["base_url"],
|
|
"project_key": self.conf.get("project_key"),
|
|
"jql_query": self.conf.get("jql_query"),
|
|
"batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
|
"include_comments": self.conf.get("include_comments", True),
|
|
"include_attachments": self.conf.get("include_attachments", False),
|
|
"labels_to_skip": self._normalize_list(self.conf.get("labels_to_skip")),
|
|
"comment_email_blacklist": self._normalize_list(self.conf.get("comment_email_blacklist")),
|
|
"scoped_token": self.conf.get("scoped_token", False),
|
|
"attachment_size_limit": self.conf.get("attachment_size_limit"),
|
|
"timezone_offset": self.conf.get("timezone_offset"),
|
|
}
|
|
|
|
self.connector = JiraConnector(**connector_kwargs)
|
|
|
|
credentials = self.conf.get("credentials")
|
|
if not credentials:
|
|
raise ValueError("Jira connector is missing credentials.")
|
|
|
|
self.connector.load_credentials(credentials)
|
|
self.connector.validate_connector_settings()
|
|
|
|
if task["reindex"] == "1" or not task["poll_range_start"]:
|
|
start_time = 0.0
|
|
begin_info = "totally"
|
|
else:
|
|
start_time = task["poll_range_start"].timestamp()
|
|
begin_info = f"from {task['poll_range_start']}"
|
|
|
|
end_time = datetime.now(timezone.utc).timestamp()
|
|
|
|
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
|
try:
|
|
batch_size = int(raw_batch_size)
|
|
except (TypeError, ValueError):
|
|
batch_size = INDEX_BATCH_SIZE
|
|
if batch_size <= 0:
|
|
batch_size = INDEX_BATCH_SIZE
|
|
|
|
def document_batches():
|
|
checkpoint = self.connector.build_dummy_checkpoint()
|
|
pending_docs = []
|
|
iterations = 0
|
|
iteration_limit = 100_000
|
|
|
|
while checkpoint.has_more:
|
|
wrapper = CheckpointOutputWrapper()
|
|
generator = wrapper(
|
|
self.connector.load_from_checkpoint(
|
|
start_time,
|
|
end_time,
|
|
checkpoint,
|
|
)
|
|
)
|
|
for document, failure, next_checkpoint in generator:
|
|
if failure is not None:
|
|
logging.warning(
|
|
f"[Jira] Jira connector failure: {getattr(failure, 'failure_message', failure)}"
|
|
)
|
|
continue
|
|
if document is not None:
|
|
pending_docs.append(document)
|
|
if len(pending_docs) >= batch_size:
|
|
yield pending_docs
|
|
pending_docs = []
|
|
if next_checkpoint is not None:
|
|
checkpoint = next_checkpoint
|
|
|
|
iterations += 1
|
|
if iterations > iteration_limit:
|
|
logging.error(f"[Jira] Task {task.get('id')} exceeded iteration limit ({iteration_limit}).")
|
|
raise RuntimeError("Too many iterations while loading Jira documents.")
|
|
|
|
if pending_docs:
|
|
yield pending_docs
|
|
|
|
logging.info(f"[Jira] Connect to Jira {connector_kwargs['jira_base_url']} {begin_info}")
|
|
return document_batches()
|
|
|
|
@staticmethod
|
|
def _normalize_list(values: Any) -> list[str] | None:
|
|
if values is None:
|
|
return None
|
|
if isinstance(values, str):
|
|
values = [item.strip() for item in values.split(",")]
|
|
return [str(value).strip() for value in values if value is not None and str(value).strip()]
|
|
|
|
|
|
class SharePoint(SyncBase):
|
|
SOURCE_NAME: str = FileSource.SHAREPOINT
|
|
|
|
async def _generate(self, task: dict):
|
|
pass
|
|
|
|
|
|
class Slack(SyncBase):
|
|
SOURCE_NAME: str = FileSource.SLACK
|
|
|
|
async def _generate(self, task: dict):
|
|
pass
|
|
|
|
|
|
class Teams(SyncBase):
|
|
SOURCE_NAME: str = FileSource.TEAMS
|
|
|
|
async def _generate(self, task: dict):
|
|
pass
|
|
|
|
|
|
func_factory = {
|
|
FileSource.S3: S3,
|
|
FileSource.NOTION: Notion,
|
|
FileSource.DISCORD: Discord,
|
|
FileSource.CONFLUENCE: Confluence,
|
|
FileSource.GMAIL: Gmail,
|
|
FileSource.GOOGLE_DRIVE: GoogleDrive,
|
|
FileSource.JIRA: Jira,
|
|
FileSource.SHAREPOINT: SharePoint,
|
|
FileSource.SLACK: Slack,
|
|
FileSource.TEAMS: Teams,
|
|
}
|
|
|
|
|
|
async def dispatch_tasks():
|
|
async with trio.open_nursery() as nursery:
|
|
while True:
|
|
try:
|
|
list(SyncLogsService.list_sync_tasks()[0])
|
|
break
|
|
except Exception as e:
|
|
logging.warning(f"DB is not ready yet: {e}")
|
|
await trio.sleep(3)
|
|
|
|
for task in SyncLogsService.list_sync_tasks()[0]:
|
|
if task["poll_range_start"]:
|
|
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
|
if task["poll_range_end"]:
|
|
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
|
|
func = func_factory[task["source"]](task["config"])
|
|
nursery.start_soon(func, task)
|
|
await trio.sleep(1)
|
|
|
|
|
|
stop_event = threading.Event()
|
|
|
|
|
|
def signal_handler(sig, frame):
|
|
logging.info("Received interrupt signal, shutting down...")
|
|
stop_event.set()
|
|
time.sleep(1)
|
|
sys.exit(0)
|
|
|
|
|
|
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
|
CONSUMER_NAME = "data_sync_" + CONSUMER_NO
|
|
|
|
|
|
async def main():
|
|
logging.info(r"""
|
|
_____ _ _____
|
|
| __ \ | | / ____|
|
|
| | | | __ _| |_ __ _ | (___ _ _ _ __ ___
|
|
| | | |/ _` | __/ _` | \___ \| | | | '_ \ / __|
|
|
| |__| | (_| | || (_| | ____) | |_| | | | | (__
|
|
|_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___|
|
|
__/ |
|
|
|___/
|
|
""")
|
|
logging.info(f"RAGFlow version: {get_ragflow_version()}")
|
|
show_configs()
|
|
settings.init_settings()
|
|
if sys.platform != "win32":
|
|
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
|
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
while not stop_event.is_set():
|
|
await dispatch_tasks()
|
|
logging.error("BUG!!! You should not reach here!!!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
faulthandler.enable()
|
|
init_root_logger(CONSUMER_NAME)
|
|
trio.run(main)
|