diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index 7d8b48e11a..56316ca71a 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -217,7 +217,6 @@ impl AIManager { let summary = select_chat_summary(&mut conn, chat_id).unwrap_or_default(); let model = self.get_active_model(&chat_id.to_string()).await; - trace!("[AI Plugin] notify open chat: {}", chat_id); self .local_ai .open_chat(&workspace_id, chat_id, &model.name, rag_ids, summary) @@ -240,7 +239,7 @@ impl AIManager { .await { Ok(settings) => { - local_ai.set_rag_ids(&chat_id, &settings.rag_ids); + local_ai.set_rag_ids(&chat_id, &settings.rag_ids).await; let rag_ids = settings .rag_ids .into_iter() @@ -712,7 +711,7 @@ impl AIManager { let user_service = self.user_service.clone(); let external_service = self.external_service.clone(); - self.local_ai.set_rag_ids(chat_id, &rag_ids); + self.local_ai.set_rag_ids(chat_id, &rag_ids).await; let rag_ids = rag_ids .into_iter() diff --git a/frontend/rust-lib/flowy-ai/src/embeddings/store.rs b/frontend/rust-lib/flowy-ai/src/embeddings/store.rs index 75054055a0..031d0fb7b8 100644 --- a/frontend/rust-lib/flowy-ai/src/embeddings/store.rs +++ b/frontend/rust-lib/flowy-ai/src/embeddings/store.rs @@ -1,6 +1,7 @@ use crate::embeddings::document_indexer::split_text_into_chunks; use crate::embeddings::embedder::{Embedder, OllamaEmbedder}; use crate::embeddings::indexer::{EmbeddingModel, IndexerProvider}; +use crate::local_ai::chat::retriever::MultipleSourceRetrieverStore; use async_trait::async_trait; use flowy_ai_pub::cloud::CollabType; use flowy_ai_pub::entities::{RAG_IDS, SOURCE_ID}; @@ -9,10 +10,8 @@ use flowy_sqlite_vec::db::VectorSqliteDB; use flowy_sqlite_vec::entities::{EmbeddedContent, SqliteEmbeddedDocument}; use futures::stream::{self, StreamExt}; use langchain_rust::llm::client::OllamaClient; -use langchain_rust::{ - schemas::Document, - vectorstore::{VecStoreOptions, VectorStore}, -}; +use langchain_rust::schemas::Document; +use langchain_rust::vectorstore::{VecStoreOptions, VectorStore}; use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; use serde_json::Value; use std::collections::HashMap; @@ -86,6 +85,80 @@ impl SqliteVectorStore { } } +#[async_trait] +impl MultipleSourceRetrieverStore for SqliteVectorStore { + fn retriever_name(&self) -> &'static str { + "Sqlite Multiple Source Retriever" + } + + async fn read_documents( + &self, + workspace_id: &Uuid, + query: &str, + limit: usize, + rag_ids: &[String], + score_threshold: f32, + _full_search: bool, + ) -> FlowyResult> { + let vector_db = match self.vector_db.upgrade() { + Some(db) => db, + None => return Err(FlowyError::internal().with_context("Vector database not initialized")), + }; + + // Create embedder and generate embedding for query + let embedder = self.create_embedder()?; + let request = GenerateEmbeddingsRequest::new( + embedder.model().name().to_string(), + EmbeddingsInput::Single(query.to_string()), + ); + + let embedding = embedder.embed(request).await?.embeddings; + if embedding.is_empty() { + return Ok(Vec::new()); + } + + debug_assert!(embedding.len() == 1); + let query_embedding = embedding.first().unwrap(); + + // Perform similarity search in the database + let results = vector_db + .search_with_score( + &workspace_id.to_string(), + rag_ids, + query_embedding, + limit as i32, + score_threshold, + ) + .await?; + + trace!( + "[VectorStore] Found {} results for query:{}, rag_ids: {:?}, score_threshold: {}", + results.len(), + query, + rag_ids, + score_threshold + ); + + // Convert results to Documents + let documents = results + .into_iter() + .map(|result| { + let mut metadata = HashMap::new(); + + if let Some(map) = result.metadata.as_ref().and_then(|v| v.as_object()) { + for (key, value) in map { + metadata.insert(key.clone(), value.clone()); + } + } + + Document::new(result.content).with_metadata(metadata) + }) + .collect(); + + Ok(documents) + } +} + #[async_trait] impl VectorStore for SqliteVectorStore { type Options = VecStoreOptions; @@ -215,74 +288,23 @@ impl VectorStore for SqliteVectorStore { // Return empty result if workspace_id is missing let workspace_id = match workspace_id { - Some(id) => id.to_string(), + Some(id) => id, None => { warn!("[VectorStore] Missing workspace_id in filters. Returning empty result."); return Ok(Vec::new()); }, }; - // Get the vector database - let vector_db = match self.vector_db.upgrade() { - Some(db) => db, - None => return Err("Vector database not initialized".into()), - }; - - // Create embedder and generate embedding for query - let embedder = self.create_embedder()?; - let request = GenerateEmbeddingsRequest::new( - embedder.model().name().to_string(), - EmbeddingsInput::Single(query.to_string()), - ); - - let embedding = match embedder.embed(request).await { - Ok(result) => result.embeddings, - Err(e) => return Err(Box::new(e)), - }; - - if embedding.is_empty() { - return Ok(Vec::new()); - } - - let score_threshold = opt.score_threshold.unwrap_or(0.4); - debug_assert!(embedding.len() == 1); - let query_embedding = embedding.first().unwrap(); - - // Perform similarity search in the database - let results = vector_db - .search_with_score( + self + .read_documents( &workspace_id, + query, + limit, &rag_ids, - query_embedding, - limit as i32, - score_threshold, + opt.score_threshold.unwrap_or(0.4), + true, ) - .await?; - - trace!( - "[VectorStore] Found {} results for query:{}, rag_ids: {:?}, score_threshold: {}", - results.len(), - query, - rag_ids, - score_threshold - ); - - // Convert results to Documents - let documents = results - .into_iter() - .map(|result| { - let mut metadata = HashMap::new(); - - if let Some(map) = result.metadata.as_ref().and_then(|v| v.as_object()) { - for (key, value) in map { - metadata.insert(key.clone(), value.clone()); - } - } - - Document::new(result.content).with_metadata(metadata) - }) - .collect(); - - Ok(documents) + .await + .map_err(|err| Box::new(err) as Box) } } diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/context_question_chain.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/context_question_chain.rs index e98877a73c..7886007030 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/context_question_chain.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/context_question_chain.rs @@ -1,6 +1,7 @@ use crate::local_ai::chat::llm::LLMOllama; use crate::SqliteVectorStore; use flowy_error::{FlowyError, FlowyResult}; +use flowy_sqlite_vec::entities::EmbeddedContent; use langchain_rust::language_models::llm::LLM; use langchain_rust::prompt::TemplateFormat; use langchain_rust::prompt::{PromptFromatter, PromptTemplate}; @@ -10,6 +11,7 @@ use ollama_rs::generation::parameters::{FormatType, JsonStructure}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; +use std::fmt::Debug; use tracing::trace; use uuid::Uuid; @@ -60,13 +62,13 @@ pub struct ContextQuestion { pub object_id: String, } -pub struct RelatedQuestionChain { +pub struct ContextRelatedQuestionChain { workspace_id: Uuid, llm: LLMOllama, store: SqliteVectorStore, } -impl RelatedQuestionChain { +impl ContextRelatedQuestionChain { pub fn new(workspace_id: Uuid, ollama: LLMOllama, store: SqliteVectorStore) -> Self { let format = FormatType::StructuredJson(JsonStructure::new::()); Self { @@ -76,25 +78,16 @@ impl RelatedQuestionChain { } } - pub async fn generate_questions(&self, rag_ids: &[String]) -> FlowyResult> { - trace!( - "[embedding] Generating context related questions for RAG IDs: {:?}", - rag_ids - ); - - let context = self - .store - .select_all_embedded_content(&self.workspace_id.to_string(), rag_ids, 3) - .await?; - - trace!( - "[embedding] Generating related questions base on: {:?}", - context, - ); - - let context_str = json!(context).to_string(); + pub async fn generate_questions_from_context( + &self, + rag_ids: &[T], + context: &str, + ) -> FlowyResult> + where + T: AsRef, + { let input_variables = prompt_args! { - "context" => context_str, + "context" => context, }; let template = PromptTemplate::new( @@ -116,8 +109,42 @@ impl RelatedQuestionChain { // filter out questions that are not in the rag_ids parsed_result .questions - .retain(|v| rag_ids.contains(&v.object_id)); + .retain(|v| rag_ids.iter().any(|id| id.as_ref() == v.object_id)); Ok(parsed_result.questions) } + + pub async fn generate_questions( + &self, + rag_ids: &[T], + ) -> FlowyResult<(String, Vec)> + where + T: AsRef + Debug, + { + trace!( + "[embedding] Generating context related questions for RAG IDs: {:?}", + rag_ids + ); + + let rag_ids_str: Vec = rag_ids.iter().map(|id| id.as_ref().to_string()).collect(); + let context = self + .store + .select_all_embedded_content(&self.workspace_id.to_string(), &rag_ids_str, 3) + .await?; + + trace!( + "[embedding] Generating related questions base on: {:?}", + context, + ); + + let context_str = embedded_documents_to_context_str(context); + self + .generate_questions_from_context(rag_ids, &context_str) + .await + .map(|questions| (context_str, questions)) + } +} + +pub fn embedded_documents_to_context_str(documents: Vec) -> String { + json!(documents).to_string() } diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/conversation_chain.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/conversation_chain.rs index d48860d563..5e1ebfb6c5 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/conversation_chain.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/conversation_chain.rs @@ -1,11 +1,17 @@ -use crate::local_ai::chat::chains::context_question_chain::RelatedQuestionChain; +use crate::local_ai::chat::chains::context_question_chain::{ + embedded_documents_to_context_str, ContextRelatedQuestionChain, +}; +use crate::local_ai::chat::chains::related_question_chain::RelatedQuestionChain; use crate::local_ai::chat::llm::LLMOllama; +use crate::local_ai::chat::retriever::AFRetriever; use crate::SqliteVectorStore; +use arc_swap::ArcSwap; use async_stream::stream; use async_trait::async_trait; use flowy_ai_pub::cloud::{ContextSuggestedQuestion, QuestionStreamValue}; -use flowy_ai_pub::entities::{RAG_IDS, SOURCE_ID}; +use flowy_ai_pub::entities::SOURCE_ID; use flowy_error::{FlowyError, FlowyResult}; +use flowy_sqlite_vec::entities::EmbeddedContent; use futures::Stream; use futures_util::{pin_mut, StreamExt}; use langchain_rust::chain::{ @@ -15,11 +21,9 @@ use langchain_rust::chain::{ use langchain_rust::language_models::{GenerateResult, TokenUsage}; use langchain_rust::memory::SimpleMemory; use langchain_rust::prompt::{FormatPrompter, PromptArgs}; -use langchain_rust::schemas::{BaseMemory, Document, Message, Retriever, StreamData}; -use langchain_rust::vectorstore::{VecStoreOptions, VectorStore}; +use langchain_rust::schemas::{BaseMemory, Document, Message, StreamData}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::error::Error; use std::{collections::HashMap, pin::Pin, sync::Arc}; use tokio::sync::Mutex; use tokio_util::either::Either; @@ -37,15 +41,16 @@ const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY: &str = "question"; pub struct ConversationalRetrieverChain { pub(crate) ollama: LLMOllama, - pub(crate) retriever: AFRetriever, + pub(crate) retriever: Box, pub memory: Arc>, pub(crate) combine_documents_chain: Box, pub(crate) condense_question_chain: Box, - pub(crate) context_question_chain: Option, + pub(crate) context_question_chain: Option, pub(crate) rephrase_question: bool, pub(crate) return_source_documents: bool, pub(crate) input_key: String, pub(crate) output_key: String, + latest_context: ArcSwap, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -91,6 +96,27 @@ impl ConversationalRetrieverChain { Ok((question, token_usage)) } + pub async fn get_related_questions(&self, question: &str) -> Result, FlowyError> { + let context = self.latest_context.load(); + let rag_ids = self.retriever.get_rag_ids(); + + if context.is_empty() { + trace!("[Chat] No context available. Generating related questions"); + let chain = RelatedQuestionChain::new(self.ollama.clone()); + chain.generate_related_question(question).await + } else if let Some(c) = self.context_question_chain.as_ref() { + trace!( + "[Chat] found context:{}. Generating context related questions", + context + ); + c.generate_questions_from_context(&rag_ids, &context) + .await + .map(|questions| questions.into_iter().map(|q| q.content).collect()) + } else { + Ok(vec![]) + } + } + async fn get_documents_or_result( &self, question: &str, @@ -101,7 +127,7 @@ impl ConversationalRetrieverChain { } else { let documents = self .retriever - .get_relevant_documents(question) + .retrieve_documents(question) .await .map_err(|e| ChainError::RetrieverError(e.to_string()))?; @@ -115,7 +141,8 @@ impl ConversationalRetrieverChain { if let Some(c) = self.context_question_chain.as_ref() { let rag_ids = rag_ids.iter().map(|v| v.to_string()).collect::>(); match c.generate_questions(&rag_ids).await { - Ok(questions) => { + Ok((context, questions)) => { + self.latest_context.store(Arc::new(context)); trace!("[embedding]: context related questions: {:?}", questions); suggested_questions = questions .into_iter() @@ -134,7 +161,7 @@ impl ConversationalRetrieverChain { } } - return if suggested_questions.is_empty() { + if suggested_questions.is_empty() { Ok(Either::Right(StreamValue::ContextSuggested { value: CAN_NOT_ANSWER_WITH_CONTEXT.to_string(), suggested_questions, @@ -144,10 +171,27 @@ impl ConversationalRetrieverChain { value: ANSWER_WITH_SUGGESTED_QUESTION.to_string(), suggested_questions, })) - }; - } + } + } else { + let embedded_docs = documents + .iter() + .flat_map(|d| { + let object_id = d + .metadata + .get(SOURCE_ID) + .and_then(|v| v.as_str().map(|v| v.to_string()))?; + Some(EmbeddedContent { + content: d.page_content.clone(), + object_id, + }) + }) + .collect::>(); - Ok(Either::Left(documents)) + let context = embedded_documents_to_context_str(embedded_docs); + self.latest_context.store(Arc::new(context)); + + Ok(Either::Left(documents)) + } } } } @@ -426,7 +470,7 @@ impl Chain for ConversationalRetrieverChain { pub struct ConversationalRetrieverChainBuilder { workspace_id: Uuid, llm: LLMOllama, - retriever: AFRetriever, + retriever: Box, memory: Option>>, prompt: Option>, rephrase_question: bool, @@ -439,7 +483,7 @@ impl ConversationalRetrieverChainBuilder { pub fn new( workspace_id: Uuid, llm: LLMOllama, - retriever: AFRetriever, + retriever: Box, store: Option, ) -> Self { ConversationalRetrieverChainBuilder { @@ -496,7 +540,7 @@ impl ConversationalRetrieverChainBuilder { let context_question_chain = self .store - .map(|store| RelatedQuestionChain::new(self.workspace_id, self.llm.clone(), store)); + .map(|store| ContextRelatedQuestionChain::new(self.workspace_id, self.llm.clone(), store)); Ok(ConversationalRetrieverChain { ollama: self.llm, @@ -509,67 +553,11 @@ impl ConversationalRetrieverChainBuilder { return_source_documents: self.return_source_documents, input_key: self.input_key, output_key: self.output_key, + latest_context: Default::default(), }) } } -// Retriever is a retriever for vector stores. -pub type RetrieverOption = VecStoreOptions; -pub struct AFRetriever { - vector_store: Option>>, - num_docs: usize, - options: RetrieverOption, -} -impl AFRetriever { - pub fn new>>>( - vector_store: Option, - num_docs: usize, - options: RetrieverOption, - ) -> Self { - AFRetriever { - vector_store: vector_store.map(Into::into), - num_docs, - options, - } - } - pub fn set_rag_ids(&mut self, new_rag_ids: Vec) { - trace!("[VectorStore] retriever {:p}", self); - let filters = self.options.filters.get_or_insert_with(|| json!({})); - filters[RAG_IDS] = json!(new_rag_ids); - } - - pub fn get_rag_ids(&self) -> Vec<&str> { - trace!("[VectorStore] retriever {:p}", self); - self - .options - .filters - .as_ref() - .and_then(|filters| filters.get(RAG_IDS).and_then(|rag_ids| rag_ids.as_array())) - .map(|rag_ids| rag_ids.iter().filter_map(|id| id.as_str()).collect()) - .unwrap_or_default() - } -} - -#[async_trait] -impl Retriever for AFRetriever { - async fn get_relevant_documents(&self, query: &str) -> Result, Box> { - trace!( - "[VectorStore] filters: {:?}, retrieving documents for query: {}", - self.options.filters, - query, - ); - - match self.vector_store.as_ref() { - None => Ok(vec![]), - Some(vector_store) => { - vector_store - .similarity_search(query, self.num_docs, &self.options) - .await - }, - } - } -} - /// Deduplicates metadata from a list of documents by merging metadata entries with the same keys fn deduplicate_metadata(documents: &[Document]) -> Vec { let mut merged_metadata: HashMap = HashMap::new(); diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/related_question_chain.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/related_question_chain.rs index 73a48b0cb5..b326c61c62 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/related_question_chain.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/related_question_chain.rs @@ -6,17 +6,9 @@ use ollama_rs::generation::parameters::{FormatType, JsonStructure}; use schemars::JsonSchema; use serde::Deserialize; -const SUMMARIZE_SYSTEM_PROMPT: &str = r#" -As an AppFlowy AI assistant, your task is to generate three medium-length, relevant, and informative questions based on the provided conversation history. -The output should only return a JSON instance that conforms to the JSON schema below. - -{ - "questions": [ - "What are the key skills needed to tackle a black diamond slope in snowboarding?", - "How does the difficulty of black diamond trails compare across different ski resorts?", - "Can you provide tips for snowboarders preparing to try a black diamond trail for the first time?" - ] -} +const SYSTEM_PROMPT: &str = r#" +You are the AppFlowy AI assistant. Given the conversation history, generate exactly three medium-length, relevant, and informative questions. +Respond with a single JSON object matching the schema below—and nothing else. If you can’t generate questions, return {}. "#; #[derive(Debug, Deserialize, JsonSchema)] @@ -36,9 +28,9 @@ impl RelatedQuestionChain { } } - pub async fn related_question(&self, question: &str) -> FlowyResult> { + pub async fn generate_related_question(&self, question: &str) -> FlowyResult> { let messages = vec![ - Message::new_system_message(SUMMARIZE_SYSTEM_PROMPT), + Message::new_system_message(SYSTEM_PROMPT), Message::new_human_message(question), ]; diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/format_prompt.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/format_prompt.rs index 17683e9d08..2813f6d696 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/chat/format_prompt.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/format_prompt.rs @@ -9,7 +9,7 @@ use langchain_rust::template_jinja2; use std::sync::{Arc, RwLock}; const QA_CONTEXT_TEMPLATE: &str = r#" -Only Use the context provided below to formulate your answer. Do not use any other information. If the context doesn't contain sufficient information to answer the question, respond with "I don't know". +Only Use the context provided below to formulate your answer. Do not use any other information. Do not reference external knowledge or information outside the context. ##Context## diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/llm_chat.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/llm_chat.rs index 289e31d389..3e3ba3da47 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/chat/llm_chat.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/llm_chat.rs @@ -1,9 +1,11 @@ use crate::local_ai::chat::chains::conversation_chain::{ - AFRetriever, ConversationalRetrieverChain, ConversationalRetrieverChainBuilder, RetrieverOption, + ConversationalRetrieverChain, ConversationalRetrieverChainBuilder, }; -use crate::local_ai::chat::chains::related_question_chain::RelatedQuestionChain; use crate::local_ai::chat::format_prompt::AFContextPrompt; use crate::local_ai::chat::llm::LLMOllama; +use crate::local_ai::chat::retriever::multi_source_retriever::MultipleSourceRetriever; +use crate::local_ai::chat::retriever::sqlite_retriever::RetrieverOption; +use crate::local_ai::chat::retriever::{AFRetriever, MultipleSourceRetrieverStore}; use crate::local_ai::chat::summary_memory::SummaryMemory; use crate::local_ai::chat::LLMChatInfo; use crate::SqliteVectorStore; @@ -27,6 +29,7 @@ use uuid::Uuid; pub struct LLMChat { store: Option, chain: ConversationalRetrieverChain, + #[allow(dead_code)] client: Arc, prompt: AFContextPrompt, info: LLMChatInfo, @@ -38,6 +41,7 @@ impl LLMChat { client: Arc, store: Option, user_service: Option>, + retriever_sources: Vec>, ) -> FlowyResult { let response_format = ResponseFormat::default(); let formatter = create_formatter_prompt_with_format(&response_format, &info.rag_ids); @@ -47,7 +51,12 @@ impl LLMChat { .map(|v| v.into()) .unwrap_or(SimpleMemory::new().into()); - let retriever = create_retriever(&info.workspace_id, info.rag_ids.clone(), store.clone()); + let retriever = create_retriever( + &info.workspace_id, + info.rag_ids.clone(), + store.clone(), + retriever_sources, + ); let builder = ConversationalRetrieverChainBuilder::new(info.workspace_id, llm, retriever, store.clone()) .rephrase_question(false) @@ -64,19 +73,7 @@ impl LLMChat { } pub async fn get_related_question(&self, user_message: String) -> FlowyResult> { - let chain = RelatedQuestionChain::new(LLMOllama::new( - &self.info.model, - self.client.clone(), - None, - None, - )); - let questions = chain.related_question(&user_message).await?; - trace!( - "related questions: {:?} for message: {}", - questions, - user_message - ); - Ok(questions) + self.chain.get_related_questions(&user_message).await } pub fn set_chat_model(&mut self, model: &str) { @@ -198,17 +195,40 @@ fn create_retriever( workspace_id: &Uuid, rag_ids: Vec, store: Option, -) -> AFRetriever { + retrievers_sources: Vec>, +) -> Box { trace!( "[VectorStore]: {} create retriever with rag_ids: {:?}", workspace_id, rag_ids, ); - let options = VecStoreOptions::default() - .with_score_threshold(0.2) - .with_filters(json!({RAG_IDS: rag_ids, "workspace_id": workspace_id})); - AFRetriever::new(store, 5, options) + let mut stores: Vec> = vec![]; + if let Some(store) = store { + stores.push(Arc::new(store)); + } + + for source in retrievers_sources { + if let Some(source) = source.upgrade() { + stores.push(source); + } + } + + trace!( + "[VectorStore]: use retrievers sources: {:?}", + stores + .iter() + .map(|s| s.retriever_name()) + .collect::>() + ); + + Box::new(MultipleSourceRetriever::new( + *workspace_id, + stores, + rag_ids.clone(), + 5, + 0.2, + )) } fn map_chain_error(err: ChainError) -> FlowyError { diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/mod.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/mod.rs index 6dbca1155e..68dda7a88a 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/chat/mod.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/mod.rs @@ -2,11 +2,12 @@ pub mod chains; mod format_prompt; pub mod llm; pub mod llm_chat; +pub mod retriever; mod summary_memory; -use crate::local_ai::chat::chains::related_question_chain::RelatedQuestionChain; use crate::local_ai::chat::llm::LLMOllama; use crate::local_ai::chat::llm_chat::LLMChat; +use crate::local_ai::chat::retriever::MultipleSourceRetrieverStore; use crate::local_ai::completion::chain::CompletionChain; use crate::local_ai::database::summary::DatabaseSummaryChain; use crate::local_ai::database::translate::DatabaseTranslateChain; @@ -28,7 +29,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::{Arc, Weak}; use tokio::sync::RwLock; -use tracing::trace; +use tracing::warn; use uuid::Uuid; type OllamaClientRef = Arc>>>; @@ -41,11 +42,14 @@ pub struct LLMChatInfo { pub summary: String, } +pub type RetrieversSources = RwLock>>; + pub struct LLMChatController { - chat_by_id: DashMap, + chat_by_id: DashMap>>, store: RwLock>, client: OllamaClientRef, user_service: Weak, + retriever_sources: RetrieversSources, } impl LLMChatController { pub fn new(user_service: Weak) -> Self { @@ -54,9 +58,14 @@ impl LLMChatController { chat_by_id: DashMap::new(), client: Default::default(), user_service, + retriever_sources: Default::default(), } } + pub async fn set_retriever_sources(&self, sources: Vec>) { + *self.retriever_sources.write().await = sources; + } + pub async fn is_ready(&self) -> bool { self.client.read().await.is_some() } @@ -72,9 +81,9 @@ impl LLMChatController { *self.store.write().await = Some(store); } - pub fn set_rag_ids(&self, chat_id: &Uuid, rag_ids: &[String]) { - if let Some(mut chat) = self.chat_by_id.get_mut(chat_id) { - chat.set_rag_ids(rag_ids.to_vec()); + pub async fn set_rag_ids(&self, chat_id: &Uuid, rag_ids: &[String]) { + if let Some(chat) = self.get_chat(chat_id) { + chat.write().await.set_rag_ids(rag_ids.to_vec()); } } @@ -90,10 +99,22 @@ impl LLMChatController { .ok_or_else(|| FlowyError::local_ai().with_context("Ollama client has been dropped"))? .clone(); let entry = self.chat_by_id.entry(info.chat_id); - + let retriever_sources = self + .retriever_sources + .read() + .await + .iter() + .map(Arc::downgrade) + .collect(); if let Entry::Vacant(e) = entry { - let chat = LLMChat::new(info, client, store, Some(self.user_service.clone()))?; - e.insert(chat); + let chat = LLMChat::new( + info, + client, + store, + Some(self.user_service.clone()), + retriever_sources, + )?; + e.insert(Arc::new(RwLock::new(chat))); } Ok(()) } @@ -107,6 +128,10 @@ impl LLMChatController { self.chat_by_id.remove(chat_id); } + pub fn get_chat(&self, chat_id: &Uuid) -> Option>> { + self.chat_by_id.get(chat_id).map(|c| c.value().clone()) + } + pub async fn summarize_database_row( &self, model_name: &str, @@ -185,38 +210,36 @@ impl LLMChatController { pub async fn get_related_question( &self, - model_name: &str, + _model_name: &str, chat_id: &Uuid, _message_id: i64, ) -> FlowyResult> { - let client = self - .client - .read() - .await - .clone() - .ok_or(FlowyError::local_ai())? - .upgrade() - .ok_or(FlowyError::local_ai())?; - - let user_service = self.user_service.upgrade().ok_or(FlowyError::local_ai())?; - let uid = user_service.user_id()?; - let conn = user_service.sqlite_connection(uid)?; - let message = select_latest_user_message(conn, &chat_id.to_string(), ChatAuthorType::Human)?; - - let chain = RelatedQuestionChain::new(LLMOllama::new(model_name, client, None, None)); - let questions = chain.related_question(&message.content).await?; - trace!( - "related questions: {:?} for message: {}", - questions, - message.content - ); - Ok(questions) + match self.get_chat(chat_id) { + None => { + warn!( + "[Chat] Chat with id {} not found, unable to get related question", + chat_id + ); + Ok(vec![]) + }, + Some(chat) => { + let user_service = self.user_service.upgrade().ok_or(FlowyError::local_ai())?; + let uid = user_service.user_id()?; + let conn = user_service.sqlite_connection(uid)?; + let message = + select_latest_user_message(conn, &chat_id.to_string(), ChatAuthorType::Human)?; + chat + .read() + .await + .get_related_question(message.content) + .await + }, + } } pub async fn ask_question(&self, chat_id: &Uuid, question: &str) -> FlowyResult { - if let Some(chat) = self.chat_by_id.get(chat_id) { - let chat = chat.value(); - let response = chat.ask_question(question).await; + if let Some(chat) = self.get_chat(chat_id) { + let response = chat.read().await.ask_question(question).await; return response; } @@ -230,10 +253,9 @@ impl LLMChatController { format: ResponseFormat, model_name: &str, ) -> FlowyResult { - if let Some(mut chat) = self.chat_by_id.get_mut(chat_id) { - chat.set_chat_model(model_name); - - let response = chat.stream_question(question, format).await; + if let Some(chat) = self.get_chat(chat_id) { + chat.write().await.set_chat_model(model_name); + let response = chat.write().await.stream_question(question, format).await; return response; } diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/mod.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/mod.rs new file mode 100644 index 0000000000..1e899e9f97 --- /dev/null +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/mod.rs @@ -0,0 +1,31 @@ +use async_trait::async_trait; +use flowy_error::FlowyResult; +pub use langchain_rust::schemas::Document as LangchainDocument; +use std::error::Error; +use uuid::Uuid; + +pub mod multi_source_retriever; +pub mod sqlite_retriever; + +#[async_trait] +pub trait AFRetriever: Send + Sync + 'static { + fn get_rag_ids(&self) -> Vec<&str>; + fn set_rag_ids(&mut self, new_rag_ids: Vec); + + async fn retrieve_documents(&self, query: &str) + -> Result, Box>; +} + +#[async_trait] +pub trait MultipleSourceRetrieverStore: Send + Sync { + fn retriever_name(&self) -> &'static str; + async fn read_documents( + &self, + workspace_id: &Uuid, + query: &str, + limit: usize, + rag_ids: &[String], + score_threshold: f32, + full_search: bool, + ) -> FlowyResult>; +} diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/multi_source_retriever.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/multi_source_retriever.rs new file mode 100644 index 0000000000..f6a311a832 --- /dev/null +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/multi_source_retriever.rs @@ -0,0 +1,120 @@ +use crate::local_ai::chat::retriever::{AFRetriever, MultipleSourceRetrieverStore}; +use async_trait::async_trait; +use futures::future::join_all; +use langchain_rust::schemas::Document; +use std::error::Error; +use std::sync::Arc; +use tracing::{error, trace}; +use uuid::Uuid; + +pub struct MultipleSourceRetriever { + workspace_id: Uuid, + vector_stores: Vec>, + num_docs: usize, + rag_ids: Vec, + full_search: bool, + score_threshold: f32, +} + +impl MultipleSourceRetriever { + pub fn new>>( + workspace_id: Uuid, + vector_stores: Vec, + rag_ids: Vec, + num_docs: usize, + score_threshold: f32, + ) -> Self { + MultipleSourceRetriever { + workspace_id, + vector_stores: vector_stores.into_iter().map(|v| v.into()).collect(), + num_docs, + rag_ids, + full_search: false, + score_threshold, + } + } + pub fn set_rag_ids(&mut self, new_rag_ids: Vec) { + self.rag_ids = new_rag_ids; + } + + pub fn get_rag_ids(&self) -> Vec<&str> { + self + .rag_ids + .iter() + .map(|id| id.as_str()) + .collect::>() + } +} + +#[async_trait] +impl AFRetriever for MultipleSourceRetriever { + fn get_rag_ids(&self) -> Vec<&str> { + self + .rag_ids + .iter() + .map(|id| id.as_str()) + .collect::>() + } + + fn set_rag_ids(&mut self, new_rag_ids: Vec) { + self.rag_ids = new_rag_ids; + } + + async fn retrieve_documents(&self, query: &str) -> Result, Box> { + trace!( + "[VectorStore] filters: {:?}, retrieving documents for query: {}", + self.rag_ids, + query, + ); + + // Create futures for each vector store search + let search_futures = self + .vector_stores + .iter() + .map(|vector_store| { + let vector_store = vector_store.clone(); + let query = query.to_string(); + let num_docs = self.num_docs; + let full_search = self.full_search; + let rag_ids = self.rag_ids.clone(); + let workspace_id = self.workspace_id; + let score_threshold = self.score_threshold; + + async move { + vector_store + .read_documents( + &workspace_id, + &query, + num_docs, + &rag_ids, + score_threshold, + full_search, + ) + .await + .map(|docs| (vector_store.retriever_name(), docs)) + } + }) + .collect::>(); + + let search_results = join_all(search_futures).await; + let mut results = Vec::new(); + for result in search_results { + if let Ok((retriever_name, docs)) = result { + trace!( + "[VectorStore] {} found {} results, scores: {:?}", + retriever_name, + docs.len(), + docs.iter().map(|doc| doc.score).collect::>() + ); + results.extend(docs); + } else { + error!( + "[VectorStore] Failed to retrieve documents: {}", + result.unwrap_err() + ); + } + } + + Ok(results) + } +} diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/sqlite_retriever.rs b/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/sqlite_retriever.rs new file mode 100644 index 0000000000..747d802fc6 --- /dev/null +++ b/frontend/rust-lib/flowy-ai/src/local_ai/chat/retriever/sqlite_retriever.rs @@ -0,0 +1,71 @@ +use crate::local_ai::chat::retriever::AFRetriever; +use async_trait::async_trait; +use flowy_ai_pub::entities::RAG_IDS; +use langchain_rust::schemas::{Document, Retriever}; +use langchain_rust::vectorstore::{VecStoreOptions, VectorStore}; +use serde_json::{json, Value}; +use std::error::Error; +use tracing::trace; + +// Retriever is a retriever for vector stores. +pub type RetrieverOption = VecStoreOptions; +pub struct SqliteVecRetriever { + vector_store: Option>>, + num_docs: usize, + options: RetrieverOption, +} +impl SqliteVecRetriever { + pub fn new>>>( + vector_store: Option, + num_docs: usize, + options: RetrieverOption, + ) -> Self { + SqliteVecRetriever { + vector_store: vector_store.map(Into::into), + num_docs, + options, + } + } +} + +#[async_trait] +impl AFRetriever for SqliteVecRetriever { + fn get_rag_ids(&self) -> Vec<&str> { + self + .options + .filters + .as_ref() + .and_then(|filters| filters.get(RAG_IDS).and_then(|rag_ids| rag_ids.as_array())) + .map(|rag_ids| rag_ids.iter().filter_map(|id| id.as_str()).collect()) + .unwrap_or_default() + } + + fn set_rag_ids(&mut self, new_rag_ids: Vec) { + let filters = self.options.filters.get_or_insert_with(|| json!({})); + filters[RAG_IDS] = json!(new_rag_ids); + } + + async fn retrieve_documents(&self, query: &str) -> Result, Box> { + trace!( + "[VectorStore] filters: {:?}, retrieving documents for query: {}", + self.options.filters, + query, + ); + + match self.vector_store.as_ref() { + None => Ok(vec![]), + Some(vector_store) => { + vector_store + .similarity_search(query, self.num_docs, &self.options) + .await + }, + } + } +} + +#[async_trait] +impl Retriever for SqliteVecRetriever { + async fn get_relevant_documents(&self, query: &str) -> Result, Box> { + self.retrieve_documents(query).await + } +} diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs index 1df2d68bef..600acc8759 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs @@ -26,7 +26,7 @@ use serde::{Deserialize, Serialize}; use std::ops::Deref; use std::path::PathBuf; use std::sync::{Arc, Weak}; -use tracing::{debug, error, info, instrument, trace}; +use tracing::{debug, error, info, instrument, trace, warn}; use uuid::Uuid; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -189,7 +189,7 @@ impl LocalAIController { return; } - self.llm_controller.set_rag_ids(chat_id, rag_ids); + self.llm_controller.set_rag_ids(chat_id, rag_ids).await; } pub async fn open_chat( @@ -201,14 +201,17 @@ impl LocalAIController { summary: String, ) -> FlowyResult<()> { if !self.is_enabled() { + warn!("[chat] local ai is disabled, skip open chat"); return Ok(()); } // Only keep one chat open at a time. Since loading multiple models at the same time will cause // memory issues. if let Some(current_chat_id) = self.current_chat_id.load().as_ref() { - debug!("[Local AI] close previous chat: {}", current_chat_id); - self.close_chat(current_chat_id); + if current_chat_id.as_ref() != chat_id { + debug!("[Chat] close previous chat: {}", current_chat_id); + self.close_chat(current_chat_id); + } } let info = LLMChatInfo { @@ -219,6 +222,7 @@ impl LocalAIController { summary, }; self.current_chat_id.store(Some(Arc::new(*chat_id))); + trace!("[Chat] open chat: {}", chat_id); self.llm_controller.open_chat(info).await?; Ok(()) } diff --git a/frontend/rust-lib/flowy-ai/tests/chat_test/qa_test.rs b/frontend/rust-lib/flowy-ai/tests/chat_test/qa_test.rs index 44db0a7d1b..93d73c2102 100644 --- a/frontend/rust-lib/flowy-ai/tests/chat_test/qa_test.rs +++ b/frontend/rust-lib/flowy-ai/tests/chat_test/qa_test.rs @@ -133,7 +133,7 @@ async fn local_ollama_test_chat_related_question() { let ollama = LLMOllama::default().with_model("llama3.1"); let chain = RelatedQuestionChain::new(ollama); let resp = chain - .related_question("Compare rust with JS") + .generate_related_question("Compare rust with JS") .await .unwrap(); diff --git a/frontend/rust-lib/flowy-ai/tests/main.rs b/frontend/rust-lib/flowy-ai/tests/main.rs index e5a91d93ef..2589aa1ea0 100644 --- a/frontend/rust-lib/flowy-ai/tests/main.rs +++ b/frontend/rust-lib/flowy-ai/tests/main.rs @@ -76,7 +76,14 @@ impl TestContext { summary: "".to_string(), }; - LLMChat::new(info, self.ollama.clone(), Some(self.store.clone()), None).unwrap() + LLMChat::new( + info, + self.ollama.clone(), + Some(self.store.clone()), + None, + vec![], + ) + .unwrap() } } diff --git a/frontend/rust-lib/flowy-core/src/deps_resolve/chat_deps.rs b/frontend/rust-lib/flowy-core/src/deps_resolve/chat_deps.rs index 5a954504a2..34c3e3b155 100644 --- a/frontend/rust-lib/flowy-core/src/deps_resolve/chat_deps.rs +++ b/frontend/rust-lib/flowy-core/src/deps_resolve/chat_deps.rs @@ -5,14 +5,18 @@ use collab::preclude::{Collab, StateVector}; use collab::util::is_change_since_sv; use collab_entity::CollabType; use flowy_ai::ai_manager::{AIExternalService, AIManager}; +use flowy_ai::local_ai::chat::retriever::{LangchainDocument, MultipleSourceRetrieverStore}; use flowy_ai::local_ai::controller::LocalAIController; use flowy_ai_pub::cloud::ChatCloudService; +use flowy_ai_pub::entities::{SOURCE, SOURCE_ID, SOURCE_NAME}; use flowy_ai_pub::persistence::AFCollabMetadata; use flowy_ai_pub::user_service::AIUserService; use flowy_error::{FlowyError, FlowyResult}; use flowy_folder::ViewLayout; use flowy_folder_pub::cloud::{FolderCloudService, FullSyncCollabParams}; use flowy_folder_pub::query::FolderService; +use flowy_search_pub::tantivy_state::DocumentTantivyState; +use flowy_server::util::tanvity_local_search; use flowy_sqlite::kv::KVStorePreferences; use flowy_sqlite::DBConnection; use flowy_storage_pub::storage::StorageService; @@ -20,9 +24,11 @@ use flowy_user::services::authenticate_user::AuthenticateUser; use flowy_user_pub::entities::WorkspaceType; use lib_infra::async_trait::async_trait; use lib_infra::util::timestamp; +use serde_json::json; use std::collections::HashMap; use std::path::PathBuf; use std::sync::{Arc, Weak}; +use tokio::sync::RwLock; use tracing::{debug, error, info}; use uuid::Uuid; @@ -205,3 +211,64 @@ impl AIUserService for ChatUserServiceImpl { )) } } + +#[derive(Clone)] +pub struct MultiSourceVSTanvityImpl { + state: Option>>, +} + +impl MultiSourceVSTanvityImpl { + pub fn new(state: Option>>) -> Self { + Self { state } + } +} + +#[async_trait] +impl MultipleSourceRetrieverStore for MultiSourceVSTanvityImpl { + fn retriever_name(&self) -> &'static str { + "Tanvity Multiple Source Retriever" + } + + async fn read_documents( + &self, + workspace_id: &Uuid, + query: &str, + limit: usize, + rag_ids: &[String], + score_threshold: f32, + _full_search: bool, + ) -> FlowyResult> { + let docs = tanvity_local_search( + &self.state, + workspace_id, + query, + Some(rag_ids.to_vec()), + limit, + score_threshold, + ) + .await; + + match docs { + None => Ok(vec![]), + Some(docs) => Ok( + docs + .into_iter() + .map(|v| LangchainDocument { + page_content: v.content, + metadata: json!({ + SOURCE_ID: v.object_id, + SOURCE: "appflowy", + SOURCE_NAME: "document", + }) + .as_object() + .unwrap() + .clone() + .into_iter() + .collect(), + score: v.score, + }) + .collect(), + ), + } + } +} diff --git a/frontend/rust-lib/flowy-core/src/server_layer.rs b/frontend/rust-lib/flowy-core/src/server_layer.rs index 03c33b644b..ab3ef8126e 100644 --- a/frontend/rust-lib/flowy-core/src/server_layer.rs +++ b/frontend/rust-lib/flowy-core/src/server_layer.rs @@ -1,3 +1,4 @@ +use crate::deps_resolve::MultiSourceVSTanvityImpl; use crate::AppFlowyCoreConfig; use arc_swap::{ArcSwap, ArcSwapOption}; use collab::entity::EncodedCollab; @@ -73,6 +74,13 @@ impl ServerProvider { } async fn set_tanvity_state(&self, tanvity_state: Option>>) { + let tanvity_store = Arc::new(MultiSourceVSTanvityImpl::new(tanvity_state.clone())); + + self + .local_ai + .set_retriever_sources(vec![tanvity_store]) + .await; + match self.providers.try_get(self.auth_type.load().as_ref()) { TryResult::Present(r) => { r.set_tanvity_state(tanvity_state).await; diff --git a/frontend/rust-lib/flowy-search-pub/src/entities.rs b/frontend/rust-lib/flowy-search-pub/src/entities.rs index d4e2cc81fe..914064015f 100644 --- a/frontend/rust-lib/flowy-search-pub/src/entities.rs +++ b/frontend/rust-lib/flowy-search-pub/src/entities.rs @@ -37,6 +37,7 @@ pub struct TanvitySearchResponseItem { pub icon: Option, pub workspace_id: String, pub content: String, + pub score: f32, } #[derive(Default, Debug, Clone, PartialEq, Eq)] diff --git a/frontend/rust-lib/flowy-search-pub/src/tantivy_state.rs b/frontend/rust-lib/flowy-search-pub/src/tantivy_state.rs index f1f0382461..d27f4a2bc4 100644 --- a/frontend/rust-lib/flowy-search-pub/src/tantivy_state.rs +++ b/frontend/rust-lib/flowy-search-pub/src/tantivy_state.rs @@ -305,6 +305,8 @@ impl DocumentTantivyState { workspace_id: &Uuid, query: &str, object_ids: Option>, + limit: usize, + score_threshold: f32, ) -> FlowyResult> { let workspace_id = workspace_id.to_string(); let reader = self.reader.clone(); @@ -319,7 +321,7 @@ impl DocumentTantivyState { qp.set_field_fuzzy(self.field_name, true, 2, true); let query = qp.parse_query(query)?; - let top_docs = searcher.search(&query, &tantivy::collector::TopDocs::with_limit(10))?; + let top_docs = searcher.search(&query, &tantivy::collector::TopDocs::with_limit(limit))?; let mut results = Vec::with_capacity(top_docs.len()); let mut seen_ids = std::collections::HashSet::new(); @@ -333,7 +335,12 @@ impl DocumentTantivyState { } }); - for (_score, doc_address) in top_docs { + for (score, doc_address) in top_docs { + // Skip results that don't meet the score threshold + if score < score_threshold { + continue; + } + let retrieved: TantivyDocument = searcher.doc(doc_address)?; // Pull out each stored field using cached field references let workspace_id_str = retrieved @@ -416,6 +423,7 @@ impl DocumentTantivyState { icon, workspace_id: workspace_id_str, content, + score, }); } diff --git a/frontend/rust-lib/flowy-search/src/document/local_search_handler.rs b/frontend/rust-lib/flowy-search/src/document/local_search_handler.rs index 3b339a4460..33f883c38f 100644 --- a/frontend/rust-lib/flowy-search/src/document/local_search_handler.rs +++ b/frontend/rust-lib/flowy-search/src/document/local_search_handler.rs @@ -50,7 +50,7 @@ impl SearchHandler for DocumentLocalSearchHandler { ); }, Some(state) => { - match state.read().await.search(&workspace_id, &query, None) { + match state.read().await.search(&workspace_id, &query, None, 10, 0.4) { Ok(items) => { trace!("[Tanvity] local document search result: {:?}", items); if items.is_empty() { diff --git a/frontend/rust-lib/flowy-server/src/af_cloud/impls/search.rs b/frontend/rust-lib/flowy-server/src/af_cloud/impls/search.rs index db90024e91..58cc311abc 100644 --- a/frontend/rust-lib/flowy-server/src/af_cloud/impls/search.rs +++ b/frontend/rust-lib/flowy-server/src/af_cloud/impls/search.rs @@ -39,7 +39,7 @@ where } trace!("[Search] Local AI search returned no results, falling back to local search"); - let items = tanvity_local_search(&self.state, workspace_id, &query) + let items = tanvity_local_search(&self.state, workspace_id, &query, None, 10, 0.4) .await .unwrap_or_default(); Ok(items) diff --git a/frontend/rust-lib/flowy-server/src/local_server/impls/search.rs b/frontend/rust-lib/flowy-server/src/local_server/impls/search.rs index 23f3531b4e..f49c84d4d9 100644 --- a/frontend/rust-lib/flowy-server/src/local_server/impls/search.rs +++ b/frontend/rust-lib/flowy-server/src/local_server/impls/search.rs @@ -44,7 +44,7 @@ impl SearchCloudService for LocalSearchServiceImpl { } trace!("[Search] Local AI search returned no results, falling back to local search"); - let items = tanvity_local_search(&self.state, workspace_id, &query) + let items = tanvity_local_search(&self.state, workspace_id, &query, None, 10, 0.4) .await .unwrap_or_default(); Ok(items) diff --git a/frontend/rust-lib/flowy-server/src/util.rs b/frontend/rust-lib/flowy-server/src/util.rs index bb76a90167..5907f2b05a 100644 --- a/frontend/rust-lib/flowy-server/src/util.rs +++ b/frontend/rust-lib/flowy-server/src/util.rs @@ -23,6 +23,9 @@ pub async fn tanvity_local_search( state: &Option>>, workspace_id: &Uuid, query: &str, + object_ids: Option>, + limit: usize, + score_threshold: f32, ) -> Option> { match state.as_ref().and_then(|v| v.upgrade()) { None => { @@ -30,7 +33,11 @@ pub async fn tanvity_local_search( None }, Some(state) => { - let results = state.read().await.search(workspace_id, query, None).ok()?; + let results = state + .read() + .await + .search(workspace_id, query, object_ids, limit, score_threshold) + .ok()?; let items = results .into_iter() .flat_map(|v| tanvity_document_to_search_document(*workspace_id, v)) @@ -49,7 +56,7 @@ pub(crate) fn tanvity_document_to_search_document( Some(SearchDocumentResponseItem { object_id, workspace_id, - score: 1.0, + score: doc.score as f64, content_type: Some(SearchContentType::PlainText), content: doc.content, preview: None,