use crate::local_ai::controller::LocalAIController; use arc_swap::ArcSwapOption; use flowy_ai_pub::cloud::{AIModel, ChatCloudService}; use flowy_error::{FlowyError, FlowyResult}; use flowy_sqlite::kv::KVStorePreferences; use lib_infra::async_trait::async_trait; use lib_infra::util::timestamp; use std::collections::HashSet; use std::sync::Arc; use tokio::sync::RwLock; use tracing::{error, info, trace}; use uuid::Uuid; type Model = AIModel; pub const GLOBAL_ACTIVE_MODEL_KEY: &str = "global_active_model"; /// Manages multiple sources and provides operations for model selection pub struct ModelSelectionControl { sources: Vec>, default_model: Model, local_storage: ArcSwapOption>, server_storage: ArcSwapOption>, unset_sources: RwLock>, } impl ModelSelectionControl { /// Create a new manager with the given storage backends pub fn new() -> Self { let default_model = Model::default(); Self { sources: Vec::new(), default_model, local_storage: ArcSwapOption::new(None), server_storage: ArcSwapOption::new(None), unset_sources: Default::default(), } } /// Replace the local storage backend at runtime pub fn set_local_storage(&self, storage: impl UserModelStorage + 'static) { self.local_storage.store(Some(Arc::new(Box::new(storage)))); } /// Replace the server storage backend at runtime pub fn set_server_storage(&self, storage: impl UserModelStorage + 'static) { self.server_storage.store(Some(Arc::new(Box::new(storage)))); } /// Add a new model source at runtime pub fn add_source(&mut self, source: Box) { info!("[Model Selection] Adding source: {}", source.source_name()); // remove existing source with the same name self .sources .retain(|s| s.source_name() != source.source_name()); self.sources.push(source); } /// Remove all sources matching the given name pub fn remove_local_source(&mut self) { info!("[Model Selection] Removing local source"); self .sources .retain(|source| source.source_name() != "local"); } /// Asynchronously aggregate models from all sources, or return the default if none found pub async fn get_models(&self, workspace_id: &Uuid) -> Vec { let mut models = Vec::new(); for source in &self.sources { let mut list = source.list_chat_models(workspace_id).await; models.append(&mut list); } if models.is_empty() { vec![self.default_model.clone()] } else { models } } /// Fetches all server‐side models and, if specified, a single local model by name. /// /// First collects models from any source named `"server"`. Then it fetches all local models /// (from the `"local"` source) and: /// - If `local_model_name` is `Some(name)`, it will append exactly that local model /// if it exists. /// - If `local_model_name` is `None`, it will append *all* local models. /// pub async fn get_models_with_specific_local_model( &self, workspace_id: &Uuid, local_model_name: Option, ) -> Vec { let mut models = Vec::new(); // add server models for source in &self.sources { if source.source_name() == "server" { let mut list = source.list_chat_models(workspace_id).await; models.append(&mut list); } } // check input local model present in local models let local_models = self.get_local_models(workspace_id).await; match local_model_name { Some(name) => { local_models.into_iter().for_each(|model| { if model.name == name { models.push(model); } }); }, None => { models.extend(local_models); }, } models } pub async fn get_local_models(&self, workspace_id: &Uuid) -> Vec { for source in &self.sources { if source.source_name() == "local" { return source.list_chat_models(workspace_id).await; } } vec![] } pub async fn get_all_unset_sources(&self) -> Vec { let unset_sources = self.unset_sources.read().await; unset_sources.iter().cloned().collect() } pub async fn get_global_active_model(&self, workspace_id: &Uuid) -> Model { self .get_active_model( workspace_id, &SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string()), ) .await } /// Retrieves the active model: first tries local storage, then server storage. Ensures validity in the model list. /// If neither storage yields a valid model, falls back to default. pub async fn get_active_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Model { let available = self.get_models(workspace_id).await; // Try local storage if let Some(storage) = self.local_storage.load_full() { trace!("[Model Selection] Checking local storage"); if let Some(local_model) = storage.get_selected_model(workspace_id, source_key).await { trace!("[Model Selection] Found local model: {}", local_model.name); if available.iter().any(|m| m.name == local_model.name) { return local_model; } else { trace!( "[Model Selection] Local {} not found in available list, available: {:?}", local_model.name, available.iter().map(|m| &m.name).collect::>() ); } } else { self .unset_sources .write() .await .insert(source_key.key.clone()); } } // use local model if user doesn't set the model for given source if self .sources .iter() .any(|source| source.source_name() == "local") { trace!("[Model Selection] Checking global active model"); let global_source = SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string()); if let Some(storage) = self.local_storage.load_full() { if let Some(local_model) = storage .get_selected_model(workspace_id, &global_source) .await { trace!( "[Model Selection] Found global active model: {}", local_model.name ); if available.iter().any(|m| m.name == local_model.name) { return local_model; } } } } // Try server storage if let Some(storage) = self.server_storage.load_full() { trace!("[Model Selection] Checking server storage"); if let Some(server_model) = storage.get_selected_model(workspace_id, source_key).await { trace!( "[Model Selection] Found server model: {}", server_model.name ); if available.iter().any(|m| m.name == server_model.name) { return server_model; } else { trace!( "[Model Selection] Server {} not found in available list, available: {:?}", server_model.name, available.iter().map(|m| &m.name).collect::>() ); } } } // Fallback: default info!( "[Model Selection] No active model found, using default: {}", self.default_model.name ); self.default_model.clone() } /// Sets the active model in both local and server storage pub async fn set_active_model( &self, workspace_id: &Uuid, source_key: &SourceKey, model: Model, ) -> Result<(), FlowyError> { info!( "[Model Selection] active model: {} for source: {}", model.name, source_key.key ); self.unset_sources.write().await.remove(&source_key.key); let available = self.get_models(workspace_id).await; if available.contains(&model) { // Update local storage if let Some(storage) = self.local_storage.load_full() { storage .set_selected_model(workspace_id, source_key, model.clone()) .await?; } // Update server storage if let Some(storage) = self.server_storage.load_full() { storage .set_selected_model(workspace_id, source_key, model) .await?; } Ok(()) } else { Err( FlowyError::internal() .with_context(format!("Model '{:?}' not found in available list", model)), ) } } } /// Namespaced key for model selection storage #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SourceKey { key: String, } impl SourceKey { /// Create a new SourceKey pub fn new(key: String) -> Self { Self { key } } /// Combine the UUID key with a model's is_local flag and name to produce a storage identifier pub fn storage_id(&self) -> String { format!("ai_models_{}", self.key) } } /// A trait that defines an asynchronous source of AI models #[async_trait] pub trait ModelSource: Send + Sync { /// Identifier for this source (e.g., "local" or "server") fn source_name(&self) -> &'static str; /// Asynchronously returns a list of available models from this source async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec; } pub struct LocalAiSource { controller: Arc, } impl LocalAiSource { pub fn new(controller: Arc) -> Self { Self { controller } } } #[async_trait] impl ModelSource for LocalAiSource { fn source_name(&self) -> &'static str { "local" } async fn list_chat_models(&self, _workspace_id: &Uuid) -> Vec { match self.controller.ollama.load_full() { None => vec![], Some(ollama) => ollama .list_local_models() .await .map(|models| { models .into_iter() .filter(|m| !m.name.contains("embed")) .map(|m| AIModel::local(m.name, String::new())) .collect() }) .unwrap_or_default(), } } } /// A server-side AI source (e.g., cloud API) #[derive(Debug, Default)] struct ServerModelsCache { models: Vec, timestamp: Option, } pub struct ServerAiSource { cached_models: Arc>, cloud_service: Arc, } impl ServerAiSource { pub fn new(cloud_service: Arc) -> Self { Self { cached_models: Arc::new(Default::default()), cloud_service, } } async fn update_models_cache(&self, models: &[Model], timestamp: i64) -> FlowyResult<()> { match self.cached_models.try_write() { Ok(mut cache) => { cache.models = models.to_vec(); cache.timestamp = Some(timestamp); Ok(()) }, Err(_) => { Err(FlowyError::internal().with_context("Failed to acquire write lock for models cache")) }, } } } #[async_trait] impl ModelSource for ServerAiSource { fn source_name(&self) -> &'static str { "server" } async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec { let now = timestamp(); let should_fetch = { let cached = self.cached_models.read().await; cached.models.is_empty() || cached.timestamp.map_or(true, |ts| now - ts >= 300) }; if !should_fetch { return self.cached_models.read().await.models.clone(); } match self.cloud_service.get_available_models(workspace_id).await { Ok(resp) => { let models = resp .models .into_iter() .map(AIModel::from) .collect::>(); if let Err(e) = self.update_models_cache(&models, now).await { error!("Failed to update cache: {}", e); } models }, Err(err) => { error!("Failed to fetch models: {}", err); let cached = self.cached_models.read().await; if !cached.models.is_empty() { info!("Returning expired cache due to error"); return cached.models.clone(); } Vec::new() }, } } } #[async_trait] pub trait UserModelStorage: Send + Sync { async fn get_selected_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Option; async fn set_selected_model( &self, workspace_id: &Uuid, source_key: &SourceKey, model: Model, ) -> Result<(), FlowyError>; } pub struct ServerModelStorageImpl(pub Arc); #[async_trait] impl UserModelStorage for ServerModelStorageImpl { async fn get_selected_model( &self, workspace_id: &Uuid, _source_key: &SourceKey, ) -> Option { let name = self .0 .get_workspace_default_model(workspace_id) .await .ok()?; Some(Model::server(name, String::new())) } async fn set_selected_model( &self, workspace_id: &Uuid, source_key: &SourceKey, model: Model, ) -> Result<(), FlowyError> { if model.is_local { // local model does not need to be set return Ok(()); } if source_key.key != GLOBAL_ACTIVE_MODEL_KEY { return Ok(()); } self .0 .set_workspace_default_model(workspace_id, &model.name) .await?; Ok(()) } } pub struct LocalModelStorageImpl(pub Arc); #[async_trait] impl UserModelStorage for LocalModelStorageImpl { async fn get_selected_model( &self, _workspace_id: &Uuid, source_key: &SourceKey, ) -> Option { self.0.get_object::(&source_key.storage_id()) } async fn set_selected_model( &self, _workspace_id: &Uuid, source_key: &SourceKey, model: Model, ) -> Result<(), FlowyError> { self .0 .set_object::(&source_key.storage_id(), &model)?; Ok(()) } }