2024-06-30 17:38:39 +08:00
|
|
|
use crate::chat_manager::ChatUserService;
|
2024-07-08 13:19:13 +08:00
|
|
|
use crate::entities::{ChatStatePB, ModelTypePB};
|
2024-07-15 15:23:23 +08:00
|
|
|
use crate::local_ai::local_llm_chat::LocalAIController;
|
2024-07-18 20:54:35 +08:00
|
|
|
use crate::notification::{make_notification, ChatNotification, APPFLOWY_AI_NOTIFICATION_KEY};
|
2024-06-30 17:38:39 +08:00
|
|
|
use crate::persistence::select_single_message;
|
2024-07-08 13:19:13 +08:00
|
|
|
use appflowy_plugin::error::PluginError;
|
2024-07-15 15:23:23 +08:00
|
|
|
|
2024-06-30 17:38:39 +08:00
|
|
|
use flowy_chat_pub::cloud::{
|
2024-07-15 15:23:23 +08:00
|
|
|
ChatCloudService, ChatMessage, ChatMessageType, CompletionType, LocalAIConfig, MessageCursor,
|
2024-06-30 17:38:39 +08:00
|
|
|
RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer, StreamComplete,
|
|
|
|
};
|
|
|
|
use flowy_error::{FlowyError, FlowyResult};
|
2024-07-08 13:19:13 +08:00
|
|
|
use futures::{stream, StreamExt, TryStreamExt};
|
2024-06-30 17:38:39 +08:00
|
|
|
use lib_infra::async_trait::async_trait;
|
|
|
|
use lib_infra::future::FutureResult;
|
2024-07-15 15:23:23 +08:00
|
|
|
|
|
|
|
use std::path::PathBuf;
|
2024-06-30 17:38:39 +08:00
|
|
|
use std::sync::Arc;
|
|
|
|
|
2024-07-15 15:23:23 +08:00
|
|
|
pub struct ChatServiceMiddleware {
|
2024-06-30 17:38:39 +08:00
|
|
|
pub cloud_service: Arc<dyn ChatCloudService>,
|
|
|
|
user_service: Arc<dyn ChatUserService>,
|
2024-07-15 15:23:23 +08:00
|
|
|
local_llm_controller: Arc<LocalAIController>,
|
2024-06-30 17:38:39 +08:00
|
|
|
}
|
|
|
|
|
2024-07-15 15:23:23 +08:00
|
|
|
impl ChatServiceMiddleware {
|
2024-06-30 17:38:39 +08:00
|
|
|
pub fn new(
|
|
|
|
user_service: Arc<dyn ChatUserService>,
|
|
|
|
cloud_service: Arc<dyn ChatCloudService>,
|
2024-07-15 15:23:23 +08:00
|
|
|
local_llm_controller: Arc<LocalAIController>,
|
2024-06-30 17:38:39 +08:00
|
|
|
) -> Self {
|
|
|
|
Self {
|
|
|
|
user_service,
|
|
|
|
cloud_service,
|
2024-07-15 15:23:23 +08:00
|
|
|
local_llm_controller,
|
2024-06-30 17:38:39 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get_message_content(&self, message_id: i64) -> FlowyResult<String> {
|
|
|
|
let uid = self.user_service.user_id()?;
|
|
|
|
let conn = self.user_service.sqlite_connection(uid)?;
|
|
|
|
let content = select_single_message(conn, message_id)?
|
|
|
|
.map(|data| data.content)
|
|
|
|
.ok_or_else(|| {
|
|
|
|
FlowyError::record_not_found().with_context(format!("Message not found: {}", message_id))
|
|
|
|
})?;
|
|
|
|
|
|
|
|
Ok(content)
|
|
|
|
}
|
2024-07-08 13:19:13 +08:00
|
|
|
|
|
|
|
fn handle_plugin_error(&self, err: PluginError) {
|
|
|
|
if matches!(
|
|
|
|
err,
|
|
|
|
PluginError::PluginNotConnected | PluginError::PeerDisconnect
|
|
|
|
) {
|
2024-07-18 20:54:35 +08:00
|
|
|
make_notification(
|
|
|
|
APPFLOWY_AI_NOTIFICATION_KEY,
|
2024-07-15 15:23:23 +08:00
|
|
|
ChatNotification::UpdateChatPluginState,
|
|
|
|
)
|
|
|
|
.payload(ChatStatePB {
|
|
|
|
model_type: ModelTypePB::LocalAI,
|
|
|
|
available: false,
|
2024-07-18 20:54:35 +08:00
|
|
|
})
|
|
|
|
.send();
|
2024-07-08 13:19:13 +08:00
|
|
|
}
|
|
|
|
}
|
2024-06-30 17:38:39 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait]
|
2024-07-15 15:23:23 +08:00
|
|
|
impl ChatCloudService for ChatServiceMiddleware {
|
2024-06-30 17:38:39 +08:00
|
|
|
fn create_chat(
|
|
|
|
&self,
|
|
|
|
uid: &i64,
|
|
|
|
workspace_id: &str,
|
|
|
|
chat_id: &str,
|
|
|
|
) -> FutureResult<(), FlowyError> {
|
|
|
|
self.cloud_service.create_chat(uid, workspace_id, chat_id)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn save_question(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
chat_id: &str,
|
|
|
|
message: &str,
|
|
|
|
message_type: ChatMessageType,
|
|
|
|
) -> FutureResult<ChatMessage, FlowyError> {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.save_question(workspace_id, chat_id, message, message_type)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn save_answer(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
chat_id: &str,
|
|
|
|
message: &str,
|
|
|
|
question_id: i64,
|
|
|
|
) -> FutureResult<ChatMessage, FlowyError> {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.save_answer(workspace_id, chat_id, message, question_id)
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn ask_question(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
chat_id: &str,
|
|
|
|
message_id: i64,
|
|
|
|
) -> Result<StreamAnswer, FlowyError> {
|
2024-07-18 20:54:35 +08:00
|
|
|
if self.local_llm_controller.is_running() {
|
2024-06-30 17:38:39 +08:00
|
|
|
let content = self.get_message_content(message_id)?;
|
2024-07-15 15:23:23 +08:00
|
|
|
match self
|
|
|
|
.local_llm_controller
|
|
|
|
.stream_question(chat_id, &content)
|
|
|
|
.await
|
|
|
|
{
|
2024-07-08 13:19:13 +08:00
|
|
|
Ok(stream) => Ok(
|
|
|
|
stream
|
|
|
|
.map_err(|err| FlowyError::local_ai().with_context(err))
|
|
|
|
.boxed(),
|
|
|
|
),
|
|
|
|
Err(err) => {
|
|
|
|
self.handle_plugin_error(err);
|
|
|
|
Ok(stream::once(async { Err(FlowyError::local_ai_unavailable()) }).boxed())
|
|
|
|
},
|
|
|
|
}
|
2024-06-30 17:38:39 +08:00
|
|
|
} else {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.ask_question(workspace_id, chat_id, message_id)
|
|
|
|
.await
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn generate_answer(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
chat_id: &str,
|
|
|
|
question_message_id: i64,
|
|
|
|
) -> Result<ChatMessage, FlowyError> {
|
2024-07-18 20:54:35 +08:00
|
|
|
if self.local_llm_controller.is_running() {
|
2024-06-30 17:38:39 +08:00
|
|
|
let content = self.get_message_content(question_message_id)?;
|
2024-07-15 15:23:23 +08:00
|
|
|
match self
|
|
|
|
.local_llm_controller
|
|
|
|
.ask_question(chat_id, &content)
|
|
|
|
.await
|
|
|
|
{
|
2024-07-08 13:19:13 +08:00
|
|
|
Ok(answer) => {
|
|
|
|
let message = self
|
|
|
|
.cloud_service
|
|
|
|
.save_answer(workspace_id, chat_id, &answer, question_message_id)
|
|
|
|
.await?;
|
|
|
|
Ok(message)
|
|
|
|
},
|
|
|
|
Err(err) => {
|
|
|
|
self.handle_plugin_error(err);
|
|
|
|
Err(FlowyError::local_ai_unavailable())
|
|
|
|
},
|
|
|
|
}
|
2024-06-30 17:38:39 +08:00
|
|
|
} else {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.generate_answer(workspace_id, chat_id, question_message_id)
|
|
|
|
.await
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get_chat_messages(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
chat_id: &str,
|
|
|
|
offset: MessageCursor,
|
|
|
|
limit: u64,
|
|
|
|
) -> FutureResult<RepeatedChatMessage, FlowyError> {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.get_chat_messages(workspace_id, chat_id, offset, limit)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get_related_message(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
chat_id: &str,
|
|
|
|
message_id: i64,
|
|
|
|
) -> FutureResult<RepeatedRelatedQuestion, FlowyError> {
|
2024-07-18 20:54:35 +08:00
|
|
|
if self.local_llm_controller.is_running() {
|
2024-06-30 17:38:39 +08:00
|
|
|
FutureResult::new(async move {
|
|
|
|
Ok(RepeatedRelatedQuestion {
|
|
|
|
message_id,
|
|
|
|
items: vec![],
|
|
|
|
})
|
|
|
|
})
|
|
|
|
} else {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.get_related_message(workspace_id, chat_id, message_id)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn stream_complete(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
text: &str,
|
|
|
|
complete_type: CompletionType,
|
|
|
|
) -> Result<StreamComplete, FlowyError> {
|
2024-07-18 20:54:35 +08:00
|
|
|
if self.local_llm_controller.is_running() {
|
2024-07-15 15:23:23 +08:00
|
|
|
return Err(
|
|
|
|
FlowyError::not_support().with_context("completion with local ai is not supported yet"),
|
|
|
|
);
|
2024-06-30 17:38:39 +08:00
|
|
|
} else {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.stream_complete(workspace_id, text, complete_type)
|
|
|
|
.await
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-15 15:23:23 +08:00
|
|
|
async fn index_file(
|
|
|
|
&self,
|
|
|
|
workspace_id: &str,
|
|
|
|
file_path: PathBuf,
|
|
|
|
chat_id: &str,
|
|
|
|
) -> Result<(), FlowyError> {
|
2024-07-18 20:54:35 +08:00
|
|
|
if self.local_llm_controller.is_running() {
|
2024-07-15 15:23:23 +08:00
|
|
|
self
|
|
|
|
.local_llm_controller
|
|
|
|
.index_file(chat_id, file_path)
|
|
|
|
.await
|
|
|
|
.map_err(|err| FlowyError::local_ai().with_context(err))?;
|
|
|
|
Ok(())
|
|
|
|
} else {
|
|
|
|
self
|
|
|
|
.cloud_service
|
|
|
|
.index_file(workspace_id, file_path, chat_id)
|
|
|
|
.await
|
2024-06-30 17:38:39 +08:00
|
|
|
}
|
2024-07-15 15:23:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
async fn get_local_ai_config(&self, workspace_id: &str) -> Result<LocalAIConfig, FlowyError> {
|
|
|
|
self.cloud_service.get_local_ai_config(workspace_id).await
|
2024-06-30 17:38:39 +08:00
|
|
|
}
|
|
|
|
}
|