2024-08-01 23:13:35 +08:00
|
|
|
use crate::ai_manager::AIUserService;
|
2024-06-03 14:27:28 +08:00
|
|
|
use crate::entities::{
|
2025-01-08 10:43:03 +08:00
|
|
|
ChatMessageErrorPB, ChatMessageListPB, ChatMessagePB, PredefinedFormatPB,
|
|
|
|
RepeatedRelatedQuestionPB, StreamMessageParams,
|
2024-06-03 14:27:28 +08:00
|
|
|
};
|
2024-08-01 23:13:35 +08:00
|
|
|
use crate::middleware::chat_service_mw::AICloudServiceMiddleware;
|
2024-12-08 18:25:25 +08:00
|
|
|
use crate::notification::{chat_notification_builder, ChatNotification};
|
2024-12-19 14:13:53 +08:00
|
|
|
use crate::persistence::{
|
|
|
|
insert_chat_messages, select_chat_messages, select_message_where_match_reply_message_id,
|
|
|
|
ChatMessageTable,
|
|
|
|
};
|
2024-08-10 17:23:37 +08:00
|
|
|
use crate::stream_message::StreamMessage;
|
2024-06-09 14:02:32 +08:00
|
|
|
use allo_isolate::Isolate;
|
2024-08-06 07:56:13 +08:00
|
|
|
use flowy_ai_pub::cloud::{
|
2025-01-08 10:43:03 +08:00
|
|
|
ChatCloudService, ChatMessage, MessageCursor, QuestionStreamValue, ResponseFormat,
|
2024-08-06 07:56:13 +08:00
|
|
|
};
|
2024-06-03 14:27:28 +08:00
|
|
|
use flowy_error::{FlowyError, FlowyResult};
|
|
|
|
use flowy_sqlite::DBConnection;
|
2024-06-09 14:02:32 +08:00
|
|
|
use futures::{SinkExt, StreamExt};
|
|
|
|
use lib_infra::isolate_stream::IsolateSink;
|
2024-07-15 15:23:23 +08:00
|
|
|
use std::path::PathBuf;
|
2024-06-09 14:02:32 +08:00
|
|
|
use std::sync::atomic::{AtomicBool, AtomicI64};
|
2024-06-03 14:27:28 +08:00
|
|
|
use std::sync::Arc;
|
2024-06-09 14:02:32 +08:00
|
|
|
use tokio::sync::{Mutex, RwLock};
|
2024-06-03 14:27:28 +08:00
|
|
|
use tracing::{error, instrument, trace};
|
|
|
|
|
|
|
|
enum PrevMessageState {
|
|
|
|
HasMore,
|
|
|
|
NoMore,
|
|
|
|
Loading,
|
|
|
|
}
|
|
|
|
|
|
|
|
pub struct Chat {
|
|
|
|
chat_id: String,
|
|
|
|
uid: i64,
|
2024-08-01 23:13:35 +08:00
|
|
|
user_service: Arc<dyn AIUserService>,
|
|
|
|
chat_service: Arc<AICloudServiceMiddleware>,
|
2024-06-03 14:27:28 +08:00
|
|
|
prev_message_state: Arc<RwLock<PrevMessageState>>,
|
|
|
|
latest_message_id: Arc<AtomicI64>,
|
2024-06-09 14:02:32 +08:00
|
|
|
stop_stream: Arc<AtomicBool>,
|
2024-08-06 07:56:13 +08:00
|
|
|
stream_buffer: Arc<Mutex<StringBuffer>>,
|
2024-06-03 14:27:28 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Chat {
|
|
|
|
pub fn new(
|
|
|
|
uid: i64,
|
|
|
|
chat_id: String,
|
2024-08-01 23:13:35 +08:00
|
|
|
user_service: Arc<dyn AIUserService>,
|
|
|
|
chat_service: Arc<AICloudServiceMiddleware>,
|
2024-06-03 14:27:28 +08:00
|
|
|
) -> Chat {
|
|
|
|
Chat {
|
|
|
|
uid,
|
|
|
|
chat_id,
|
2024-06-30 17:38:39 +08:00
|
|
|
chat_service,
|
2024-06-03 14:27:28 +08:00
|
|
|
user_service,
|
|
|
|
prev_message_state: Arc::new(RwLock::new(PrevMessageState::HasMore)),
|
|
|
|
latest_message_id: Default::default(),
|
2024-06-09 14:02:32 +08:00
|
|
|
stop_stream: Arc::new(AtomicBool::new(false)),
|
2024-08-06 07:56:13 +08:00
|
|
|
stream_buffer: Arc::new(Mutex::new(StringBuffer::default())),
|
2024-06-03 14:27:28 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn close(&self) {}
|
|
|
|
|
|
|
|
#[allow(dead_code)]
|
|
|
|
pub async fn pull_latest_message(&self, limit: i64) {
|
|
|
|
let latest_message_id = self
|
|
|
|
.latest_message_id
|
|
|
|
.load(std::sync::atomic::Ordering::Relaxed);
|
|
|
|
if latest_message_id > 0 {
|
|
|
|
let _ = self
|
|
|
|
.load_remote_chat_messages(limit, None, Some(latest_message_id))
|
|
|
|
.await;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-09 14:02:32 +08:00
|
|
|
pub async fn stop_stream_message(&self) {
|
|
|
|
self
|
|
|
|
.stop_stream
|
|
|
|
.store(true, std::sync::atomic::Ordering::SeqCst);
|
|
|
|
}
|
|
|
|
|
2024-06-03 14:27:28 +08:00
|
|
|
#[instrument(level = "info", skip_all, err)]
|
2025-01-08 10:43:03 +08:00
|
|
|
pub async fn stream_chat_message<'a>(
|
|
|
|
&'a self,
|
|
|
|
params: &'a StreamMessageParams<'a>,
|
2024-06-09 14:02:32 +08:00
|
|
|
) -> Result<ChatMessagePB, FlowyError> {
|
2024-08-11 20:39:25 +08:00
|
|
|
trace!(
|
2025-01-08 10:43:03 +08:00
|
|
|
"[Chat] stream chat message: chat_id={}, message={}, message_type={:?}, metadata={:?}, format={:?}",
|
2024-08-11 20:39:25 +08:00
|
|
|
self.chat_id,
|
2025-01-08 10:43:03 +08:00
|
|
|
params.message,
|
|
|
|
params.message_type,
|
|
|
|
params.metadata,
|
|
|
|
params.format,
|
2024-08-11 20:39:25 +08:00
|
|
|
);
|
|
|
|
|
2024-06-09 14:02:32 +08:00
|
|
|
// clear
|
|
|
|
self
|
|
|
|
.stop_stream
|
|
|
|
.store(false, std::sync::atomic::Ordering::SeqCst);
|
2024-08-06 07:56:13 +08:00
|
|
|
self.stream_buffer.lock().await.clear();
|
2024-06-03 14:27:28 +08:00
|
|
|
|
2025-01-08 10:43:03 +08:00
|
|
|
let mut question_sink = IsolateSink::new(Isolate::new(params.question_stream_port));
|
2024-08-10 17:23:37 +08:00
|
|
|
let answer_stream_buffer = self.stream_buffer.clone();
|
2024-06-03 14:27:28 +08:00
|
|
|
let uid = self.user_service.user_id()?;
|
|
|
|
let workspace_id = self.user_service.workspace_id()?;
|
2024-06-09 14:02:32 +08:00
|
|
|
|
2024-08-10 17:23:37 +08:00
|
|
|
let _ = question_sink
|
2025-01-08 10:43:03 +08:00
|
|
|
.send(StreamMessage::Text(params.message.to_string()).to_string())
|
2024-08-10 17:23:37 +08:00
|
|
|
.await;
|
2024-06-09 14:02:32 +08:00
|
|
|
let question = self
|
2024-06-30 17:38:39 +08:00
|
|
|
.chat_service
|
2024-08-06 07:56:13 +08:00
|
|
|
.create_question(
|
|
|
|
&workspace_id,
|
|
|
|
&self.chat_id,
|
2025-01-08 10:43:03 +08:00
|
|
|
params.message,
|
|
|
|
params.message_type.clone(),
|
2025-03-14 20:53:14 +08:00
|
|
|
&[],
|
2024-08-06 07:56:13 +08:00
|
|
|
)
|
2024-06-09 14:02:32 +08:00
|
|
|
.await
|
|
|
|
.map_err(|err| {
|
|
|
|
error!("Failed to send question: {}", err);
|
|
|
|
FlowyError::server_error()
|
|
|
|
})?;
|
|
|
|
|
2024-08-10 17:23:37 +08:00
|
|
|
let _ = question_sink
|
2024-08-11 20:39:25 +08:00
|
|
|
.send(StreamMessage::MessageId(question.message_id).to_string())
|
2024-08-10 17:23:37 +08:00
|
|
|
.await;
|
2025-03-14 20:53:14 +08:00
|
|
|
|
2024-08-11 20:39:25 +08:00
|
|
|
if let Err(err) = self
|
|
|
|
.chat_service
|
2025-01-08 10:43:03 +08:00
|
|
|
.index_message_metadata(&self.chat_id, ¶ms.metadata, &mut question_sink)
|
2024-08-11 20:39:25 +08:00
|
|
|
.await
|
|
|
|
{
|
|
|
|
error!("Failed to index file: {}", err);
|
2024-08-09 21:55:20 +08:00
|
|
|
}
|
2024-08-10 17:23:37 +08:00
|
|
|
let _ = question_sink.send(StreamMessage::Done.to_string()).await;
|
2024-08-09 21:55:20 +08:00
|
|
|
|
2024-08-11 20:39:25 +08:00
|
|
|
// Save message to disk
|
2024-08-14 16:58:56 +08:00
|
|
|
save_and_notify_message(uid, &self.chat_id, &self.user_service, question.clone())?;
|
2024-06-09 14:02:32 +08:00
|
|
|
|
2025-02-06 18:10:23 +08:00
|
|
|
let format = params.format.clone().map(Into::into).unwrap_or_default();
|
2025-01-08 10:43:03 +08:00
|
|
|
|
2024-12-19 14:13:53 +08:00
|
|
|
self.stream_response(
|
2025-01-08 10:43:03 +08:00
|
|
|
params.answer_stream_port,
|
2024-12-19 14:13:53 +08:00
|
|
|
answer_stream_buffer,
|
|
|
|
uid,
|
|
|
|
workspace_id,
|
|
|
|
question.message_id,
|
2025-01-08 10:43:03 +08:00
|
|
|
format,
|
2024-12-19 14:13:53 +08:00
|
|
|
);
|
|
|
|
|
|
|
|
let question_pb = ChatMessagePB::from(question);
|
|
|
|
Ok(question_pb)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[instrument(level = "info", skip_all, err)]
|
|
|
|
pub async fn stream_regenerate_response(
|
|
|
|
&self,
|
|
|
|
question_id: i64,
|
|
|
|
answer_stream_port: i64,
|
2025-01-08 10:43:03 +08:00
|
|
|
format: Option<PredefinedFormatPB>,
|
2024-12-19 14:13:53 +08:00
|
|
|
) -> FlowyResult<()> {
|
|
|
|
trace!(
|
|
|
|
"[Chat] regenerate and stream chat message: chat_id={}",
|
|
|
|
self.chat_id,
|
|
|
|
);
|
|
|
|
|
|
|
|
// clear
|
|
|
|
self
|
|
|
|
.stop_stream
|
|
|
|
.store(false, std::sync::atomic::Ordering::SeqCst);
|
|
|
|
self.stream_buffer.lock().await.clear();
|
|
|
|
|
2025-02-06 18:10:23 +08:00
|
|
|
let format = format.map(Into::into).unwrap_or_default();
|
2025-01-08 10:43:03 +08:00
|
|
|
|
2024-12-19 14:13:53 +08:00
|
|
|
let answer_stream_buffer = self.stream_buffer.clone();
|
|
|
|
let uid = self.user_service.user_id()?;
|
|
|
|
let workspace_id = self.user_service.workspace_id()?;
|
|
|
|
|
|
|
|
self.stream_response(
|
|
|
|
answer_stream_port,
|
|
|
|
answer_stream_buffer,
|
|
|
|
uid,
|
|
|
|
workspace_id,
|
|
|
|
question_id,
|
2025-01-08 10:43:03 +08:00
|
|
|
format,
|
2024-12-19 14:13:53 +08:00
|
|
|
);
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
fn stream_response(
|
|
|
|
&self,
|
|
|
|
answer_stream_port: i64,
|
|
|
|
answer_stream_buffer: Arc<Mutex<StringBuffer>>,
|
|
|
|
uid: i64,
|
|
|
|
workspace_id: String,
|
|
|
|
question_id: i64,
|
2025-01-08 10:43:03 +08:00
|
|
|
format: ResponseFormat,
|
2024-12-19 14:13:53 +08:00
|
|
|
) {
|
2024-06-09 14:02:32 +08:00
|
|
|
let stop_stream = self.stop_stream.clone();
|
|
|
|
let chat_id = self.chat_id.clone();
|
2024-06-30 17:38:39 +08:00
|
|
|
let cloud_service = self.chat_service.clone();
|
2024-06-09 14:02:32 +08:00
|
|
|
let user_service = self.user_service.clone();
|
|
|
|
tokio::spawn(async move {
|
2024-08-10 17:23:37 +08:00
|
|
|
let mut answer_sink = IsolateSink::new(Isolate::new(answer_stream_port));
|
2024-06-09 14:02:32 +08:00
|
|
|
match cloud_service
|
2025-01-08 10:43:03 +08:00
|
|
|
.stream_answer(&workspace_id, &chat_id, question_id, format)
|
2024-06-09 14:02:32 +08:00
|
|
|
.await
|
|
|
|
{
|
|
|
|
Ok(mut stream) => {
|
|
|
|
while let Some(message) = stream.next().await {
|
|
|
|
match message {
|
2024-06-14 09:02:06 +08:00
|
|
|
Ok(message) => {
|
|
|
|
if stop_stream.load(std::sync::atomic::Ordering::Relaxed) {
|
2024-09-25 11:44:19 +08:00
|
|
|
trace!("[Chat] client stop streaming message");
|
2024-06-14 09:02:06 +08:00
|
|
|
break;
|
|
|
|
}
|
2024-08-06 07:56:13 +08:00
|
|
|
match message {
|
|
|
|
QuestionStreamValue::Answer { value } => {
|
2024-08-10 17:23:37 +08:00
|
|
|
answer_stream_buffer.lock().await.push_str(&value);
|
2025-01-08 10:43:03 +08:00
|
|
|
// trace!("[Chat] stream answer: {}", value);
|
|
|
|
if let Err(err) = answer_sink.send(format!("data:{}", value)).await {
|
2025-03-05 10:23:28 +08:00
|
|
|
error!("Failed to stream answer via IsolateSink: {}", err);
|
2025-01-08 10:43:03 +08:00
|
|
|
}
|
2024-08-06 07:56:13 +08:00
|
|
|
},
|
|
|
|
QuestionStreamValue::Metadata { value } => {
|
|
|
|
if let Ok(s) = serde_json::to_string(&value) {
|
2025-01-08 10:43:03 +08:00
|
|
|
// trace!("[Chat] stream metadata: {}", s);
|
2024-08-10 17:23:37 +08:00
|
|
|
answer_stream_buffer.lock().await.set_metadata(value);
|
|
|
|
let _ = answer_sink.send(format!("metadata:{}", s)).await;
|
2024-08-06 07:56:13 +08:00
|
|
|
}
|
|
|
|
},
|
2025-01-08 10:43:03 +08:00
|
|
|
QuestionStreamValue::KeepAlive => {
|
|
|
|
// trace!("[Chat] stream keep alive");
|
|
|
|
},
|
2024-08-06 07:56:13 +08:00
|
|
|
}
|
2024-06-09 14:02:32 +08:00
|
|
|
},
|
|
|
|
Err(err) => {
|
|
|
|
error!("[Chat] failed to stream answer: {}", err);
|
2024-08-10 17:23:37 +08:00
|
|
|
let _ = answer_sink.send(format!("error:{}", err)).await;
|
2024-06-09 14:02:32 +08:00
|
|
|
let pb = ChatMessageErrorPB {
|
|
|
|
chat_id: chat_id.clone(),
|
|
|
|
error_message: err.to_string(),
|
|
|
|
};
|
2024-12-08 18:25:25 +08:00
|
|
|
chat_notification_builder(&chat_id, ChatNotification::StreamChatMessageError)
|
2024-06-09 14:02:32 +08:00
|
|
|
.payload(pb)
|
|
|
|
.send();
|
2024-07-02 13:26:53 +08:00
|
|
|
return Err(err);
|
2024-06-09 14:02:32 +08:00
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
Err(err) => {
|
2024-09-25 11:44:19 +08:00
|
|
|
error!("[Chat] failed to start streaming: {}", err);
|
2024-07-24 14:23:09 +08:00
|
|
|
if err.is_ai_response_limit_exceeded() {
|
2024-08-10 17:23:37 +08:00
|
|
|
let _ = answer_sink.send("AI_RESPONSE_LIMIT".to_string()).await;
|
2025-01-22 09:42:24 +08:00
|
|
|
} else if err.is_ai_image_response_limit_exceeded() {
|
|
|
|
let _ = answer_sink
|
|
|
|
.send("AI_IMAGE_RESPONSE_LIMIT".to_string())
|
|
|
|
.await;
|
2025-02-03 20:52:08 +08:00
|
|
|
} else if err.is_ai_max_required() {
|
|
|
|
let _ = answer_sink
|
|
|
|
.send(format!("AI_MAX_REQUIRED:{}", err.msg))
|
|
|
|
.await;
|
2025-03-12 20:29:03 +08:00
|
|
|
} else if err.is_local_ai_not_ready() {
|
|
|
|
let _ = answer_sink
|
|
|
|
.send(format!("LOCAL_AI_NOT_READY:{}", err.msg))
|
|
|
|
.await;
|
2024-07-24 14:23:09 +08:00
|
|
|
} else {
|
2024-08-10 17:23:37 +08:00
|
|
|
let _ = answer_sink.send(format!("error:{}", err)).await;
|
2024-07-24 14:23:09 +08:00
|
|
|
}
|
|
|
|
|
2024-06-09 14:02:32 +08:00
|
|
|
let pb = ChatMessageErrorPB {
|
|
|
|
chat_id: chat_id.clone(),
|
|
|
|
error_message: err.to_string(),
|
|
|
|
};
|
2024-12-08 18:25:25 +08:00
|
|
|
chat_notification_builder(&chat_id, ChatNotification::StreamChatMessageError)
|
2024-06-09 14:02:32 +08:00
|
|
|
.payload(pb)
|
|
|
|
.send();
|
2024-07-02 13:26:53 +08:00
|
|
|
return Err(err);
|
2024-06-09 14:02:32 +08:00
|
|
|
},
|
|
|
|
}
|
2024-06-14 09:02:06 +08:00
|
|
|
|
2024-12-08 18:25:25 +08:00
|
|
|
chat_notification_builder(&chat_id, ChatNotification::FinishStreaming).send();
|
2025-01-10 09:43:18 +08:00
|
|
|
trace!("[Chat] finish streaming");
|
|
|
|
|
2024-08-10 17:23:37 +08:00
|
|
|
if answer_stream_buffer.lock().await.is_empty() {
|
2024-07-02 13:26:53 +08:00
|
|
|
return Ok(());
|
|
|
|
}
|
2024-08-10 17:23:37 +08:00
|
|
|
let content = answer_stream_buffer.lock().await.take_content();
|
|
|
|
let metadata = answer_stream_buffer.lock().await.take_metadata();
|
2024-06-14 09:02:06 +08:00
|
|
|
let answer = cloud_service
|
2025-01-20 11:47:03 +08:00
|
|
|
.create_answer(
|
|
|
|
&workspace_id,
|
|
|
|
&chat_id,
|
|
|
|
content.trim(),
|
|
|
|
question_id,
|
|
|
|
metadata,
|
|
|
|
)
|
2024-06-14 09:02:06 +08:00
|
|
|
.await?;
|
2024-08-14 16:58:56 +08:00
|
|
|
save_and_notify_message(uid, &chat_id, &user_service, answer)?;
|
2024-06-09 14:02:32 +08:00
|
|
|
Ok::<(), FlowyError>(())
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-06-03 14:27:28 +08:00
|
|
|
/// Load chat messages for a given `chat_id`.
|
|
|
|
///
|
|
|
|
/// 1. When opening a chat:
|
|
|
|
/// - Loads local chat messages.
|
|
|
|
/// - `after_message_id` and `before_message_id` are `None`.
|
|
|
|
/// - Spawns a task to load messages from the remote server, notifying the user when the remote messages are loaded.
|
|
|
|
///
|
|
|
|
/// 2. Loading more messages in an existing chat with `after_message_id`:
|
|
|
|
/// - `after_message_id` is the last message ID in the current chat messages.
|
|
|
|
///
|
|
|
|
/// 3. Loading more messages in an existing chat with `before_message_id`:
|
|
|
|
/// - `before_message_id` is the first message ID in the current chat messages.
|
|
|
|
pub async fn load_prev_chat_messages(
|
|
|
|
&self,
|
|
|
|
limit: i64,
|
|
|
|
before_message_id: Option<i64>,
|
|
|
|
) -> Result<ChatMessageListPB, FlowyError> {
|
|
|
|
trace!(
|
2024-06-09 14:02:32 +08:00
|
|
|
"[Chat] Loading messages from disk: chat_id={}, limit={}, before_message_id={:?}",
|
2024-06-03 14:27:28 +08:00
|
|
|
self.chat_id,
|
|
|
|
limit,
|
|
|
|
before_message_id
|
|
|
|
);
|
|
|
|
let messages = self
|
|
|
|
.load_local_chat_messages(limit, None, before_message_id)
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
// If the number of messages equals the limit, then no need to load more messages from remote
|
|
|
|
if messages.len() == limit as usize {
|
2024-06-09 14:02:32 +08:00
|
|
|
let pb = ChatMessageListPB {
|
2024-06-03 14:27:28 +08:00
|
|
|
messages,
|
2024-06-09 14:02:32 +08:00
|
|
|
has_more: true,
|
2024-06-03 14:27:28 +08:00
|
|
|
total: 0,
|
2024-06-09 14:02:32 +08:00
|
|
|
};
|
2024-12-08 18:25:25 +08:00
|
|
|
chat_notification_builder(&self.chat_id, ChatNotification::DidLoadPrevChatMessage)
|
2024-06-09 14:02:32 +08:00
|
|
|
.payload(pb.clone())
|
|
|
|
.send();
|
|
|
|
return Ok(pb);
|
2024-06-03 14:27:28 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if matches!(
|
|
|
|
*self.prev_message_state.read().await,
|
|
|
|
PrevMessageState::HasMore
|
|
|
|
) {
|
|
|
|
*self.prev_message_state.write().await = PrevMessageState::Loading;
|
|
|
|
if let Err(err) = self
|
|
|
|
.load_remote_chat_messages(limit, before_message_id, None)
|
|
|
|
.await
|
|
|
|
{
|
|
|
|
error!("Failed to load previous chat messages: {}", err);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(ChatMessageListPB {
|
|
|
|
messages,
|
2024-06-09 14:02:32 +08:00
|
|
|
has_more: true,
|
2024-06-03 14:27:28 +08:00
|
|
|
total: 0,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn load_latest_chat_messages(
|
|
|
|
&self,
|
|
|
|
limit: i64,
|
|
|
|
after_message_id: Option<i64>,
|
|
|
|
) -> Result<ChatMessageListPB, FlowyError> {
|
|
|
|
trace!(
|
2024-06-09 14:02:32 +08:00
|
|
|
"[Chat] Loading new messages: chat_id={}, limit={}, after_message_id={:?}",
|
2024-06-03 14:27:28 +08:00
|
|
|
self.chat_id,
|
|
|
|
limit,
|
|
|
|
after_message_id,
|
|
|
|
);
|
|
|
|
let messages = self
|
|
|
|
.load_local_chat_messages(limit, after_message_id, None)
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
trace!(
|
2024-06-09 14:02:32 +08:00
|
|
|
"[Chat] Loaded local chat messages: chat_id={}, messages={}",
|
2024-06-03 14:27:28 +08:00
|
|
|
self.chat_id,
|
|
|
|
messages.len()
|
|
|
|
);
|
|
|
|
|
|
|
|
// If the number of messages equals the limit, then no need to load more messages from remote
|
|
|
|
let has_more = !messages.is_empty();
|
|
|
|
let _ = self
|
|
|
|
.load_remote_chat_messages(limit, None, after_message_id)
|
|
|
|
.await;
|
|
|
|
Ok(ChatMessageListPB {
|
|
|
|
messages,
|
|
|
|
has_more,
|
|
|
|
total: 0,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn load_remote_chat_messages(
|
|
|
|
&self,
|
|
|
|
limit: i64,
|
|
|
|
before_message_id: Option<i64>,
|
|
|
|
after_message_id: Option<i64>,
|
|
|
|
) -> FlowyResult<()> {
|
|
|
|
trace!(
|
2024-06-09 14:02:32 +08:00
|
|
|
"[Chat] start loading messages from remote: chat_id={}, limit={}, before_message_id={:?}, after_message_id={:?}",
|
2024-06-03 14:27:28 +08:00
|
|
|
self.chat_id,
|
|
|
|
limit,
|
|
|
|
before_message_id,
|
|
|
|
after_message_id
|
|
|
|
);
|
|
|
|
let workspace_id = self.user_service.workspace_id()?;
|
|
|
|
let chat_id = self.chat_id.clone();
|
2024-06-30 17:38:39 +08:00
|
|
|
let cloud_service = self.chat_service.clone();
|
2024-06-03 14:27:28 +08:00
|
|
|
let user_service = self.user_service.clone();
|
|
|
|
let uid = self.uid;
|
|
|
|
let prev_message_state = self.prev_message_state.clone();
|
|
|
|
let latest_message_id = self.latest_message_id.clone();
|
|
|
|
tokio::spawn(async move {
|
|
|
|
let cursor = match (before_message_id, after_message_id) {
|
|
|
|
(Some(bid), _) => MessageCursor::BeforeMessageId(bid),
|
|
|
|
(_, Some(aid)) => MessageCursor::AfterMessageId(aid),
|
|
|
|
_ => MessageCursor::NextBack,
|
|
|
|
};
|
|
|
|
match cloud_service
|
|
|
|
.get_chat_messages(&workspace_id, &chat_id, cursor.clone(), limit as u64)
|
|
|
|
.await
|
|
|
|
{
|
|
|
|
Ok(resp) => {
|
|
|
|
// Save chat messages to local disk
|
2024-08-14 19:44:15 +08:00
|
|
|
if let Err(err) = save_chat_message_disk(
|
2024-06-03 14:27:28 +08:00
|
|
|
user_service.sqlite_connection(uid)?,
|
|
|
|
&chat_id,
|
|
|
|
resp.messages.clone(),
|
|
|
|
) {
|
|
|
|
error!("Failed to save chat:{} messages: {}", chat_id, err);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Update latest message ID
|
|
|
|
if !resp.messages.is_empty() {
|
|
|
|
latest_message_id.store(
|
|
|
|
resp.messages[0].message_id,
|
|
|
|
std::sync::atomic::Ordering::Relaxed,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
let pb = ChatMessageListPB::from(resp);
|
|
|
|
trace!(
|
2024-06-09 14:02:32 +08:00
|
|
|
"[Chat] Loaded messages from remote: chat_id={}, messages={}, hasMore: {}, cursor:{:?}",
|
2024-06-03 14:27:28 +08:00
|
|
|
chat_id,
|
2024-06-09 14:02:32 +08:00
|
|
|
pb.messages.len(),
|
|
|
|
pb.has_more,
|
|
|
|
cursor,
|
2024-06-03 14:27:28 +08:00
|
|
|
);
|
|
|
|
if matches!(cursor, MessageCursor::BeforeMessageId(_)) {
|
|
|
|
if pb.has_more {
|
|
|
|
*prev_message_state.write().await = PrevMessageState::HasMore;
|
|
|
|
} else {
|
|
|
|
*prev_message_state.write().await = PrevMessageState::NoMore;
|
|
|
|
}
|
2024-12-08 18:25:25 +08:00
|
|
|
chat_notification_builder(&chat_id, ChatNotification::DidLoadPrevChatMessage)
|
2024-06-03 14:27:28 +08:00
|
|
|
.payload(pb)
|
|
|
|
.send();
|
|
|
|
} else {
|
2024-12-08 18:25:25 +08:00
|
|
|
chat_notification_builder(&chat_id, ChatNotification::DidLoadLatestChatMessage)
|
2024-06-03 14:27:28 +08:00
|
|
|
.payload(pb)
|
|
|
|
.send();
|
|
|
|
}
|
|
|
|
},
|
|
|
|
Err(err) => error!("Failed to load chat messages: {}", err),
|
|
|
|
}
|
|
|
|
Ok::<(), FlowyError>(())
|
|
|
|
});
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
2024-12-19 14:13:53 +08:00
|
|
|
pub async fn get_question_id_from_answer_id(
|
|
|
|
&self,
|
|
|
|
answer_message_id: i64,
|
|
|
|
) -> Result<i64, FlowyError> {
|
|
|
|
let conn = self.user_service.sqlite_connection(self.uid)?;
|
|
|
|
|
|
|
|
let local_result = select_message_where_match_reply_message_id(conn, answer_message_id)?
|
|
|
|
.map(|message| message.message_id);
|
|
|
|
|
|
|
|
if let Some(message_id) = local_result {
|
|
|
|
return Ok(message_id);
|
|
|
|
}
|
|
|
|
|
|
|
|
let workspace_id = self.user_service.workspace_id()?;
|
|
|
|
let chat_id = self.chat_id.clone();
|
|
|
|
let cloud_service = self.chat_service.clone();
|
|
|
|
|
|
|
|
let question = cloud_service
|
|
|
|
.get_question_from_answer_id(&workspace_id, &chat_id, answer_message_id)
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
Ok(question.message_id)
|
|
|
|
}
|
|
|
|
|
2024-06-03 14:27:28 +08:00
|
|
|
pub async fn get_related_question(
|
|
|
|
&self,
|
|
|
|
message_id: i64,
|
|
|
|
) -> Result<RepeatedRelatedQuestionPB, FlowyError> {
|
|
|
|
let workspace_id = self.user_service.workspace_id()?;
|
|
|
|
let resp = self
|
2024-06-30 17:38:39 +08:00
|
|
|
.chat_service
|
2024-06-03 14:27:28 +08:00
|
|
|
.get_related_message(&workspace_id, &self.chat_id, message_id)
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
trace!(
|
2024-06-09 14:02:32 +08:00
|
|
|
"[Chat] related messages: chat_id={}, message_id={}, messages:{:?}",
|
2024-06-03 14:27:28 +08:00
|
|
|
self.chat_id,
|
|
|
|
message_id,
|
|
|
|
resp.items
|
|
|
|
);
|
|
|
|
Ok(RepeatedRelatedQuestionPB::from(resp))
|
|
|
|
}
|
|
|
|
|
|
|
|
#[instrument(level = "debug", skip_all, err)]
|
|
|
|
pub async fn generate_answer(&self, question_message_id: i64) -> FlowyResult<ChatMessagePB> {
|
2024-06-09 14:02:32 +08:00
|
|
|
trace!(
|
|
|
|
"[Chat] generate answer: chat_id={}, question_message_id={}",
|
|
|
|
self.chat_id,
|
|
|
|
question_message_id
|
|
|
|
);
|
2024-06-03 14:27:28 +08:00
|
|
|
let workspace_id = self.user_service.workspace_id()?;
|
2024-06-09 14:02:32 +08:00
|
|
|
let answer = self
|
2024-06-30 17:38:39 +08:00
|
|
|
.chat_service
|
2024-08-06 07:56:13 +08:00
|
|
|
.get_answer(&workspace_id, &self.chat_id, question_message_id)
|
2024-06-03 14:27:28 +08:00
|
|
|
.await?;
|
|
|
|
|
2024-08-14 16:58:56 +08:00
|
|
|
save_and_notify_message(self.uid, &self.chat_id, &self.user_service, answer.clone())?;
|
2024-06-09 14:02:32 +08:00
|
|
|
let pb = ChatMessagePB::from(answer);
|
2024-06-03 14:27:28 +08:00
|
|
|
Ok(pb)
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn load_local_chat_messages(
|
|
|
|
&self,
|
|
|
|
limit: i64,
|
|
|
|
after_message_id: Option<i64>,
|
|
|
|
before_message_id: Option<i64>,
|
|
|
|
) -> Result<Vec<ChatMessagePB>, FlowyError> {
|
|
|
|
let conn = self.user_service.sqlite_connection(self.uid)?;
|
|
|
|
let records = select_chat_messages(
|
|
|
|
conn,
|
|
|
|
&self.chat_id,
|
|
|
|
limit,
|
|
|
|
after_message_id,
|
|
|
|
before_message_id,
|
|
|
|
)?;
|
|
|
|
let messages = records
|
|
|
|
.into_iter()
|
|
|
|
.map(|record| ChatMessagePB {
|
|
|
|
message_id: record.message_id,
|
|
|
|
content: record.content,
|
|
|
|
created_at: record.created_at,
|
|
|
|
author_type: record.author_type,
|
|
|
|
author_id: record.author_id,
|
|
|
|
reply_message_id: record.reply_message_id,
|
2024-08-06 07:56:13 +08:00
|
|
|
metadata: record.metadata,
|
2024-06-03 14:27:28 +08:00
|
|
|
})
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
Ok(messages)
|
|
|
|
}
|
2024-07-15 15:23:23 +08:00
|
|
|
|
|
|
|
#[instrument(level = "debug", skip_all, err)]
|
|
|
|
pub async fn index_file(&self, file_path: PathBuf) -> FlowyResult<()> {
|
|
|
|
if !file_path.exists() {
|
|
|
|
return Err(
|
|
|
|
FlowyError::record_not_found().with_context(format!("{:?} not exist", file_path)),
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
if !file_path.is_file() {
|
|
|
|
return Err(
|
|
|
|
FlowyError::invalid_data().with_context(format!("{:?} is not a file ", file_path)),
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
trace!(
|
|
|
|
"[Chat] index file: chat_id={}, file_path={:?}",
|
|
|
|
self.chat_id,
|
|
|
|
file_path
|
|
|
|
);
|
|
|
|
self
|
|
|
|
.chat_service
|
2025-03-13 10:53:20 +08:00
|
|
|
.embed_file(
|
2024-08-08 12:07:00 +08:00
|
|
|
&self.user_service.workspace_id()?,
|
|
|
|
&file_path,
|
|
|
|
&self.chat_id,
|
2024-08-12 15:43:17 +08:00
|
|
|
None,
|
2024-08-08 12:07:00 +08:00
|
|
|
)
|
2024-07-15 15:23:23 +08:00
|
|
|
.await?;
|
|
|
|
|
2024-08-08 12:07:00 +08:00
|
|
|
trace!(
|
|
|
|
"[Chat] created index file record: chat_id={}, file_path={:?}",
|
|
|
|
self.chat_id,
|
|
|
|
file_path
|
|
|
|
);
|
|
|
|
|
2024-07-15 15:23:23 +08:00
|
|
|
Ok(())
|
|
|
|
}
|
2024-06-03 14:27:28 +08:00
|
|
|
}
|
|
|
|
|
2024-08-14 19:44:15 +08:00
|
|
|
fn save_chat_message_disk(
|
2024-06-03 14:27:28 +08:00
|
|
|
conn: DBConnection,
|
|
|
|
chat_id: &str,
|
|
|
|
messages: Vec<ChatMessage>,
|
|
|
|
) -> FlowyResult<()> {
|
|
|
|
let records = messages
|
|
|
|
.into_iter()
|
|
|
|
.map(|message| ChatMessageTable {
|
|
|
|
message_id: message.message_id,
|
|
|
|
chat_id: chat_id.to_string(),
|
|
|
|
content: message.content,
|
|
|
|
created_at: message.created_at.timestamp(),
|
|
|
|
author_type: message.author.author_type as i64,
|
|
|
|
author_id: message.author.author_id.to_string(),
|
|
|
|
reply_message_id: message.reply_message_id,
|
2024-08-06 07:56:13 +08:00
|
|
|
metadata: Some(serde_json::to_string(&message.meta_data).unwrap_or_default()),
|
2024-06-03 14:27:28 +08:00
|
|
|
})
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
insert_chat_messages(conn, &records)?;
|
|
|
|
Ok(())
|
|
|
|
}
|
2024-08-06 07:56:13 +08:00
|
|
|
|
|
|
|
#[derive(Debug, Default)]
|
|
|
|
struct StringBuffer {
|
|
|
|
content: String,
|
|
|
|
metadata: Option<serde_json::Value>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl StringBuffer {
|
|
|
|
fn clear(&mut self) {
|
|
|
|
self.content.clear();
|
|
|
|
self.metadata = None;
|
|
|
|
}
|
|
|
|
|
|
|
|
fn push_str(&mut self, value: &str) {
|
|
|
|
self.content.push_str(value);
|
|
|
|
}
|
|
|
|
|
|
|
|
fn set_metadata(&mut self, value: serde_json::Value) {
|
|
|
|
self.metadata = Some(value);
|
|
|
|
}
|
|
|
|
|
|
|
|
fn is_empty(&self) -> bool {
|
|
|
|
self.content.is_empty()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn take_metadata(&mut self) -> Option<serde_json::Value> {
|
|
|
|
self.metadata.take()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn take_content(&mut self) -> String {
|
|
|
|
std::mem::take(&mut self.content)
|
|
|
|
}
|
|
|
|
}
|
2024-08-14 16:58:56 +08:00
|
|
|
|
|
|
|
pub(crate) fn save_and_notify_message(
|
|
|
|
uid: i64,
|
|
|
|
chat_id: &str,
|
|
|
|
user_service: &Arc<dyn AIUserService>,
|
|
|
|
message: ChatMessage,
|
|
|
|
) -> Result<(), FlowyError> {
|
|
|
|
trace!("[Chat] save answer: answer={:?}", message);
|
2024-08-14 19:44:15 +08:00
|
|
|
save_chat_message_disk(
|
2024-08-14 16:58:56 +08:00
|
|
|
user_service.sqlite_connection(uid)?,
|
|
|
|
chat_id,
|
|
|
|
vec![message.clone()],
|
|
|
|
)?;
|
|
|
|
let pb = ChatMessagePB::from(message);
|
2024-12-08 18:25:25 +08:00
|
|
|
chat_notification_builder(chat_id, ChatNotification::DidReceiveChatMessage)
|
2024-08-14 16:58:56 +08:00
|
|
|
.payload(pb)
|
|
|
|
.send();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|