chore: display all local models

This commit is contained in:
Nathan 2025-04-25 16:53:32 +08:00
parent 2c6253576f
commit 7dd8d06c85
16 changed files with 418 additions and 200 deletions

View File

@ -90,38 +90,40 @@ class SelectModelPopoverContent extends StatelessWidget {
return Padding(
padding: const EdgeInsets.all(8.0),
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
if (localModels.isNotEmpty) ...[
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_localModel.tr(),
child: SingleChildScrollView(
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
if (localModels.isNotEmpty) ...[
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_localModel.tr(),
),
const VSpace(4.0),
],
...localModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[
const VSpace(8.0),
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_cloudModel.tr(),
),
const VSpace(4.0),
],
...cloudModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
const VSpace(4.0),
],
...localModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[
const VSpace(8.0),
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_cloudModel.tr(),
),
const VSpace(4.0),
],
...cloudModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
],
),
),
);
}

View File

@ -11,6 +11,9 @@ import 'package:equatable/equatable.dart';
part 'ollama_setting_bloc.freezed.dart';
const kDefaultChatModel = 'llama3.1:latest';
const kDefaultEmbeddingModel = 'nomic-embed-text:latest';
class OllamaSettingBloc extends Bloc<OllamaSettingEvent, OllamaSettingState> {
OllamaSettingBloc() : super(const OllamaSettingState()) {
on<OllamaSettingEvent>(_handleEvent);
@ -70,7 +73,7 @@ class OllamaSettingBloc extends Bloc<OllamaSettingEvent, OllamaSettingState> {
final setting = LocalAISettingPB();
final settingUpdaters = <SettingType, void Function(String)>{
SettingType.serverUrl: (value) => setting.serverUrl = value,
SettingType.chatModel: (value) => setting.chatModelName = value,
SettingType.chatModel: (value) => setting.defaultModel = value,
SettingType.embeddingModel: (value) =>
setting.embeddingModelName = value,
};
@ -108,13 +111,13 @@ class OllamaSettingBloc extends Bloc<OllamaSettingEvent, OllamaSettingState> {
settingType: SettingType.serverUrl,
),
SettingItem(
content: setting.chatModelName,
hintText: 'llama3.1',
content: setting.defaultModel,
hintText: kDefaultChatModel,
settingType: SettingType.chatModel,
),
SettingItem(
content: setting.embeddingModelName,
hintText: 'nomic-embed-text',
hintText: kDefaultEmbeddingModel,
settingType: SettingType.embeddingModel,
),
];
@ -125,7 +128,7 @@ class OllamaSettingBloc extends Bloc<OllamaSettingEvent, OllamaSettingState> {
settingType: SettingType.serverUrl,
),
SubmittedItem(
content: setting.chatModelName,
content: setting.defaultModel,
settingType: SettingType.chatModel,
),
SubmittedItem(
@ -203,13 +206,13 @@ class OllamaSettingState with _$OllamaSettingState {
settingType: SettingType.serverUrl,
),
SettingItem(
content: 'llama3.1',
hintText: 'llama3.1',
content: kDefaultChatModel,
hintText: kDefaultChatModel,
settingType: SettingType.chatModel,
),
SettingItem(
content: 'nomic-embed-text',
hintText: 'nomic-embed-text',
content: kDefaultEmbeddingModel,
hintText: kDefaultEmbeddingModel,
settingType: SettingType.embeddingModel,
),
])

View File

@ -2210,12 +2210,6 @@ dependencies = [
"litrs",
]
[[package]]
name = "dotenv"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
[[package]]
name = "downcast-rs"
version = "2.0.1"
@ -2506,7 +2500,6 @@ dependencies = [
"bytes",
"collab-integrate",
"dashmap 6.0.1",
"dotenv",
"flowy-ai-pub",
"flowy-codegen",
"flowy-derive",
@ -2520,19 +2513,18 @@ dependencies = [
"lib-infra",
"log",
"notify",
"ollama-rs",
"pin-project",
"protobuf",
"reqwest 0.11.27",
"serde",
"serde_json",
"sha2",
"simsimd",
"strum_macros 0.21.1",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator 0.18.1",
]
@ -2798,6 +2790,7 @@ dependencies = [
"flowy-derive",
"flowy-sqlite",
"lib-dispatch",
"ollama-rs",
"protobuf",
"r2d2",
"reqwest 0.11.27",
@ -4044,6 +4037,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
"serde",
]
[[package]]
@ -4894,6 +4888,23 @@ dependencies = [
"memchr",
]
[[package]]
name = "ollama-rs"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a4b4750770584c8b4a643d0329e7bedacc4ecf68b7c7ac3e1fec2bafd6312f7"
dependencies = [
"async-stream",
"log",
"reqwest 0.12.15",
"schemars",
"serde",
"serde_json",
"static_assertions",
"thiserror 2.0.12",
"url",
]
[[package]]
name = "once_cell"
version = "1.19.0"
@ -5696,7 +5707,7 @@ dependencies = [
"rustc-hash 2.1.0",
"rustls 0.23.20",
"socket2 0.5.5",
"thiserror 2.0.9",
"thiserror 2.0.12",
"tokio",
"tracing",
]
@ -5715,7 +5726,7 @@ dependencies = [
"rustls 0.23.20",
"rustls-pki-types",
"slab",
"thiserror 2.0.9",
"thiserror 2.0.12",
"tinyvec",
"tracing",
"web-time",
@ -6407,6 +6418,31 @@ dependencies = [
"parking_lot 0.12.1",
]
[[package]]
name = "schemars"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
dependencies = [
"dyn-clone",
"indexmap 1.9.3",
"schemars_derive",
"serde",
"serde_json",
]
[[package]]
name = "schemars_derive"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d"
dependencies = [
"proc-macro2",
"quote",
"serde_derive_internals 0.29.1",
"syn 2.0.94",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -6554,6 +6590,17 @@ dependencies = [
"syn 2.0.94",
]
[[package]]
name = "serde_derive_internals"
version = "0.29.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.94",
]
[[package]]
name = "serde_html_form"
version = "0.2.7"
@ -6746,15 +6793,6 @@ dependencies = [
"time",
]
[[package]]
name = "simsimd"
version = "4.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4"
dependencies = [
"cc",
]
[[package]]
name = "siphasher"
version = "0.3.11"
@ -6841,6 +6879,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "string_cache"
version = "0.8.7"
@ -7098,7 +7142,7 @@ dependencies = [
"tantivy-stacker",
"tantivy-tokenizer-api",
"tempfile",
"thiserror 2.0.9",
"thiserror 2.0.12",
"time",
"uuid",
"winapi",
@ -7280,11 +7324,11 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.9"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
dependencies = [
"thiserror-impl 2.0.9",
"thiserror-impl 2.0.12",
]
[[package]]
@ -7300,9 +7344,9 @@ dependencies = [
[[package]]
name = "thiserror-impl"
version = "2.0.9"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
@ -7823,7 +7867,7 @@ checksum = "7a94b0f0954b3e59bfc2c246b4c8574390d94a4ad4ad246aaf2fb07d7dfd3b47"
dependencies = [
"proc-macro2",
"quote",
"serde_derive_internals",
"serde_derive_internals 0.28.0",
"syn 2.0.94",
]

View File

@ -0,0 +1,54 @@
use diesel::sqlite::SqliteConnection;
use flowy_error::FlowyResult;
use flowy_sqlite::upsert::excluded;
use flowy_sqlite::{
diesel,
query_dsl::*,
schema::{local_ai_model_table, local_ai_model_table::dsl},
ExpressionMethods, Identifiable, Insertable, Queryable,
};
#[derive(Clone, Default, Queryable, Insertable, Identifiable)]
#[diesel(table_name = local_ai_model_table)]
#[diesel(primary_key(name))]
pub struct LocalAIModelTable {
pub name: String,
pub model_type: i16,
}
#[derive(Clone, Debug, Copy)]
pub enum ModelType {
Embedding = 0,
Chat = 1,
}
impl From<i16> for ModelType {
fn from(value: i16) -> Self {
match value {
0 => ModelType::Embedding,
1 => ModelType::Chat,
_ => ModelType::Embedding,
}
}
}
pub fn select_local_ai_model(conn: &mut SqliteConnection, name: &str) -> Option<LocalAIModelTable> {
local_ai_model_table::table
.filter(dsl::name.eq(name))
.first::<LocalAIModelTable>(conn)
.ok()
}
pub fn upsert_local_ai_model(
conn: &mut SqliteConnection,
row: &LocalAIModelTable,
) -> FlowyResult<()> {
diesel::insert_into(local_ai_model_table::table)
.values(row)
.on_conflict(local_ai_model_table::name)
.do_update()
.set((local_ai_model_table::model_type.eq(excluded(local_ai_model_table::model_type)),))
.execute(conn)?;
Ok(())
}

View File

@ -1,5 +1,7 @@
mod chat_message_sql;
mod chat_sql;
mod local_model_sql;
pub use chat_message_sql::*;
pub use chat_sql::*;
pub use local_model_sql::*;

View File

@ -48,16 +48,16 @@ collab-integrate.workspace = true
[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies]
notify = "6.1.1"
ollama-rs = "0.3.0"
#faiss = { version = "0.12.1" }
af-mcp = { version = "0.1.0" }
[dev-dependencies]
dotenv = "0.15.0"
uuid.workspace = true
tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] }
simsimd = "4.4.0"
[build-dependencies]
flowy-codegen.workspace = true
[features]
dart = ["flowy-codegen/dart", "flowy-notification/dart"]
local_ai = []

View File

@ -330,14 +330,10 @@ impl AIManager {
.get_question_id_from_answer_id(chat_id, answer_message_id)
.await?;
let model = model.map_or_else(
|| {
self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(&chat_id.to_string()))
},
|model| Some(model.into()),
);
let model = match model {
None => self.get_active_model(&chat_id.to_string()).await,
Some(model) => Some(model.into()),
};
chat
.stream_regenerate_response(question_message_id, answer_stream_port, format, model)
.await?;
@ -354,9 +350,10 @@ impl AIManager {
"[AI Plugin] update global active model, previous: {}, current: {}",
previous_model, current_model
);
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
let model = AIModel::local(current_model, "".to_string());
self.update_selected_model(source_key, model).await?;
self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
.await?;
}
Ok(())
@ -440,11 +437,11 @@ impl AIManager {
}
pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> {
info!(
"[Model Selection] update {} selected model: {:?}",
source, model
);
let source_key = ai_available_models_key(&source);
info!(
"[Model Selection] update {} selected model: {:?} for key:{}",
source, model, source_key
);
self
.store_preferences
.set_object::<AIModel>(&source_key, &model)?;
@ -458,12 +455,13 @@ impl AIManager {
#[instrument(skip_all, level = "debug")]
pub async fn toggle_local_ai(&self) -> FlowyResult<()> {
let enabled = self.local_ai.toggle_local_ai().await?;
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
if enabled {
if let Some(name) = self.local_ai.get_plugin_chat_model() {
info!("Set global active model to local ai: {}", name);
let model = AIModel::local(name, "".to_string());
self.update_selected_model(source_key, model).await?;
self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
.await?;
}
} else {
info!("Set global active model to default");
@ -471,7 +469,7 @@ impl AIManager {
let models = self.get_server_available_models().await?;
if let Some(model) = models.into_iter().find(|m| m.name == global_active_model) {
self
.update_selected_model(source_key, AIModel::from(model))
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), AIModel::from(model))
.await?;
}
}
@ -484,21 +482,31 @@ impl AIManager {
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(source));
if model.is_none() {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = Some(AIModel::local(local_model, "".to_string()));
}
match model {
None => {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = Some(AIModel::local(local_model, "".to_string()));
}
model
},
Some(mut model) => {
let models = self.local_ai.get_all_chat_local_models().await;
if !models.contains(&model) {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = AIModel::local(local_model, "".to_string());
}
}
Some(model)
},
}
model
}
pub async fn get_available_models(&self, source: String) -> FlowyResult<AvailableModelsPB> {
let is_local_mode = self.user_service.is_local_model().await?;
if is_local_mode {
let setting = self.local_ai.get_local_ai_setting();
let models = self.local_ai.get_all_chat_local_models().await;
let selected_model = AIModel::local(setting.chat_model_name, "".to_string());
let models = vec![selected_model.clone()];
Ok(AvailableModelsPB {
models: models.into_iter().map(|m| m.into()).collect(),
@ -506,27 +514,24 @@ impl AIManager {
})
} else {
// Build the models list from server models and mark them as non-local.
let mut models: Vec<AIModel> = self
let mut all_models: Vec<AIModel> = self
.get_server_available_models()
.await?
.into_iter()
.map(AIModel::from)
.collect();
trace!("[Model Selection]: Available models: {:?}", models);
let mut current_active_local_ai_model = None;
trace!("[Model Selection]: Available models: {:?}", all_models);
// If user enable local ai, then add local ai model to the list.
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
let model = AIModel::local(local_model, "".to_string());
current_active_local_ai_model = Some(model.clone());
trace!("[Model Selection] current local ai model: {}", model.name);
models.push(model);
if self.local_ai.is_enabled() {
let local_models = self.local_ai.get_all_chat_local_models().await;
all_models.extend(local_models.into_iter().map(|m| m));
}
if models.is_empty() {
if all_models.is_empty() {
return Ok(AvailableModelsPB {
models: models.into_iter().map(|m| m.into()).collect(),
models: all_models.into_iter().map(|m| m.into()).collect(),
selected_model: AIModelPB::default(),
});
}
@ -545,37 +550,29 @@ impl AIManager {
let mut user_selected_model = server_active_model.clone();
// when current select model is deprecated, reset the model to default
if !models.iter().any(|m| m.name == server_active_model.name) {
if !all_models
.iter()
.any(|m| m.name == server_active_model.name)
{
server_active_model = AIModel::default();
}
let source_key = ai_available_models_key(&source);
// We use source to identify user selected model. source can be document id or chat id.
match self.store_preferences.get_object::<AIModel>(&source_key) {
match self.get_active_model(&source).await {
None => {
// when there is selected model and current local ai is active, then use local ai
if let Some(local_ai_model) = models.iter().find(|m| m.is_local) {
if let Some(local_ai_model) = all_models.iter().find(|m| m.is_local) {
user_selected_model = local_ai_model.clone();
}
},
Some(mut model) => {
Some(model) => {
trace!("[Model Selection] user previous select model: {:?}", model);
// If source is provided, try to get the user-selected model from the store. User selected
// model will be used as the active model if it exists.
if model.is_local {
if let Some(local_ai_model) = &current_active_local_ai_model {
if local_ai_model.name != model.name {
model = local_ai_model.clone();
}
}
}
user_selected_model = model;
},
}
// If user selected model is not available in the list, use the global active model.
let active_model = models
let active_model = all_models
.iter()
.find(|m| m.name == user_selected_model.name)
.cloned()
@ -585,15 +582,15 @@ impl AIManager {
if let Some(ref active_model) = active_model {
if active_model.name != user_selected_model.name {
self
.store_preferences
.set_object::<AIModel>(&source_key, &active_model.clone())?;
.update_selected_model(source, active_model.clone())
.await?;
}
}
trace!("[Model Selection] final active model: {:?}", active_model);
let selected_model = AIModelPB::from(active_model.unwrap_or_default());
Ok(AvailableModelsPB {
models: models.into_iter().map(|m| m.into()).collect(),
models: all_models.into_iter().map(|m| m.into()).collect(),
selected_model,
})
}

View File

@ -686,7 +686,7 @@ pub struct LocalAISettingPB {
#[pb(index = 2)]
#[validate(custom(function = "required_not_empty_str"))]
pub chat_model_name: String,
pub default_model: String,
#[pb(index = 3)]
#[validate(custom(function = "required_not_empty_str"))]
@ -697,7 +697,7 @@ impl From<LocalAISetting> for LocalAISettingPB {
fn from(value: LocalAISetting) -> Self {
LocalAISettingPB {
server_url: value.ollama_server_url,
chat_model_name: value.chat_model_name,
default_model: value.chat_model_name,
embedding_model_name: value.embedding_model_name,
}
}
@ -707,7 +707,7 @@ impl From<LocalAISettingPB> for LocalAISetting {
fn from(value: LocalAISettingPB) -> Self {
LocalAISetting {
ollama_server_url: value.server_url,
chat_model_name: value.chat_model_name,
chat_model_name: value.default_model,
embedding_model_name: value.embedding_model_name,
}
}

View File

@ -1,7 +1,6 @@
use crate::ai_manager::{AIManager, GLOBAL_ACTIVE_MODEL_KEY};
use crate::completion::AICompletion;
use crate::entities::*;
use crate::util::ai_available_models_key;
use flowy_ai_pub::cloud::{AIModel, ChatMessageType};
use flowy_error::{ErrorCode, FlowyError, FlowyResult};
use lib_dispatch::prelude::{data_result_ok, AFPluginData, AFPluginState, DataResult};
@ -82,8 +81,9 @@ pub(crate) async fn get_server_model_list_handler(
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<AvailableModelsPB, FlowyError> {
let ai_manager = upgrade_ai_manager(ai_manager)?;
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
let models = ai_manager.get_available_models(source_key).await?;
let models = ai_manager
.get_available_models(GLOBAL_ACTIVE_MODEL_KEY.to_string())
.await?;
data_result_ok(models)
}

View File

@ -16,9 +16,15 @@ use af_local_ai::ollama_plugin::OllamaAIPlugin;
use af_plugin::core::path::is_plugin_ready;
use af_plugin::core::plugin::RunningState;
use arc_swap::ArcSwapOption;
use flowy_ai_pub::cloud::AIModel;
use flowy_ai_pub::persistence::{
select_local_ai_model, upsert_local_ai_model, LocalAIModelTable, ModelType,
};
use flowy_ai_pub::user_service::AIUserService;
use futures_util::SinkExt;
use lib_infra::util::get_operating_system;
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
use ollama_rs::Ollama;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::path::PathBuf;
@ -39,8 +45,8 @@ impl Default for LocalAISetting {
fn default() -> Self {
Self {
ollama_server_url: "http://localhost:11434".to_string(),
chat_model_name: "llama3.1".to_string(),
embedding_model_name: "nomic-embed-text".to_string(),
chat_model_name: "llama3.1:latest".to_string(),
embedding_model_name: "nomic-embed-text:latest".to_string(),
}
}
}
@ -53,6 +59,7 @@ pub struct LocalAIController {
current_chat_id: ArcSwapOption<Uuid>,
store_preferences: Weak<KVStorePreferences>,
user_service: Arc<dyn AIUserService>,
ollama: ArcSwapOption<Ollama>,
}
impl Deref for LocalAIController {
@ -83,69 +90,80 @@ impl LocalAIController {
user_service.clone(),
res_impl,
));
// Subscribe to state changes
let mut running_state_rx = local_ai.subscribe_running_state();
let cloned_llm_res = Arc::clone(&local_ai_resource);
let cloned_store_preferences = store_preferences.clone();
let cloned_local_ai = Arc::clone(&local_ai);
let cloned_user_service = Arc::clone(&user_service);
let ollama = ArcSwapOption::default();
let sys = get_operating_system();
if sys.is_desktop() {
let setting = local_ai_resource.get_llm_setting();
ollama.store(
Ollama::try_new(&setting.ollama_server_url)
.map(Arc::new)
.ok(),
);
// Spawn a background task to listen for plugin state changes
tokio::spawn(async move {
while let Some(state) = running_state_rx.next().await {
// Skip if we cant get workspace_id
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
continue;
};
// Subscribe to state changes
let mut running_state_rx = local_ai.subscribe_running_state();
let cloned_llm_res = Arc::clone(&local_ai_resource);
let cloned_store_preferences = store_preferences.clone();
let cloned_local_ai = Arc::clone(&local_ai);
let cloned_user_service = Arc::clone(&user_service);
let key = local_ai_enabled_key(&workspace_id.to_string());
info!("[AI Plugin] state: {:?}", state);
// Read whether plugin is enabled from store; default to true
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
// Only check resource status if the plugin isnt in "UnexpectedStop" and is enabled
let (plugin_downloaded, lack_of_resource) =
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
// Possibly check plugin readiness and resource concurrency in parallel,
// but here we do it sequentially for clarity.
let downloaded = is_plugin_ready();
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
(downloaded, resource_lack)
} else {
(false, None)
};
// If plugin is running, retrieve version
let plugin_version = if matches!(state, RunningState::Running { .. }) {
match cloned_local_ai.plugin_info().await {
Ok(info) => Some(info.version),
Err(_) => None,
}
} else {
None
// Spawn a background task to listen for plugin state changes
tokio::spawn(async move {
while let Some(state) = running_state_rx.next().await {
// Skip if we can't get workspace_id
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
continue;
};
// Broadcast the new local AI state
let new_state = RunningStatePB::from(state);
chat_notification_builder(
APPFLOWY_AI_NOTIFICATION_KEY,
ChatNotification::UpdateLocalAIState,
)
.payload(LocalAIPB {
enabled,
plugin_downloaded,
lack_of_resource,
state: new_state,
plugin_version,
})
.send();
} else {
warn!("[AI Plugin] store preferences is dropped");
let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string());
info!("[AI Plugin] state: {:?}", state);
// Read whether plugin is enabled from store; default to true
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
// Only check resource status if the plugin isn't in "UnexpectedStop" and is enabled
let (plugin_downloaded, lack_of_resource) =
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
// Possibly check plugin readiness and resource concurrency in parallel,
// but here we do it sequentially for clarity.
let downloaded = is_plugin_ready();
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
(downloaded, resource_lack)
} else {
(false, None)
};
// If plugin is running, retrieve version
let plugin_version = if matches!(state, RunningState::Running { .. }) {
match cloned_local_ai.plugin_info().await {
Ok(info) => Some(info.version),
Err(_) => None,
}
} else {
None
};
// Broadcast the new local AI state
let new_state = RunningStatePB::from(state);
chat_notification_builder(
APPFLOWY_AI_NOTIFICATION_KEY,
ChatNotification::UpdateLocalAIState,
)
.payload(LocalAIPB {
enabled,
plugin_downloaded,
lack_of_resource,
state: new_state,
plugin_version,
})
.send();
} else {
warn!("[AI Plugin] store preferences is dropped");
}
}
}
});
});
}
Self {
ai_plugin: local_ai,
@ -153,6 +171,7 @@ impl LocalAIController {
current_chat_id: ArcSwapOption::default(),
store_preferences,
user_service,
ollama,
}
}
#[instrument(level = "debug", skip_all)]
@ -287,6 +306,78 @@ impl LocalAIController {
self.resource.get_llm_setting()
}
pub async fn get_all_chat_local_models(&self) -> Vec<AIModel> {
self
.get_filtered_local_models(|name| !name.contains("embed"))
.await
}
pub async fn get_all_embedded_local_models(&self) -> Vec<AIModel> {
self
.get_filtered_local_models(|name| name.contains("embed"))
.await
}
// Helper function to avoid code duplication in model retrieval
async fn get_filtered_local_models<F>(&self, filter_fn: F) -> Vec<AIModel>
where
F: Fn(&str) -> bool,
{
match self.ollama.load_full() {
None => vec![],
Some(ollama) => ollama
.list_local_models()
.await
.map(|models| {
models
.into_iter()
.filter(|m| filter_fn(&m.name.to_lowercase()))
.map(|m| AIModel::local(m.name, String::new()))
.collect()
})
.unwrap_or_default(),
}
}
pub async fn check_model_type(&self, model_name: &str) -> FlowyResult<ModelType> {
let uid = self.user_service.user_id()?;
let mut conn = self.user_service.sqlite_connection(uid)?;
match select_local_ai_model(&mut conn, model_name) {
None => {
let ollama = self
.ollama
.load_full()
.ok_or_else(|| FlowyError::local_ai().with_context("ollama is not initialized"))?;
let request = GenerateEmbeddingsRequest::new(
model_name.to_string(),
EmbeddingsInput::Single("Hello".to_string()),
);
let model_type = match ollama.generate_embeddings(request).await {
Ok(value) => {
if value.embeddings.is_empty() {
ModelType::Chat
} else {
ModelType::Embedding
}
},
Err(_) => ModelType::Chat,
};
upsert_local_ai_model(
&mut conn,
&LocalAIModelTable {
name: model_name.to_string(),
model_type: model_type as i16,
},
)?;
Ok(model_type)
},
Some(r) => Ok(ModelType::from(r.model_type)),
}
}
pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> {
info!(
"[AI Plugin] update local ai setting: {:?}, thread: {:?}",

View File

@ -161,7 +161,6 @@ impl LocalAIResourceController {
let setting = self.get_llm_setting();
let client = Client::builder().timeout(Duration::from_secs(5)).build()?;
match client.get(&setting.ollama_server_url).send().await {
Ok(resp) if resp.status().is_success() => {
info!(

View File

@ -36,6 +36,10 @@ client-api = { workspace = true, optional = true }
tantivy = { workspace = true, optional = true }
uuid.workspace = true
[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies]
ollama-rs = "0.3.0"
[features]
default = ["impl_from_dispatch_error", "impl_from_serde", "impl_from_reqwest", "impl_from_sqlite"]
impl_from_dispatch_error = ["lib-dispatch"]

View File

@ -264,3 +264,10 @@ impl From<uuid::Error> for FlowyError {
FlowyError::internal().with_context(value)
}
}
#[cfg(any(target_os = "windows", target_os = "macos", target_os = "linux"))]
impl From<ollama_rs::error::OllamaError> for FlowyError {
fn from(value: ollama_rs::error::OllamaError) -> Self {
FlowyError::local_ai().with_context(value)
}
}

View File

@ -0,0 +1 @@
-- This file should undo anything in `up.sql`

View File

@ -0,0 +1,6 @@
-- Your SQL goes here
CREATE TABLE local_ai_model_table
(
name TEXT PRIMARY KEY NOT NULL,
model_type SMALLINT NOT NULL
);

View File

@ -54,6 +54,13 @@ diesel::table! {
}
}
diesel::table! {
local_ai_model_table (name) {
name -> Text,
model_type -> SmallInt,
}
}
diesel::table! {
upload_file_part (upload_id, e_tag) {
upload_id -> Text,
@ -133,16 +140,17 @@ diesel::table! {
}
diesel::allow_tables_to_appear_in_same_query!(
af_collab_metadata,
chat_local_setting_table,
chat_message_table,
chat_table,
collab_snapshot,
upload_file_part,
upload_file_table,
user_data_migration_records,
user_table,
user_workspace_table,
workspace_members_table,
workspace_setting_table,
af_collab_metadata,
chat_local_setting_table,
chat_message_table,
chat_table,
collab_snapshot,
local_ai_model_table,
upload_file_part,
upload_file_table,
user_data_migration_records,
user_table,
user_workspace_table,
workspace_members_table,
workspace_setting_table,
);