Nathan.fooo 20f3116bd9
refactor: Model select (#7875)
* refactor: model select

* refactor: add test

* fix: add source

* fix: add source

* chore: notify all unset source

* chore: fix test
2025-05-02 15:37:40 +08:00

472 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::local_ai::controller::LocalAIController;
use arc_swap::ArcSwapOption;
use flowy_ai_pub::cloud::{AIModel, ChatCloudService};
use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::kv::KVStorePreferences;
use lib_infra::async_trait::async_trait;
use lib_infra::util::timestamp;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info, trace};
use uuid::Uuid;
type Model = AIModel;
pub const GLOBAL_ACTIVE_MODEL_KEY: &str = "global_active_model";
/// Manages multiple sources and provides operations for model selection
pub struct ModelSelectionControl {
sources: Vec<Box<dyn ModelSource>>,
default_model: Model,
local_storage: ArcSwapOption<Box<dyn UserModelStorage>>,
server_storage: ArcSwapOption<Box<dyn UserModelStorage>>,
unset_sources: RwLock<HashSet<String>>,
}
impl ModelSelectionControl {
/// Create a new manager with the given storage backends
pub fn new() -> Self {
let default_model = Model::default();
Self {
sources: Vec::new(),
default_model,
local_storage: ArcSwapOption::new(None),
server_storage: ArcSwapOption::new(None),
unset_sources: Default::default(),
}
}
/// Replace the local storage backend at runtime
pub fn set_local_storage(&self, storage: impl UserModelStorage + 'static) {
self.local_storage.store(Some(Arc::new(Box::new(storage))));
}
/// Replace the server storage backend at runtime
pub fn set_server_storage(&self, storage: impl UserModelStorage + 'static) {
self.server_storage.store(Some(Arc::new(Box::new(storage))));
}
/// Add a new model source at runtime
pub fn add_source(&mut self, source: Box<dyn ModelSource>) {
info!("[Model Selection] Adding source: {}", source.source_name());
// remove existing source with the same name
self
.sources
.retain(|s| s.source_name() != source.source_name());
self.sources.push(source);
}
/// Remove all sources matching the given name
pub fn remove_local_source(&mut self) {
info!("[Model Selection] Removing local source");
self
.sources
.retain(|source| source.source_name() != "local");
}
/// Asynchronously aggregate models from all sources, or return the default if none found
pub async fn get_models(&self, workspace_id: &Uuid) -> Vec<Model> {
let mut models = Vec::new();
for source in &self.sources {
let mut list = source.list_chat_models(workspace_id).await;
models.append(&mut list);
}
if models.is_empty() {
vec![self.default_model.clone()]
} else {
models
}
}
/// Fetches all serverside models and, if specified, a single local model by name.
///
/// First collects models from any source named `"server"`. Then it fetches all local models
/// (from the `"local"` source) and:
/// - If `local_model_name` is `Some(name)`, it will append exactly that local model
/// if it exists.
/// - If `local_model_name` is `None`, it will append *all* local models.
///
pub async fn get_models_with_specific_local_model(
&self,
workspace_id: &Uuid,
local_model_name: Option<String>,
) -> Vec<Model> {
let mut models = Vec::new();
// add server models
for source in &self.sources {
if source.source_name() == "server" {
let mut list = source.list_chat_models(workspace_id).await;
models.append(&mut list);
}
}
// check input local model present in local models
let local_models = self.get_local_models(workspace_id).await;
match local_model_name {
Some(name) => {
local_models.into_iter().for_each(|model| {
if model.name == name {
models.push(model);
}
});
},
None => {
models.extend(local_models);
},
}
models
}
pub async fn get_local_models(&self, workspace_id: &Uuid) -> Vec<Model> {
for source in &self.sources {
if source.source_name() == "local" {
return source.list_chat_models(workspace_id).await;
}
}
vec![]
}
pub async fn get_all_unset_sources(&self) -> Vec<String> {
let unset_sources = self.unset_sources.read().await;
unset_sources.iter().cloned().collect()
}
pub async fn get_global_active_model(&self, workspace_id: &Uuid) -> Model {
self
.get_active_model(
workspace_id,
&SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string()),
)
.await
}
/// Retrieves the active model: first tries local storage, then server storage. Ensures validity in the model list.
/// If neither storage yields a valid model, falls back to default.
pub async fn get_active_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Model {
let available = self.get_models(workspace_id).await;
// Try local storage
if let Some(storage) = self.local_storage.load_full() {
trace!("[Model Selection] Checking local storage");
if let Some(local_model) = storage.get_selected_model(workspace_id, source_key).await {
trace!("[Model Selection] Found local model: {}", local_model.name);
if available.iter().any(|m| m.name == local_model.name) {
return local_model;
} else {
trace!(
"[Model Selection] Local {} not found in available list, available: {:?}",
local_model.name,
available.iter().map(|m| &m.name).collect::<Vec<_>>()
);
}
} else {
self
.unset_sources
.write()
.await
.insert(source_key.key.clone());
}
}
// use local model if user doesn't set the model for given source
if self
.sources
.iter()
.any(|source| source.source_name() == "local")
{
trace!("[Model Selection] Checking global active model");
let global_source = SourceKey::new(GLOBAL_ACTIVE_MODEL_KEY.to_string());
if let Some(storage) = self.local_storage.load_full() {
if let Some(local_model) = storage
.get_selected_model(workspace_id, &global_source)
.await
{
trace!(
"[Model Selection] Found global active model: {}",
local_model.name
);
if available.iter().any(|m| m.name == local_model.name) {
return local_model;
}
}
}
}
// Try server storage
if let Some(storage) = self.server_storage.load_full() {
trace!("[Model Selection] Checking server storage");
if let Some(server_model) = storage.get_selected_model(workspace_id, source_key).await {
trace!(
"[Model Selection] Found server model: {}",
server_model.name
);
if available.iter().any(|m| m.name == server_model.name) {
return server_model;
} else {
trace!(
"[Model Selection] Server {} not found in available list, available: {:?}",
server_model.name,
available.iter().map(|m| &m.name).collect::<Vec<_>>()
);
}
}
}
// Fallback: default
info!(
"[Model Selection] No active model found, using default: {}",
self.default_model.name
);
self.default_model.clone()
}
/// Sets the active model in both local and server storage
pub async fn set_active_model(
&self,
workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError> {
info!(
"[Model Selection] active model: {} for source: {}",
model.name, source_key.key
);
self.unset_sources.write().await.remove(&source_key.key);
let available = self.get_models(workspace_id).await;
if available.contains(&model) {
// Update local storage
if let Some(storage) = self.local_storage.load_full() {
storage
.set_selected_model(workspace_id, source_key, model.clone())
.await?;
}
// Update server storage
if let Some(storage) = self.server_storage.load_full() {
storage
.set_selected_model(workspace_id, source_key, model)
.await?;
}
Ok(())
} else {
Err(
FlowyError::internal()
.with_context(format!("Model '{:?}' not found in available list", model)),
)
}
}
}
/// Namespaced key for model selection storage
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SourceKey {
key: String,
}
impl SourceKey {
/// Create a new SourceKey
pub fn new(key: String) -> Self {
Self { key }
}
/// Combine the UUID key with a model's is_local flag and name to produce a storage identifier
pub fn storage_id(&self) -> String {
format!("ai_models_{}", self.key)
}
}
/// A trait that defines an asynchronous source of AI models
#[async_trait]
pub trait ModelSource: Send + Sync {
/// Identifier for this source (e.g., "local" or "server")
fn source_name(&self) -> &'static str;
/// Asynchronously returns a list of available models from this source
async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec<Model>;
}
pub struct LocalAiSource {
controller: Arc<LocalAIController>,
}
impl LocalAiSource {
pub fn new(controller: Arc<LocalAIController>) -> Self {
Self { controller }
}
}
#[async_trait]
impl ModelSource for LocalAiSource {
fn source_name(&self) -> &'static str {
"local"
}
async fn list_chat_models(&self, _workspace_id: &Uuid) -> Vec<Model> {
match self.controller.ollama.load_full() {
None => vec![],
Some(ollama) => ollama
.list_local_models()
.await
.map(|models| {
models
.into_iter()
.filter(|m| !m.name.contains("embed"))
.map(|m| AIModel::local(m.name, String::new()))
.collect()
})
.unwrap_or_default(),
}
}
}
/// A server-side AI source (e.g., cloud API)
#[derive(Debug, Default)]
struct ServerModelsCache {
models: Vec<Model>,
timestamp: Option<i64>,
}
pub struct ServerAiSource {
cached_models: Arc<RwLock<ServerModelsCache>>,
cloud_service: Arc<dyn ChatCloudService>,
}
impl ServerAiSource {
pub fn new(cloud_service: Arc<dyn ChatCloudService>) -> Self {
Self {
cached_models: Arc::new(Default::default()),
cloud_service,
}
}
async fn update_models_cache(&self, models: &[Model], timestamp: i64) -> FlowyResult<()> {
match self.cached_models.try_write() {
Ok(mut cache) => {
cache.models = models.to_vec();
cache.timestamp = Some(timestamp);
Ok(())
},
Err(_) => {
Err(FlowyError::internal().with_context("Failed to acquire write lock for models cache"))
},
}
}
}
#[async_trait]
impl ModelSource for ServerAiSource {
fn source_name(&self) -> &'static str {
"server"
}
async fn list_chat_models(&self, workspace_id: &Uuid) -> Vec<Model> {
let now = timestamp();
let should_fetch = {
let cached = self.cached_models.read().await;
cached.models.is_empty() || cached.timestamp.map_or(true, |ts| now - ts >= 300)
};
if !should_fetch {
return self.cached_models.read().await.models.clone();
}
match self.cloud_service.get_available_models(workspace_id).await {
Ok(resp) => {
let models = resp
.models
.into_iter()
.map(AIModel::from)
.collect::<Vec<_>>();
if let Err(e) = self.update_models_cache(&models, now).await {
error!("Failed to update cache: {}", e);
}
models
},
Err(err) => {
error!("Failed to fetch models: {}", err);
let cached = self.cached_models.read().await;
if !cached.models.is_empty() {
info!("Returning expired cache due to error");
return cached.models.clone();
}
Vec::new()
},
}
}
}
#[async_trait]
pub trait UserModelStorage: Send + Sync {
async fn get_selected_model(&self, workspace_id: &Uuid, source_key: &SourceKey) -> Option<Model>;
async fn set_selected_model(
&self,
workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError>;
}
pub struct ServerModelStorageImpl(pub Arc<dyn ChatCloudService>);
#[async_trait]
impl UserModelStorage for ServerModelStorageImpl {
async fn get_selected_model(
&self,
workspace_id: &Uuid,
_source_key: &SourceKey,
) -> Option<Model> {
let name = self
.0
.get_workspace_default_model(workspace_id)
.await
.ok()?;
Some(Model::server(name, String::new()))
}
async fn set_selected_model(
&self,
workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError> {
if model.is_local {
// local model does not need to be set
return Ok(());
}
if source_key.key != GLOBAL_ACTIVE_MODEL_KEY {
return Ok(());
}
self
.0
.set_workspace_default_model(workspace_id, &model.name)
.await?;
Ok(())
}
}
pub struct LocalModelStorageImpl(pub Arc<KVStorePreferences>);
#[async_trait]
impl UserModelStorage for LocalModelStorageImpl {
async fn get_selected_model(
&self,
_workspace_id: &Uuid,
source_key: &SourceKey,
) -> Option<Model> {
self.0.get_object::<AIModel>(&source_key.storage_id())
}
async fn set_selected_model(
&self,
_workspace_id: &Uuid,
source_key: &SourceKey,
model: Model,
) -> Result<(), FlowyError> {
self
.0
.set_object::<AIModel>(&source_key.storage_id(), &model)?;
Ok(())
}
}