2024-06-03 14:27:28 +08:00
|
|
|
use flowy_chat_pub::cloud::ChatMessageType;
|
2024-07-15 15:23:23 +08:00
|
|
|
use std::path::PathBuf;
|
2024-06-30 17:38:39 +08:00
|
|
|
|
2024-07-15 15:23:23 +08:00
|
|
|
use allo_isolate::Isolate;
|
2024-06-03 14:27:28 +08:00
|
|
|
use std::sync::{Arc, Weak};
|
2024-07-15 15:23:23 +08:00
|
|
|
use tokio::sync::oneshot;
|
2024-06-03 14:27:28 +08:00
|
|
|
use validator::Validate;
|
|
|
|
|
2024-07-15 15:23:23 +08:00
|
|
|
use crate::chat_manager::ChatManager;
|
|
|
|
use crate::entities::*;
|
|
|
|
use crate::local_ai::local_llm_chat::LLMModelInfo;
|
2024-06-25 01:59:38 +02:00
|
|
|
use crate::tools::AITools;
|
2024-06-03 14:27:28 +08:00
|
|
|
use flowy_error::{FlowyError, FlowyResult};
|
|
|
|
use lib_dispatch::prelude::{data_result_ok, AFPluginData, AFPluginState, DataResult};
|
2024-07-15 15:23:23 +08:00
|
|
|
use lib_infra::isolate_stream::IsolateSink;
|
2024-06-03 14:27:28 +08:00
|
|
|
|
|
|
|
fn upgrade_chat_manager(
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> FlowyResult<Arc<ChatManager>> {
|
|
|
|
let chat_manager = chat_manager
|
|
|
|
.upgrade()
|
|
|
|
.ok_or(FlowyError::internal().with_context("The chat manager is already dropped"))?;
|
|
|
|
Ok(chat_manager)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
2024-06-09 14:02:32 +08:00
|
|
|
pub(crate) async fn stream_chat_message_handler(
|
|
|
|
data: AFPluginData<StreamChatPayloadPB>,
|
2024-06-03 14:27:28 +08:00
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
2024-06-09 14:02:32 +08:00
|
|
|
) -> DataResult<ChatMessagePB, FlowyError> {
|
2024-06-03 14:27:28 +08:00
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let data = data.into_inner();
|
|
|
|
data.validate()?;
|
|
|
|
|
|
|
|
let message_type = match data.message_type {
|
|
|
|
ChatMessageTypePB::System => ChatMessageType::System,
|
|
|
|
ChatMessageTypePB::User => ChatMessageType::User,
|
|
|
|
};
|
2024-06-09 14:02:32 +08:00
|
|
|
|
|
|
|
let question = chat_manager
|
|
|
|
.stream_chat_message(
|
|
|
|
&data.chat_id,
|
|
|
|
&data.message,
|
|
|
|
message_type,
|
|
|
|
data.text_stream_port,
|
|
|
|
)
|
2024-06-03 14:27:28 +08:00
|
|
|
.await?;
|
2024-06-09 14:02:32 +08:00
|
|
|
data_result_ok(question)
|
2024-06-03 14:27:28 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn load_prev_message_handler(
|
|
|
|
data: AFPluginData<LoadPrevChatMessagePB>,
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> DataResult<ChatMessageListPB, FlowyError> {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let data = data.into_inner();
|
|
|
|
data.validate()?;
|
|
|
|
|
|
|
|
let messages = chat_manager
|
|
|
|
.load_prev_chat_messages(&data.chat_id, data.limit, data.before_message_id)
|
|
|
|
.await?;
|
|
|
|
data_result_ok(messages)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn load_next_message_handler(
|
|
|
|
data: AFPluginData<LoadNextChatMessagePB>,
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> DataResult<ChatMessageListPB, FlowyError> {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let data = data.into_inner();
|
|
|
|
data.validate()?;
|
|
|
|
|
|
|
|
let messages = chat_manager
|
|
|
|
.load_latest_chat_messages(&data.chat_id, data.limit, data.after_message_id)
|
|
|
|
.await?;
|
|
|
|
data_result_ok(messages)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn get_related_question_handler(
|
|
|
|
data: AFPluginData<ChatMessageIdPB>,
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> DataResult<RepeatedRelatedQuestionPB, FlowyError> {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let data = data.into_inner();
|
|
|
|
let messages = chat_manager
|
|
|
|
.get_related_questions(&data.chat_id, data.message_id)
|
|
|
|
.await?;
|
|
|
|
data_result_ok(messages)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn get_answer_handler(
|
|
|
|
data: AFPluginData<ChatMessageIdPB>,
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> DataResult<ChatMessagePB, FlowyError> {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let data = data.into_inner();
|
2024-06-30 17:38:39 +08:00
|
|
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
|
|
tokio::spawn(async move {
|
|
|
|
let message = chat_manager
|
|
|
|
.generate_answer(&data.chat_id, data.message_id)
|
|
|
|
.await?;
|
|
|
|
let _ = tx.send(message);
|
|
|
|
Ok::<_, FlowyError>(())
|
|
|
|
});
|
|
|
|
let message = rx.await?;
|
2024-06-03 14:27:28 +08:00
|
|
|
data_result_ok(message)
|
|
|
|
}
|
2024-06-09 14:02:32 +08:00
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn stop_stream_handler(
|
|
|
|
data: AFPluginData<StopStreamPB>,
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> Result<(), FlowyError> {
|
|
|
|
let data = data.into_inner();
|
|
|
|
data.validate()?;
|
|
|
|
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
chat_manager.stop_stream(&data.chat_id).await?;
|
|
|
|
Ok(())
|
|
|
|
}
|
2024-06-25 01:59:38 +02:00
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
2024-07-15 15:23:23 +08:00
|
|
|
pub(crate) async fn refresh_local_ai_info_handler(
|
2024-06-30 17:38:39 +08:00
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
2024-07-15 15:23:23 +08:00
|
|
|
) -> DataResult<LLMModelInfoPB, FlowyError> {
|
2024-06-30 17:38:39 +08:00
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
2024-07-15 15:23:23 +08:00
|
|
|
let (tx, rx) = oneshot::channel::<Result<LLMModelInfo, FlowyError>>();
|
|
|
|
tokio::spawn(async move {
|
|
|
|
let model_info = chat_manager.local_ai_controller.refresh().await;
|
|
|
|
let _ = tx.send(model_info);
|
|
|
|
});
|
|
|
|
|
|
|
|
let model_info = rx.await??;
|
|
|
|
data_result_ok(model_info.into())
|
2024-06-30 17:38:39 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
2024-07-15 15:23:23 +08:00
|
|
|
pub(crate) async fn update_local_llm_model_handler(
|
|
|
|
data: AFPluginData<LLMModelPB>,
|
2024-06-30 17:38:39 +08:00
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
2024-07-15 15:23:23 +08:00
|
|
|
) -> DataResult<LocalModelResourcePB, FlowyError> {
|
2024-06-30 17:38:39 +08:00
|
|
|
let data = data.into_inner();
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
2024-07-15 15:23:23 +08:00
|
|
|
let state = chat_manager
|
|
|
|
.local_ai_controller
|
|
|
|
.use_local_llm(data.llm_id)
|
|
|
|
.await?;
|
|
|
|
data_result_ok(state)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn get_local_llm_state_handler(
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> DataResult<LocalModelResourcePB, FlowyError> {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let state = chat_manager
|
|
|
|
.local_ai_controller
|
|
|
|
.get_local_llm_state()
|
|
|
|
.await?;
|
|
|
|
data_result_ok(state)
|
2024-06-30 17:38:39 +08:00
|
|
|
}
|
|
|
|
|
2024-06-25 01:59:38 +02:00
|
|
|
pub(crate) async fn start_complete_text_handler(
|
|
|
|
data: AFPluginData<CompleteTextPB>,
|
|
|
|
tools: AFPluginState<Arc<AITools>>,
|
|
|
|
) -> DataResult<CompleteTextTaskPB, FlowyError> {
|
|
|
|
let task = tools.create_complete_task(data.into_inner()).await?;
|
|
|
|
data_result_ok(task)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn stop_complete_text_handler(
|
|
|
|
data: AFPluginData<CompleteTextTaskPB>,
|
|
|
|
tools: AFPluginState<Arc<AITools>>,
|
|
|
|
) -> Result<(), FlowyError> {
|
|
|
|
let data = data.into_inner();
|
|
|
|
tools.cancel_complete_task(&data.task_id).await;
|
|
|
|
Ok(())
|
|
|
|
}
|
2024-07-15 15:23:23 +08:00
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn chat_file_handler(
|
|
|
|
data: AFPluginData<ChatFilePB>,
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> Result<(), FlowyError> {
|
|
|
|
let data = data.try_into_inner()?;
|
|
|
|
let file_path = PathBuf::from(&data.file_path);
|
|
|
|
let (tx, rx) = oneshot::channel::<Result<(), FlowyError>>();
|
|
|
|
tokio::spawn(async move {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
chat_manager
|
|
|
|
.chat_with_file(&data.chat_id, file_path)
|
|
|
|
.await?;
|
|
|
|
let _ = tx.send(Ok(()));
|
|
|
|
Ok::<_, FlowyError>(())
|
|
|
|
});
|
|
|
|
|
|
|
|
rx.await?
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn download_llm_resource_handler(
|
|
|
|
data: AFPluginData<DownloadLLMPB>,
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> DataResult<DownloadTaskPB, FlowyError> {
|
|
|
|
let data = data.into_inner();
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let text_sink = IsolateSink::new(Isolate::new(data.progress_stream));
|
|
|
|
let task_id = chat_manager
|
|
|
|
.local_ai_controller
|
|
|
|
.start_downloading(text_sink)
|
|
|
|
.await?;
|
|
|
|
data_result_ok(DownloadTaskPB { task_id })
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn cancel_download_llm_resource_handler(
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> Result<(), FlowyError> {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
chat_manager.local_ai_controller.cancel_download()?;
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "debug", skip_all, err)]
|
|
|
|
pub(crate) async fn get_plugin_state_handler(
|
|
|
|
chat_manager: AFPluginState<Weak<ChatManager>>,
|
|
|
|
) -> DataResult<PluginStatePB, FlowyError> {
|
|
|
|
let chat_manager = upgrade_chat_manager(chat_manager)?;
|
|
|
|
let state = chat_manager.local_ai_controller.get_plugin_state();
|
|
|
|
data_result_ok(state)
|
|
|
|
}
|