diff --git a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart index a611d84310..36b389f8c4 100644 --- a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart +++ b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart @@ -90,38 +90,40 @@ class SelectModelPopoverContent extends StatelessWidget { return Padding( padding: const EdgeInsets.all(8.0), - child: Column( - mainAxisSize: MainAxisSize.min, - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - if (localModels.isNotEmpty) ...[ - _ModelSectionHeader( - title: LocaleKeys.chat_switchModel_localModel.tr(), + child: SingleChildScrollView( + child: Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + if (localModels.isNotEmpty) ...[ + _ModelSectionHeader( + title: LocaleKeys.chat_switchModel_localModel.tr(), + ), + const VSpace(4.0), + ], + ...localModels.map( + (model) => _ModelItem( + model: model, + isSelected: model == selectedModel, + onTap: () => onSelectModel?.call(model), + ), + ), + if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[ + const VSpace(8.0), + _ModelSectionHeader( + title: LocaleKeys.chat_switchModel_cloudModel.tr(), + ), + const VSpace(4.0), + ], + ...cloudModels.map( + (model) => _ModelItem( + model: model, + isSelected: model == selectedModel, + onTap: () => onSelectModel?.call(model), + ), ), - const VSpace(4.0), ], - ...localModels.map( - (model) => _ModelItem( - model: model, - isSelected: model == selectedModel, - onTap: () => onSelectModel?.call(model), - ), - ), - if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[ - const VSpace(8.0), - _ModelSectionHeader( - title: LocaleKeys.chat_switchModel_cloudModel.tr(), - ), - const VSpace(4.0), - ], - ...cloudModels.map( - (model) => _ModelItem( - model: model, - isSelected: model == selectedModel, - onTap: () => onSelectModel?.call(model), - ), - ), - ], + ), ), ); } diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart index f5c4209028..2659292b11 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/ollama_setting_bloc.dart @@ -11,6 +11,9 @@ import 'package:equatable/equatable.dart'; part 'ollama_setting_bloc.freezed.dart'; +const kDefaultChatModel = 'llama3.1:latest'; +const kDefaultEmbeddingModel = 'nomic-embed-text:latest'; + class OllamaSettingBloc extends Bloc { OllamaSettingBloc() : super(const OllamaSettingState()) { on(_handleEvent); @@ -70,7 +73,7 @@ class OllamaSettingBloc extends Bloc { final setting = LocalAISettingPB(); final settingUpdaters = { SettingType.serverUrl: (value) => setting.serverUrl = value, - SettingType.chatModel: (value) => setting.chatModelName = value, + SettingType.chatModel: (value) => setting.defaultModel = value, SettingType.embeddingModel: (value) => setting.embeddingModelName = value, }; @@ -108,13 +111,13 @@ class OllamaSettingBloc extends Bloc { settingType: SettingType.serverUrl, ), SettingItem( - content: setting.chatModelName, - hintText: 'llama3.1', + content: setting.defaultModel, + hintText: kDefaultChatModel, settingType: SettingType.chatModel, ), SettingItem( content: setting.embeddingModelName, - hintText: 'nomic-embed-text', + hintText: kDefaultEmbeddingModel, settingType: SettingType.embeddingModel, ), ]; @@ -125,7 +128,7 @@ class OllamaSettingBloc extends Bloc { settingType: SettingType.serverUrl, ), SubmittedItem( - content: setting.chatModelName, + content: setting.defaultModel, settingType: SettingType.chatModel, ), SubmittedItem( @@ -203,13 +206,13 @@ class OllamaSettingState with _$OllamaSettingState { settingType: SettingType.serverUrl, ), SettingItem( - content: 'llama3.1', - hintText: 'llama3.1', + content: kDefaultChatModel, + hintText: kDefaultChatModel, settingType: SettingType.chatModel, ), SettingItem( - content: 'nomic-embed-text', - hintText: 'nomic-embed-text', + content: kDefaultEmbeddingModel, + hintText: kDefaultEmbeddingModel, settingType: SettingType.embeddingModel, ), ]) diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index b7d7fbd7f2..4689ed5c4f 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -2210,12 +2210,6 @@ dependencies = [ "litrs", ] -[[package]] -name = "dotenv" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" - [[package]] name = "downcast-rs" version = "2.0.1" @@ -2506,7 +2500,6 @@ dependencies = [ "bytes", "collab-integrate", "dashmap 6.0.1", - "dotenv", "flowy-ai-pub", "flowy-codegen", "flowy-derive", @@ -2520,19 +2513,18 @@ dependencies = [ "lib-infra", "log", "notify", + "ollama-rs", "pin-project", "protobuf", "reqwest 0.11.27", "serde", "serde_json", "sha2", - "simsimd", "strum_macros 0.21.1", "tokio", "tokio-stream", "tokio-util", "tracing", - "tracing-subscriber", "uuid", "validator 0.18.1", ] @@ -2798,6 +2790,7 @@ dependencies = [ "flowy-derive", "flowy-sqlite", "lib-dispatch", + "ollama-rs", "protobuf", "r2d2", "reqwest 0.11.27", @@ -4044,6 +4037,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -4894,6 +4888,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a4b4750770584c8b4a643d0329e7bedacc4ecf68b7c7ac3e1fec2bafd6312f7" +dependencies = [ + "async-stream", + "log", + "reqwest 0.12.15", + "schemars", + "serde", + "serde_json", + "static_assertions", + "thiserror 2.0.12", + "url", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -5696,7 +5707,7 @@ dependencies = [ "rustc-hash 2.1.0", "rustls 0.23.20", "socket2 0.5.5", - "thiserror 2.0.9", + "thiserror 2.0.12", "tokio", "tracing", ] @@ -5715,7 +5726,7 @@ dependencies = [ "rustls 0.23.20", "rustls-pki-types", "slab", - "thiserror 2.0.9", + "thiserror 2.0.12", "tinyvec", "tracing", "web-time", @@ -6407,6 +6418,31 @@ dependencies = [ "parking_lot 0.12.1", ] +[[package]] +name = "schemars" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" +dependencies = [ + "dyn-clone", + "indexmap 1.9.3", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals 0.29.1", + "syn 2.0.94", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -6554,6 +6590,17 @@ dependencies = [ "syn 2.0.94", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.94", +] + [[package]] name = "serde_html_form" version = "0.2.7" @@ -6746,15 +6793,6 @@ dependencies = [ "time", ] -[[package]] -name = "simsimd" -version = "4.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4" -dependencies = [ - "cc", -] - [[package]] name = "siphasher" version = "0.3.11" @@ -6841,6 +6879,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "string_cache" version = "0.8.7" @@ -7098,7 +7142,7 @@ dependencies = [ "tantivy-stacker", "tantivy-tokenizer-api", "tempfile", - "thiserror 2.0.9", + "thiserror 2.0.12", "time", "uuid", "winapi", @@ -7280,11 +7324,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.9" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.9", + "thiserror-impl 2.0.12", ] [[package]] @@ -7300,9 +7344,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.9" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", @@ -7823,7 +7867,7 @@ checksum = "7a94b0f0954b3e59bfc2c246b4c8574390d94a4ad4ad246aaf2fb07d7dfd3b47" dependencies = [ "proc-macro2", "quote", - "serde_derive_internals", + "serde_derive_internals 0.28.0", "syn 2.0.94", ] diff --git a/frontend/rust-lib/flowy-ai-pub/src/persistence/local_model_sql.rs b/frontend/rust-lib/flowy-ai-pub/src/persistence/local_model_sql.rs new file mode 100644 index 0000000000..1e1d49b79e --- /dev/null +++ b/frontend/rust-lib/flowy-ai-pub/src/persistence/local_model_sql.rs @@ -0,0 +1,54 @@ +use diesel::sqlite::SqliteConnection; +use flowy_error::FlowyResult; +use flowy_sqlite::upsert::excluded; +use flowy_sqlite::{ + diesel, + query_dsl::*, + schema::{local_ai_model_table, local_ai_model_table::dsl}, + ExpressionMethods, Identifiable, Insertable, Queryable, +}; + +#[derive(Clone, Default, Queryable, Insertable, Identifiable)] +#[diesel(table_name = local_ai_model_table)] +#[diesel(primary_key(name))] +pub struct LocalAIModelTable { + pub name: String, + pub model_type: i16, +} + +#[derive(Clone, Debug, Copy)] +pub enum ModelType { + Embedding = 0, + Chat = 1, +} + +impl From for ModelType { + fn from(value: i16) -> Self { + match value { + 0 => ModelType::Embedding, + 1 => ModelType::Chat, + _ => ModelType::Embedding, + } + } +} + +pub fn select_local_ai_model(conn: &mut SqliteConnection, name: &str) -> Option { + local_ai_model_table::table + .filter(dsl::name.eq(name)) + .first::(conn) + .ok() +} + +pub fn upsert_local_ai_model( + conn: &mut SqliteConnection, + row: &LocalAIModelTable, +) -> FlowyResult<()> { + diesel::insert_into(local_ai_model_table::table) + .values(row) + .on_conflict(local_ai_model_table::name) + .do_update() + .set((local_ai_model_table::model_type.eq(excluded(local_ai_model_table::model_type)),)) + .execute(conn)?; + + Ok(()) +} diff --git a/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs b/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs index b21eb507ae..7ae97148ce 100644 --- a/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs +++ b/frontend/rust-lib/flowy-ai-pub/src/persistence/mod.rs @@ -1,5 +1,7 @@ mod chat_message_sql; mod chat_sql; +mod local_model_sql; pub use chat_message_sql::*; pub use chat_sql::*; +pub use local_model_sql::*; diff --git a/frontend/rust-lib/flowy-ai/Cargo.toml b/frontend/rust-lib/flowy-ai/Cargo.toml index 3a6aaf5898..fb90714d00 100644 --- a/frontend/rust-lib/flowy-ai/Cargo.toml +++ b/frontend/rust-lib/flowy-ai/Cargo.toml @@ -48,16 +48,16 @@ collab-integrate.workspace = true [target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies] notify = "6.1.1" +ollama-rs = "0.3.0" +#faiss = { version = "0.12.1" } af-mcp = { version = "0.1.0" } [dev-dependencies] -dotenv = "0.15.0" uuid.workspace = true -tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] } -simsimd = "4.4.0" [build-dependencies] flowy-codegen.workspace = true [features] dart = ["flowy-codegen/dart", "flowy-notification/dart"] +local_ai = [] \ No newline at end of file diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index 06b3adaeea..d18cb91355 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -330,14 +330,10 @@ impl AIManager { .get_question_id_from_answer_id(chat_id, answer_message_id) .await?; - let model = model.map_or_else( - || { - self - .store_preferences - .get_object::(&ai_available_models_key(&chat_id.to_string())) - }, - |model| Some(model.into()), - ); + let model = match model { + None => self.get_active_model(&chat_id.to_string()).await, + Some(model) => Some(model.into()), + }; chat .stream_regenerate_response(question_message_id, answer_stream_port, format, model) .await?; @@ -354,9 +350,10 @@ impl AIManager { "[AI Plugin] update global active model, previous: {}, current: {}", previous_model, current_model ); - let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY); let model = AIModel::local(current_model, "".to_string()); - self.update_selected_model(source_key, model).await?; + self + .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) + .await?; } Ok(()) @@ -440,11 +437,11 @@ impl AIManager { } pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> { - info!( - "[Model Selection] update {} selected model: {:?}", - source, model - ); let source_key = ai_available_models_key(&source); + info!( + "[Model Selection] update {} selected model: {:?} for key:{}", + source, model, source_key + ); self .store_preferences .set_object::(&source_key, &model)?; @@ -458,12 +455,13 @@ impl AIManager { #[instrument(skip_all, level = "debug")] pub async fn toggle_local_ai(&self) -> FlowyResult<()> { let enabled = self.local_ai.toggle_local_ai().await?; - let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY); if enabled { if let Some(name) = self.local_ai.get_plugin_chat_model() { info!("Set global active model to local ai: {}", name); let model = AIModel::local(name, "".to_string()); - self.update_selected_model(source_key, model).await?; + self + .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model) + .await?; } } else { info!("Set global active model to default"); @@ -471,7 +469,7 @@ impl AIManager { let models = self.get_server_available_models().await?; if let Some(model) = models.into_iter().find(|m| m.name == global_active_model) { self - .update_selected_model(source_key, AIModel::from(model)) + .update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), AIModel::from(model)) .await?; } } @@ -484,21 +482,31 @@ impl AIManager { .store_preferences .get_object::(&ai_available_models_key(source)); - if model.is_none() { - if let Some(local_model) = self.local_ai.get_plugin_chat_model() { - model = Some(AIModel::local(local_model, "".to_string())); - } + match model { + None => { + if let Some(local_model) = self.local_ai.get_plugin_chat_model() { + model = Some(AIModel::local(local_model, "".to_string())); + } + model + }, + Some(mut model) => { + let models = self.local_ai.get_all_chat_local_models().await; + if !models.contains(&model) { + if let Some(local_model) = self.local_ai.get_plugin_chat_model() { + model = AIModel::local(local_model, "".to_string()); + } + } + Some(model) + }, } - - model } pub async fn get_available_models(&self, source: String) -> FlowyResult { let is_local_mode = self.user_service.is_local_model().await?; if is_local_mode { let setting = self.local_ai.get_local_ai_setting(); + let models = self.local_ai.get_all_chat_local_models().await; let selected_model = AIModel::local(setting.chat_model_name, "".to_string()); - let models = vec![selected_model.clone()]; Ok(AvailableModelsPB { models: models.into_iter().map(|m| m.into()).collect(), @@ -506,27 +514,24 @@ impl AIManager { }) } else { // Build the models list from server models and mark them as non-local. - let mut models: Vec = self + let mut all_models: Vec = self .get_server_available_models() .await? .into_iter() .map(AIModel::from) .collect(); - trace!("[Model Selection]: Available models: {:?}", models); - let mut current_active_local_ai_model = None; + trace!("[Model Selection]: Available models: {:?}", all_models); // If user enable local ai, then add local ai model to the list. - if let Some(local_model) = self.local_ai.get_plugin_chat_model() { - let model = AIModel::local(local_model, "".to_string()); - current_active_local_ai_model = Some(model.clone()); - trace!("[Model Selection] current local ai model: {}", model.name); - models.push(model); + if self.local_ai.is_enabled() { + let local_models = self.local_ai.get_all_chat_local_models().await; + all_models.extend(local_models.into_iter().map(|m| m)); } - if models.is_empty() { + if all_models.is_empty() { return Ok(AvailableModelsPB { - models: models.into_iter().map(|m| m.into()).collect(), + models: all_models.into_iter().map(|m| m.into()).collect(), selected_model: AIModelPB::default(), }); } @@ -545,37 +550,29 @@ impl AIManager { let mut user_selected_model = server_active_model.clone(); // when current select model is deprecated, reset the model to default - if !models.iter().any(|m| m.name == server_active_model.name) { + if !all_models + .iter() + .any(|m| m.name == server_active_model.name) + { server_active_model = AIModel::default(); } - let source_key = ai_available_models_key(&source); // We use source to identify user selected model. source can be document id or chat id. - match self.store_preferences.get_object::(&source_key) { + match self.get_active_model(&source).await { None => { // when there is selected model and current local ai is active, then use local ai - if let Some(local_ai_model) = models.iter().find(|m| m.is_local) { + if let Some(local_ai_model) = all_models.iter().find(|m| m.is_local) { user_selected_model = local_ai_model.clone(); } }, - Some(mut model) => { + Some(model) => { trace!("[Model Selection] user previous select model: {:?}", model); - // If source is provided, try to get the user-selected model from the store. User selected - // model will be used as the active model if it exists. - if model.is_local { - if let Some(local_ai_model) = ¤t_active_local_ai_model { - if local_ai_model.name != model.name { - model = local_ai_model.clone(); - } - } - } - user_selected_model = model; }, } // If user selected model is not available in the list, use the global active model. - let active_model = models + let active_model = all_models .iter() .find(|m| m.name == user_selected_model.name) .cloned() @@ -585,15 +582,15 @@ impl AIManager { if let Some(ref active_model) = active_model { if active_model.name != user_selected_model.name { self - .store_preferences - .set_object::(&source_key, &active_model.clone())?; + .update_selected_model(source, active_model.clone()) + .await?; } } trace!("[Model Selection] final active model: {:?}", active_model); let selected_model = AIModelPB::from(active_model.unwrap_or_default()); Ok(AvailableModelsPB { - models: models.into_iter().map(|m| m.into()).collect(), + models: all_models.into_iter().map(|m| m.into()).collect(), selected_model, }) } diff --git a/frontend/rust-lib/flowy-ai/src/entities.rs b/frontend/rust-lib/flowy-ai/src/entities.rs index 5a4aecbbd7..796664a18f 100644 --- a/frontend/rust-lib/flowy-ai/src/entities.rs +++ b/frontend/rust-lib/flowy-ai/src/entities.rs @@ -686,7 +686,7 @@ pub struct LocalAISettingPB { #[pb(index = 2)] #[validate(custom(function = "required_not_empty_str"))] - pub chat_model_name: String, + pub default_model: String, #[pb(index = 3)] #[validate(custom(function = "required_not_empty_str"))] @@ -697,7 +697,7 @@ impl From for LocalAISettingPB { fn from(value: LocalAISetting) -> Self { LocalAISettingPB { server_url: value.ollama_server_url, - chat_model_name: value.chat_model_name, + default_model: value.chat_model_name, embedding_model_name: value.embedding_model_name, } } @@ -707,7 +707,7 @@ impl From for LocalAISetting { fn from(value: LocalAISettingPB) -> Self { LocalAISetting { ollama_server_url: value.server_url, - chat_model_name: value.chat_model_name, + chat_model_name: value.default_model, embedding_model_name: value.embedding_model_name, } } diff --git a/frontend/rust-lib/flowy-ai/src/event_handler.rs b/frontend/rust-lib/flowy-ai/src/event_handler.rs index f85858b1c2..f778063309 100644 --- a/frontend/rust-lib/flowy-ai/src/event_handler.rs +++ b/frontend/rust-lib/flowy-ai/src/event_handler.rs @@ -1,7 +1,6 @@ use crate::ai_manager::{AIManager, GLOBAL_ACTIVE_MODEL_KEY}; use crate::completion::AICompletion; use crate::entities::*; -use crate::util::ai_available_models_key; use flowy_ai_pub::cloud::{AIModel, ChatMessageType}; use flowy_error::{ErrorCode, FlowyError, FlowyResult}; use lib_dispatch::prelude::{data_result_ok, AFPluginData, AFPluginState, DataResult}; @@ -82,8 +81,9 @@ pub(crate) async fn get_server_model_list_handler( ai_manager: AFPluginState>, ) -> DataResult { let ai_manager = upgrade_ai_manager(ai_manager)?; - let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY); - let models = ai_manager.get_available_models(source_key).await?; + let models = ai_manager + .get_available_models(GLOBAL_ACTIVE_MODEL_KEY.to_string()) + .await?; data_result_ok(models) } diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs index 1ec08854e0..d384ddfb75 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/controller.rs @@ -16,9 +16,15 @@ use af_local_ai::ollama_plugin::OllamaAIPlugin; use af_plugin::core::path::is_plugin_ready; use af_plugin::core::plugin::RunningState; use arc_swap::ArcSwapOption; +use flowy_ai_pub::cloud::AIModel; +use flowy_ai_pub::persistence::{ + select_local_ai_model, upsert_local_ai_model, LocalAIModelTable, ModelType, +}; use flowy_ai_pub::user_service::AIUserService; use futures_util::SinkExt; use lib_infra::util::get_operating_system; +use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; +use ollama_rs::Ollama; use serde::{Deserialize, Serialize}; use std::ops::Deref; use std::path::PathBuf; @@ -39,8 +45,8 @@ impl Default for LocalAISetting { fn default() -> Self { Self { ollama_server_url: "http://localhost:11434".to_string(), - chat_model_name: "llama3.1".to_string(), - embedding_model_name: "nomic-embed-text".to_string(), + chat_model_name: "llama3.1:latest".to_string(), + embedding_model_name: "nomic-embed-text:latest".to_string(), } } } @@ -53,6 +59,7 @@ pub struct LocalAIController { current_chat_id: ArcSwapOption, store_preferences: Weak, user_service: Arc, + ollama: ArcSwapOption, } impl Deref for LocalAIController { @@ -83,69 +90,80 @@ impl LocalAIController { user_service.clone(), res_impl, )); - // Subscribe to state changes - let mut running_state_rx = local_ai.subscribe_running_state(); - let cloned_llm_res = Arc::clone(&local_ai_resource); - let cloned_store_preferences = store_preferences.clone(); - let cloned_local_ai = Arc::clone(&local_ai); - let cloned_user_service = Arc::clone(&user_service); + let ollama = ArcSwapOption::default(); + let sys = get_operating_system(); + if sys.is_desktop() { + let setting = local_ai_resource.get_llm_setting(); + ollama.store( + Ollama::try_new(&setting.ollama_server_url) + .map(Arc::new) + .ok(), + ); - // Spawn a background task to listen for plugin state changes - tokio::spawn(async move { - while let Some(state) = running_state_rx.next().await { - // Skip if we can’t get workspace_id - let Ok(workspace_id) = cloned_user_service.workspace_id() else { - continue; - }; + // Subscribe to state changes + let mut running_state_rx = local_ai.subscribe_running_state(); + let cloned_llm_res = Arc::clone(&local_ai_resource); + let cloned_store_preferences = store_preferences.clone(); + let cloned_local_ai = Arc::clone(&local_ai); + let cloned_user_service = Arc::clone(&user_service); - let key = local_ai_enabled_key(&workspace_id.to_string()); - info!("[AI Plugin] state: {:?}", state); - - // Read whether plugin is enabled from store; default to true - if let Some(store_preferences) = cloned_store_preferences.upgrade() { - let enabled = store_preferences.get_bool(&key).unwrap_or(true); - // Only check resource status if the plugin isn’t in "UnexpectedStop" and is enabled - let (plugin_downloaded, lack_of_resource) = - if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled { - // Possibly check plugin readiness and resource concurrency in parallel, - // but here we do it sequentially for clarity. - let downloaded = is_plugin_ready(); - let resource_lack = cloned_llm_res.get_lack_of_resource().await; - (downloaded, resource_lack) - } else { - (false, None) - }; - - // If plugin is running, retrieve version - let plugin_version = if matches!(state, RunningState::Running { .. }) { - match cloned_local_ai.plugin_info().await { - Ok(info) => Some(info.version), - Err(_) => None, - } - } else { - None + // Spawn a background task to listen for plugin state changes + tokio::spawn(async move { + while let Some(state) = running_state_rx.next().await { + // Skip if we can't get workspace_id + let Ok(workspace_id) = cloned_user_service.workspace_id() else { + continue; }; - // Broadcast the new local AI state - let new_state = RunningStatePB::from(state); - chat_notification_builder( - APPFLOWY_AI_NOTIFICATION_KEY, - ChatNotification::UpdateLocalAIState, - ) - .payload(LocalAIPB { - enabled, - plugin_downloaded, - lack_of_resource, - state: new_state, - plugin_version, - }) - .send(); - } else { - warn!("[AI Plugin] store preferences is dropped"); + let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string()); + info!("[AI Plugin] state: {:?}", state); + + // Read whether plugin is enabled from store; default to true + if let Some(store_preferences) = cloned_store_preferences.upgrade() { + let enabled = store_preferences.get_bool(&key).unwrap_or(true); + // Only check resource status if the plugin isn't in "UnexpectedStop" and is enabled + let (plugin_downloaded, lack_of_resource) = + if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled { + // Possibly check plugin readiness and resource concurrency in parallel, + // but here we do it sequentially for clarity. + let downloaded = is_plugin_ready(); + let resource_lack = cloned_llm_res.get_lack_of_resource().await; + (downloaded, resource_lack) + } else { + (false, None) + }; + + // If plugin is running, retrieve version + let plugin_version = if matches!(state, RunningState::Running { .. }) { + match cloned_local_ai.plugin_info().await { + Ok(info) => Some(info.version), + Err(_) => None, + } + } else { + None + }; + + // Broadcast the new local AI state + let new_state = RunningStatePB::from(state); + chat_notification_builder( + APPFLOWY_AI_NOTIFICATION_KEY, + ChatNotification::UpdateLocalAIState, + ) + .payload(LocalAIPB { + enabled, + plugin_downloaded, + lack_of_resource, + state: new_state, + plugin_version, + }) + .send(); + } else { + warn!("[AI Plugin] store preferences is dropped"); + } } - } - }); + }); + } Self { ai_plugin: local_ai, @@ -153,6 +171,7 @@ impl LocalAIController { current_chat_id: ArcSwapOption::default(), store_preferences, user_service, + ollama, } } #[instrument(level = "debug", skip_all)] @@ -287,6 +306,78 @@ impl LocalAIController { self.resource.get_llm_setting() } + pub async fn get_all_chat_local_models(&self) -> Vec { + self + .get_filtered_local_models(|name| !name.contains("embed")) + .await + } + + pub async fn get_all_embedded_local_models(&self) -> Vec { + self + .get_filtered_local_models(|name| name.contains("embed")) + .await + } + + // Helper function to avoid code duplication in model retrieval + async fn get_filtered_local_models(&self, filter_fn: F) -> Vec + where + F: Fn(&str) -> bool, + { + match self.ollama.load_full() { + None => vec![], + Some(ollama) => ollama + .list_local_models() + .await + .map(|models| { + models + .into_iter() + .filter(|m| filter_fn(&m.name.to_lowercase())) + .map(|m| AIModel::local(m.name, String::new())) + .collect() + }) + .unwrap_or_default(), + } + } + + pub async fn check_model_type(&self, model_name: &str) -> FlowyResult { + let uid = self.user_service.user_id()?; + let mut conn = self.user_service.sqlite_connection(uid)?; + match select_local_ai_model(&mut conn, model_name) { + None => { + let ollama = self + .ollama + .load_full() + .ok_or_else(|| FlowyError::local_ai().with_context("ollama is not initialized"))?; + + let request = GenerateEmbeddingsRequest::new( + model_name.to_string(), + EmbeddingsInput::Single("Hello".to_string()), + ); + + let model_type = match ollama.generate_embeddings(request).await { + Ok(value) => { + if value.embeddings.is_empty() { + ModelType::Chat + } else { + ModelType::Embedding + } + }, + Err(_) => ModelType::Chat, + }; + + upsert_local_ai_model( + &mut conn, + &LocalAIModelTable { + name: model_name.to_string(), + model_type: model_type as i16, + }, + )?; + Ok(model_type) + }, + Some(r) => Ok(ModelType::from(r.model_type)), + } + } + pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> { info!( "[AI Plugin] update local ai setting: {:?}, thread: {:?}", diff --git a/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs b/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs index 36a56e171d..352778f28f 100644 --- a/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs +++ b/frontend/rust-lib/flowy-ai/src/local_ai/resource.rs @@ -161,7 +161,6 @@ impl LocalAIResourceController { let setting = self.get_llm_setting(); let client = Client::builder().timeout(Duration::from_secs(5)).build()?; - match client.get(&setting.ollama_server_url).send().await { Ok(resp) if resp.status().is_success() => { info!( diff --git a/frontend/rust-lib/flowy-error/Cargo.toml b/frontend/rust-lib/flowy-error/Cargo.toml index 61a7422f17..8bc67ee46c 100644 --- a/frontend/rust-lib/flowy-error/Cargo.toml +++ b/frontend/rust-lib/flowy-error/Cargo.toml @@ -36,6 +36,10 @@ client-api = { workspace = true, optional = true } tantivy = { workspace = true, optional = true } uuid.workspace = true +[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies] +ollama-rs = "0.3.0" + + [features] default = ["impl_from_dispatch_error", "impl_from_serde", "impl_from_reqwest", "impl_from_sqlite"] impl_from_dispatch_error = ["lib-dispatch"] diff --git a/frontend/rust-lib/flowy-error/src/errors.rs b/frontend/rust-lib/flowy-error/src/errors.rs index f76a7d4dda..36240cd08d 100644 --- a/frontend/rust-lib/flowy-error/src/errors.rs +++ b/frontend/rust-lib/flowy-error/src/errors.rs @@ -264,3 +264,10 @@ impl From for FlowyError { FlowyError::internal().with_context(value) } } + +#[cfg(any(target_os = "windows", target_os = "macos", target_os = "linux"))] +impl From for FlowyError { + fn from(value: ollama_rs::error::OllamaError) -> Self { + FlowyError::local_ai().with_context(value) + } +} diff --git a/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/down.sql b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/down.sql new file mode 100644 index 0000000000..d9a93fe9a1 --- /dev/null +++ b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/down.sql @@ -0,0 +1 @@ +-- This file should undo anything in `up.sql` diff --git a/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/up.sql b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/up.sql new file mode 100644 index 0000000000..243fe61193 --- /dev/null +++ b/frontend/rust-lib/flowy-sqlite/migrations/2025-04-25-071459_local_ai_model/up.sql @@ -0,0 +1,6 @@ +-- Your SQL goes here +CREATE TABLE local_ai_model_table +( + name TEXT PRIMARY KEY NOT NULL, + model_type SMALLINT NOT NULL +); \ No newline at end of file diff --git a/frontend/rust-lib/flowy-sqlite/src/schema.rs b/frontend/rust-lib/flowy-sqlite/src/schema.rs index 0236cbf467..bf7f431682 100644 --- a/frontend/rust-lib/flowy-sqlite/src/schema.rs +++ b/frontend/rust-lib/flowy-sqlite/src/schema.rs @@ -54,6 +54,13 @@ diesel::table! { } } +diesel::table! { + local_ai_model_table (name) { + name -> Text, + model_type -> SmallInt, + } +} + diesel::table! { upload_file_part (upload_id, e_tag) { upload_id -> Text, @@ -133,16 +140,17 @@ diesel::table! { } diesel::allow_tables_to_appear_in_same_query!( - af_collab_metadata, - chat_local_setting_table, - chat_message_table, - chat_table, - collab_snapshot, - upload_file_part, - upload_file_table, - user_data_migration_records, - user_table, - user_workspace_table, - workspace_members_table, - workspace_setting_table, + af_collab_metadata, + chat_local_setting_table, + chat_message_table, + chat_table, + collab_snapshot, + local_ai_model_table, + upload_file_part, + upload_file_table, + user_data_migration_records, + user_table, + user_workspace_table, + workspace_members_table, + workspace_setting_table, );