mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2025-07-31 21:11:35 +00:00

* refactor: model select * refactor: add test * fix: add source * fix: add source * chore: notify all unset source * chore: fix test
472 lines
13 KiB
Rust
472 lines
13 KiB
Rust
use crate::local_ai::controller::LocalAIController;
|
||
use arc_swap::ArcSwapOption;
|
||
use flowy_ai_pub::cloud::{AIModel, ChatCloudService};
|
||
use flowy_error::{FlowyError, FlowyResult};
|
||
use flowy_sqlite::kv::KVStorePreferences;
|
||
use lib_infra::async_trait::async_trait;
|
||
use lib_infra::util::timestamp;
|
||
use std::collections::HashSet;
|
||
use std::sync::Arc;
|
||
use tokio::sync::RwLock;
|
||
use tracing::{error, info, trace};
|
||
use uuid::Uuid;
|
||
|
||
type Model = AIModel;
|
||
pub const GLOBAL_ACTIVE_MODEL_KEY: &str = "global_active_model";
|
||
|
||
/// Manages multiple sources and provides operations for model selection
|
||
pub struct ModelSelectionControl {
|
||
sources: Vec<Box<dyn ModelSource>>,
|
||
default_model: Model,
|
||
local_storage: ArcSwapOption<Box<dyn UserModelStorage>>,
|
||
server_storage: ArcSwapOption<Box<dyn UserModelStorage>>,
|
||
unset_sources: RwLock<HashSet<String>>,
|
||
}
|
||
|
||
impl ModelSelectionControl {
|
||
/// Create a new manager with the given storage backends
|
||
pub fn new() -> Self {
|
||
let default_model = Model::default();
|
||
Self {
|
||
sources: Vec::new(),
|
||
default_model,
|
||
local_storage: ArcSwapOption::new(None),
|
||
server_storage: ArcSwapOption::new(None),
|
||
unset_sources: Default::default(),
|
||
}
|
||
}
|
||
|
||
/// Replace the local storage backend at runtime
|
||
pub fn set_local_storage(&self, storage: impl UserModelStorage + 'static) {
|
||
self.local_storage.store(Some(Arc::new(Box::new(storage))));
|
||
}
|
||
|
||
/// Replace the server storage backend at runtime
|
||
pub fn set_server_storage(&self, storage: impl UserModelStorage + 'static) {
|
||
self.server_storage.store(Some(Arc::new(Box::new(storage))));
|
||
}
|
||
|
||
/// Add a new model source at runtime
|
||
pub fn add_source(&mut self, source: Box<dyn ModelSource>) {
|
||
info!("[Model Selection] Adding source: {}", source.source_name());
|
||
// remove existing source with the same name
|
||
self
|
||
.sources
|
||
.retain(|s| s.source_name() != source.source_name());
|
||
|
||
self.sources.push(source);
|
||
}
|
||
|
||
/// Remove all sources matching the given name
|
||
pub fn remove_local_source(&mut self) {
|
||
info!("[Model Selection] Removing local source");
|
||
self
|
||
.sources
|
||
.retain(|source| source.source_name() != "local");
|
||
}
|
||
|
||
/// Asynchronously aggregate models from all sources, or return the default if none found
|
||
pub async fn get_models(&self, workspace_id: &Uuid) -> Vec<Model> {
|
||
let mut models = Vec::new();
|
||
for source in &self.sources {
|
||
let mut list = source.list_chat_models(workspace_id).await;
|
||
models.append(&mut list);
|
||
}
|
||
if models.is_empty() {
|
||
vec![self.default_model.clone()]
|
||
} else {
|
||
models
|
||
}
|
||
}
|
||
|
||
/// Fetches all server‐side models and, if specified, a single local model by name.
|
||
///
|
||
/// First collects models from any source named `"server"`. Then it fetches all local models
|
||
/// (from the `"local"` source) and:
|
||
/// - If `local_model_name` is `Some(name)`, it will append exactly that local model
|
||
/// if it exists.
|
||
/// - If `local_model_name` is `None`, it will append *all* local models.
|
||
///
|
||
pub async fn get_models_with_specific_local_model(
|
||
&self,
|
||
workspace_id: &Uuid,
|
||
local_model_name: Option<String>,
|
||
) -> Vec<Model> {
|
||
let mut models = Vec::new();
|
||
// add server models
|
||
for source in &self.sources {
|
||
if source.source_name() == "server" {
|
||
let mut list = source.list_chat_models(workspace_id).await;
|
||
models.append(&mut list);
|
||
}
|
||
}
|
||
|
||
// check input local model present in local models
|
||
let local_models = self.get_local_models(workspace_id).await;
|
||
match local_model_name {
|
||
Some(name) => {
|
||
local_models.into_iter().for_each(|model| {
|
||
if model.name == name {
|
||
models.push(model);
|
||
}
|
||
});
|
||
},
|
||
None => {
|
||
models.extend(local_models);
|
||
},
|
||
}
|
||
|
||
models
|
||
}
|
||
|
||
pub async fn get_local_models(&self, workspace_id: &Uuid) -> Vec<Model> {
|
||
for source in &self.sources {
|
||
if source.source_name() == "local" {
|
||
return source.list_chat_models(workspace_id).await;
|
||
}
|
||
}
|
||
vec![]
|
||
}
|
||
|
||
pub async fn get_all_unset_sources(&self) -> Vec<String> {
|
||
let unset_sources = self.unset_sources.read().await;
|
||
unset_sources.iter().cloned().collect()
|
||
}
|
||
|
||
pub async fn get_global_active_model(&self, workspace_id: &Uuid) -> Model {
|
||
self
|
||
.get_active_model(
|
||
workspace_id,
|
||
&SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string()),
|
||
)
|
||
.await
|
||
}
|
||
|
||
/// Retrieves the active model: first tries local storage, then server storage. Ensures validity in the model list.
|
||
/// If neither storage yields a valid model, falls back to default.
|
||
pub async fn get_active_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Model {
|
||
let available = self.get_models(workspace_id).await;
|
||
// Try local storage
|
||
if let Some(storage) = self.local_storage.load_full() {
|
||
trace!("[Model Selection] Checking local storage");
|
||
if let Some(local_model) = storage.get_selected_model(workspace_id, source_key).await {
|
||
trace!("[Model Selection] Found local model: {}", local_model.name);
|
||
if available.iter().any(|m| m.name == local_model.name) {
|
||
return local_model;
|
||
} else {
|
||
trace!(
|
||
"[Model Selection] Local {} not found in available list, available: {:?}",
|
||
local_model.name,
|
||
available.iter().map(|m| &m.name).collect::<Vec<_>>()
|
||
);
|
||
}
|
||
} else {
|
||
self
|
||
.unset_sources
|
||
.write()
|
||
.await
|
||
.insert(source_key.key.clone());
|
||
}
|
||
}
|
||
|
||
// use local model if user doesn't set the model for given source
|
||
if self
|
||
.sources
|
||
.iter()
|
||
.any(|source| source.source_name() == "local")
|
||
{
|
||
trace!("[Model Selection] Checking global active model");
|
||
let global_source = SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string());
|
||
if let Some(storage) = self.local_storage.load_full() {
|
||
if let Some(local_model) = storage
|
||
.get_selected_model(workspace_id, &global_source)
|
||
.await
|
||
{
|
||
trace!(
|
||
"[Model Selection] Found global active model: {}",
|
||
local_model.name
|
||
);
|
||
if available.iter().any(|m| m.name == local_model.name) {
|
||
return local_model;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Try server storage
|
||
if let Some(storage) = self.server_storage.load_full() {
|
||
trace!("[Model Selection] Checking server storage");
|
||
if let Some(server_model) = storage.get_selected_model(workspace_id, source_key).await {
|
||
trace!(
|
||
"[Model Selection] Found server model: {}",
|
||
server_model.name
|
||
);
|
||
if available.iter().any(|m| m.name == server_model.name) {
|
||
return server_model;
|
||
} else {
|
||
trace!(
|
||
"[Model Selection] Server {} not found in available list, available: {:?}",
|
||
server_model.name,
|
||
available.iter().map(|m| &m.name).collect::<Vec<_>>()
|
||
);
|
||
}
|
||
}
|
||
}
|
||
// Fallback: default
|
||
info!(
|
||
"[Model Selection] No active model found, using default: {}",
|
||
self.default_model.name
|
||
);
|
||
self.default_model.clone()
|
||
}
|
||
|
||
/// Sets the active model in both local and server storage
|
||
pub async fn set_active_model(
|
||
&self,
|
||
workspace_id: &Uuid,
|
||
source_key: &SourceKey,
|
||
model: Model,
|
||
) -> Result<(), FlowyError> {
|
||
info!(
|
||
"[Model Selection] active model: {} for source: {}",
|
||
model.name, source_key.key
|
||
);
|
||
self.unset_sources.write().await.remove(&source_key.key);
|
||
|
||
let available = self.get_models(workspace_id).await;
|
||
if available.contains(&model) {
|
||
// Update local storage
|
||
if let Some(storage) = self.local_storage.load_full() {
|
||
storage
|
||
.set_selected_model(workspace_id, source_key, model.clone())
|
||
.await?;
|
||
}
|
||
|
||
// Update server storage
|
||
if let Some(storage) = self.server_storage.load_full() {
|
||
storage
|
||
.set_selected_model(workspace_id, source_key, model)
|
||
.await?;
|
||
}
|
||
Ok(())
|
||
} else {
|
||
Err(
|
||
FlowyError::internal()
|
||
.with_context(format!("Model '{:?}' not found in available list", model)),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Namespaced key for model selection storage
|
||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||
pub struct SourceKey {
|
||
key: String,
|
||
}
|
||
|
||
impl SourceKey {
|
||
/// Create a new SourceKey
|
||
pub fn new(key: String) -> Self {
|
||
Self { key }
|
||
}
|
||
|
||
/// Combine the UUID key with a model's is_local flag and name to produce a storage identifier
|
||
pub fn storage_id(&self) -> String {
|
||
format!("ai_models_{}", self.key)
|
||
}
|
||
}
|
||
|
||
/// A trait that defines an asynchronous source of AI models
|
||
#[async_trait]
|
||
pub trait ModelSource: Send + Sync {
|
||
/// Identifier for this source (e.g., "local" or "server")
|
||
fn source_name(&self) -> &'static str;
|
||
|
||
/// Asynchronously returns a list of available models from this source
|
||
async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec<Model>;
|
||
}
|
||
|
||
pub struct LocalAiSource {
|
||
controller: Arc<LocalAIController>,
|
||
}
|
||
|
||
impl LocalAiSource {
|
||
pub fn new(controller: Arc<LocalAIController>) -> Self {
|
||
Self { controller }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl ModelSource for LocalAiSource {
|
||
fn source_name(&self) -> &'static str {
|
||
"local"
|
||
}
|
||
|
||
async fn list_chat_models(&self, _workspace_id: &Uuid) -> Vec<Model> {
|
||
match self.controller.ollama.load_full() {
|
||
None => vec![],
|
||
Some(ollama) => ollama
|
||
.list_local_models()
|
||
.await
|
||
.map(|models| {
|
||
models
|
||
.into_iter()
|
||
.filter(|m| !m.name.contains("embed"))
|
||
.map(|m| AIModel::local(m.name, String::new()))
|
||
.collect()
|
||
})
|
||
.unwrap_or_default(),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// A server-side AI source (e.g., cloud API)
|
||
#[derive(Debug, Default)]
|
||
struct ServerModelsCache {
|
||
models: Vec<Model>,
|
||
timestamp: Option<i64>,
|
||
}
|
||
|
||
pub struct ServerAiSource {
|
||
cached_models: Arc<RwLock<ServerModelsCache>>,
|
||
cloud_service: Arc<dyn ChatCloudService>,
|
||
}
|
||
|
||
impl ServerAiSource {
|
||
pub fn new(cloud_service: Arc<dyn ChatCloudService>) -> Self {
|
||
Self {
|
||
cached_models: Arc::new(Default::default()),
|
||
cloud_service,
|
||
}
|
||
}
|
||
|
||
async fn update_models_cache(&self, models: &[Model], timestamp: i64) -> FlowyResult<()> {
|
||
match self.cached_models.try_write() {
|
||
Ok(mut cache) => {
|
||
cache.models = models.to_vec();
|
||
cache.timestamp = Some(timestamp);
|
||
Ok(())
|
||
},
|
||
Err(_) => {
|
||
Err(FlowyError::internal().with_context("Failed to acquire write lock for models cache"))
|
||
},
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl ModelSource for ServerAiSource {
|
||
fn source_name(&self) -> &'static str {
|
||
"server"
|
||
}
|
||
|
||
async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec<Model> {
|
||
let now = timestamp();
|
||
let should_fetch = {
|
||
let cached = self.cached_models.read().await;
|
||
cached.models.is_empty() || cached.timestamp.map_or(true, |ts| now - ts >= 300)
|
||
};
|
||
if !should_fetch {
|
||
return self.cached_models.read().await.models.clone();
|
||
}
|
||
match self.cloud_service.get_available_models(workspace_id).await {
|
||
Ok(resp) => {
|
||
let models = resp
|
||
.models
|
||
.into_iter()
|
||
.map(AIModel::from)
|
||
.collect::<Vec<_>>();
|
||
if let Err(e) = self.update_models_cache(&models, now).await {
|
||
error!("Failed to update cache: {}", e);
|
||
}
|
||
models
|
||
},
|
||
Err(err) => {
|
||
error!("Failed to fetch models: {}", err);
|
||
let cached = self.cached_models.read().await;
|
||
if !cached.models.is_empty() {
|
||
info!("Returning expired cache due to error");
|
||
return cached.models.clone();
|
||
}
|
||
Vec::new()
|
||
},
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
pub trait UserModelStorage: Send + Sync {
|
||
async fn get_selected_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Option<Model>;
|
||
async fn set_selected_model(
|
||
&self,
|
||
workspace_id: &Uuid,
|
||
source_key: &SourceKey,
|
||
model: Model,
|
||
) -> Result<(), FlowyError>;
|
||
}
|
||
|
||
pub struct ServerModelStorageImpl(pub Arc<dyn ChatCloudService>);
|
||
|
||
#[async_trait]
|
||
impl UserModelStorage for ServerModelStorageImpl {
|
||
async fn get_selected_model(
|
||
&self,
|
||
workspace_id: &Uuid,
|
||
_source_key: &SourceKey,
|
||
) -> Option<Model> {
|
||
let name = self
|
||
.0
|
||
.get_workspace_default_model(workspace_id)
|
||
.await
|
||
.ok()?;
|
||
Some(Model::server(name, String::new()))
|
||
}
|
||
|
||
async fn set_selected_model(
|
||
&self,
|
||
workspace_id: &Uuid,
|
||
source_key: &SourceKey,
|
||
model: Model,
|
||
) -> Result<(), FlowyError> {
|
||
if model.is_local {
|
||
// local model does not need to be set
|
||
return Ok(());
|
||
}
|
||
|
||
if source_key.key != GLOBAL_ACTIVE_MODEL_KEY {
|
||
return Ok(());
|
||
}
|
||
|
||
self
|
||
.0
|
||
.set_workspace_default_model(workspace_id, &model.name)
|
||
.await?;
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
pub struct LocalModelStorageImpl(pub Arc<KVStorePreferences>);
|
||
|
||
#[async_trait]
|
||
impl UserModelStorage for LocalModelStorageImpl {
|
||
async fn get_selected_model(
|
||
&self,
|
||
_workspace_id: &Uuid,
|
||
source_key: &SourceKey,
|
||
) -> Option<Model> {
|
||
self.0.get_object::<AIModel>(&source_key.storage_id())
|
||
}
|
||
|
||
async fn set_selected_model(
|
||
&self,
|
||
_workspace_id: &Uuid,
|
||
source_key: &SourceKey,
|
||
model: Model,
|
||
) -> Result<(), FlowyError> {
|
||
self
|
||
.0
|
||
.set_object::<AIModel>(&source_key.storage_id(), &model)?;
|
||
Ok(())
|
||
}
|
||
}
|