Merge pull request #7834 from AppFlowy-IO/support_switch_local_models

Support switch local models
This commit is contained in:
Nathan.fooo 2025-04-26 10:43:59 +08:00 committed by GitHub
commit ec5eb4e337
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 921 additions and 579 deletions

View File

@ -7,7 +7,6 @@ import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:protobuf/protobuf.dart';
import 'package:universal_platform/universal_platform.dart';
typedef OnModelStateChangedCallback = void Function(AIModelState state);
@ -52,25 +51,29 @@ class AIModelStateNotifier {
final String objectId;
final LocalAIStateListener? _localAIListener;
final AIModelSwitchListener _aiModelSwitchListener;
LocalAIPB? _localAIState;
AvailableModelsPB? _availableModels;
// callbacks
LocalAIPB? _localAIState;
ModelSelectionPB? _modelSelection;
AIModelState _currentState = _defaultState();
List<AIModelPB> _availableModels = [];
AIModelPB? _selectedModel;
final List<OnModelStateChangedCallback> _stateChangedCallbacks = [];
final List<OnAvailableModelsChangedCallback>
_availableModelsChangedCallbacks = [];
/// Starts platform-specific listeners
void _startListening() {
if (UniversalPlatform.isDesktop) {
_localAIListener?.start(
stateCallback: (state) async {
_localAIState = state;
_notifyStateChanged();
_updateAll();
if (state.state == RunningStatePB.Running ||
state.state == RunningStatePB.Stopped) {
await _loadAvailableModels();
_notifyAvailableModelsChanged();
await _loadModelSelection();
_updateAll();
}
},
);
@ -78,25 +81,25 @@ class AIModelStateNotifier {
_aiModelSwitchListener.start(
onUpdateSelectedModel: (model) async {
final updatedModels = _availableModels?.deepCopy()
?..selectedModel = model;
_availableModels = updatedModels;
_notifyAvailableModelsChanged();
_selectedModel = model;
_updateAll();
if (model.isLocal && UniversalPlatform.isDesktop) {
await _loadLocalAiState();
await _loadLocalState();
_updateAll();
}
_notifyStateChanged();
},
);
}
void _init() async {
await Future.wait([_loadLocalAiState(), _loadAvailableModels()]);
_notifyStateChanged();
_notifyAvailableModelsChanged();
Future<void> _init() async {
await Future.wait([
if (UniversalPlatform.isDesktop) _loadLocalState(),
_loadModelSelection(),
]);
_updateAll();
}
/// Register callbacks for state or available-models changes
void addListener({
OnModelStateChangedCallback? onStateChanged,
OnAvailableModelsChangedCallback? onAvailableModelsChanged,
@ -109,6 +112,7 @@ class AIModelStateNotifier {
}
}
/// Remove previously registered callbacks
void removeListener({
OnModelStateChangedCallback? onStateChanged,
OnAvailableModelsChangedCallback? onAvailableModelsChanged,
@ -128,116 +132,88 @@ class AIModelStateNotifier {
await _aiModelSwitchListener.stop();
}
AIModelState getState() {
if (UniversalPlatform.isMobile) {
return AIModelState(
type: AiType.cloud,
hintText: LocaleKeys.chat_inputMessageHint.tr(),
tooltip: null,
isEditable: true,
localAIEnabled: false,
);
/// Returns current AIModelState
AIModelState getState() => _currentState;
/// Returns available models and the selected model
(List<AIModelPB>, AIModelPB?) getModelSelection() =>
(_availableModels, _selectedModel);
void _updateAll() {
_currentState = _computeState();
for (final cb in _stateChangedCallbacks) {
cb(_currentState);
}
final availableModels = _availableModels;
final localAiState = _localAIState;
if (availableModels == null) {
return AIModelState(
type: AiType.cloud,
hintText: LocaleKeys.chat_inputMessageHint.tr(),
isEditable: true,
tooltip: null,
localAIEnabled: false,
);
}
if (localAiState == null) {
return AIModelState(
type: AiType.cloud,
hintText: LocaleKeys.chat_inputMessageHint.tr(),
tooltip: null,
isEditable: true,
localAIEnabled: false,
);
for (final cb in _availableModelsChangedCallbacks) {
cb(_availableModels, _selectedModel);
}
}
if (!availableModels.selectedModel.isLocal) {
return AIModelState(
type: AiType.cloud,
hintText: LocaleKeys.chat_inputMessageHint.tr(),
tooltip: null,
isEditable: true,
localAIEnabled: false,
);
}
final editable = localAiState.state == RunningStatePB.Running;
final tooltip = localAiState.enabled
? (editable
? null
: LocaleKeys.settings_aiPage_keys_localAINotReadyTextFieldPrompt
.tr())
: LocaleKeys.settings_aiPage_keys_localAIDisabledTextFieldPrompt.tr();
final hintText = localAiState.enabled
? (editable
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
: LocaleKeys.settings_aiPage_keys_localAIInitializing.tr())
: LocaleKeys.settings_aiPage_keys_localAIDisabled.tr();
return AIModelState(
type: AiType.local,
hintText: hintText,
tooltip: tooltip,
isEditable: editable,
localAIEnabled: localAiState.enabled,
Future<void> _loadModelSelection() async {
await AIEventGetSourceModelSelection(
ModelSourcePB(source: objectId),
).send().fold(
(ms) {
_modelSelection = ms;
_availableModels = ms.models;
_selectedModel = ms.selectedModel;
},
(e) => Log.error("Failed to fetch models: \$e"),
);
}
(List<AIModelPB>, AIModelPB?) getAvailableModels() {
final availableModels = _availableModels;
if (availableModels == null) {
return ([], null);
}
return (availableModels.models, availableModels.selectedModel);
}
void _notifyAvailableModelsChanged() {
final (models, selectedModel) = getAvailableModels();
for (final callback in _availableModelsChangedCallbacks) {
callback(models, selectedModel);
}
}
void _notifyStateChanged() {
final state = getState();
for (final callback in _stateChangedCallbacks) {
callback(state);
}
}
Future<void> _loadAvailableModels() {
final payload = AvailableModelsQueryPB(source: objectId);
return AIEventGetAvailableModels(payload).send().fold(
(models) => _availableModels = models,
(err) => Log.error("Failed to get available models: $err"),
Future<void> _loadLocalState() async {
await AIEventGetLocalAIState().send().fold(
(s) => _localAIState = s,
(e) => Log.error("Failed to fetch local AI state: \$e"),
);
}
Future<void> _loadLocalAiState() {
return AIEventGetLocalAIState().send().fold(
(localAIState) => _localAIState = localAIState,
(error) => Log.error("Failed to get local AI state: $error"),
);
static AIModelState _defaultState() => AIModelState(
type: AiType.cloud,
hintText: LocaleKeys.chat_inputMessageHint.tr(),
tooltip: null,
isEditable: true,
localAIEnabled: false,
);
/// Core logic computing the state from local and selection data
AIModelState _computeState() {
if (UniversalPlatform.isMobile) return _defaultState();
if (_modelSelection == null || _localAIState == null) {
return _defaultState();
}
if (!_selectedModel!.isLocal) {
return _defaultState();
}
final enabled = _localAIState!.enabled;
final running = _localAIState!.state == RunningStatePB.Running;
final hintKey = enabled
? (running
? LocaleKeys.chat_inputLocalAIMessageHint
: LocaleKeys.settings_aiPage_keys_localAIInitializing)
: LocaleKeys.settings_aiPage_keys_localAIDisabled;
final tooltipKey = enabled
? (running
? null
: LocaleKeys.settings_aiPage_keys_localAINotReadyTextFieldPrompt)
: LocaleKeys.settings_aiPage_keys_localAIDisabledTextFieldPrompt;
return AIModelState(
type: AiType.local,
hintText: hintKey.tr(),
tooltip: tooltipKey?.tr(),
isEditable: running,
localAIEnabled: enabled,
);
}
}
extension AiModelExtension on AIModelPB {
bool get isDefault {
return name == "Auto";
}
String get i18n {
return isDefault ? LocaleKeys.chat_switchModel_autoModel.tr() : name;
}
extension AIModelPBExtension on AIModelPB {
bool get isDefault => name == 'Auto';
String get i18n =>
isDefault ? LocaleKeys.chat_switchModel_autoModel.tr() : name;
}

View File

@ -83,7 +83,7 @@ class SelectModelState with _$SelectModelState {
}) = _SelectModelState;
factory SelectModelState.initial(AIModelStateNotifier notifier) {
final (models, selectedModel) = notifier.getAvailableModels();
final (models, selectedModel) = notifier.getModelSelection();
return SelectModelState(
models: models,
selectedModel: selectedModel,

View File

@ -90,38 +90,40 @@ class SelectModelPopoverContent extends StatelessWidget {
return Padding(
padding: const EdgeInsets.all(8.0),
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
if (localModels.isNotEmpty) ...[
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_localModel.tr(),
child: SingleChildScrollView(
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
if (localModels.isNotEmpty) ...[
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_localModel.tr(),
),
const VSpace(4.0),
],
...localModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[
const VSpace(8.0),
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_cloudModel.tr(),
),
const VSpace(4.0),
],
...cloudModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
const VSpace(4.0),
],
...localModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[
const VSpace(8.0),
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_cloudModel.tr(),
),
const VSpace(4.0),
],
...cloudModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
],
),
),
);
}
@ -215,45 +217,41 @@ class _CurrentModelButton extends StatelessWidget {
behavior: HitTestBehavior.opaque,
child: SizedBox(
height: DesktopAIPromptSizes.actionBarButtonSize,
child: AnimatedSize(
duration: const Duration(milliseconds: 50),
curve: Curves.easeInOut,
alignment: AlignmentDirectional.centerStart,
child: FlowyHover(
style: const HoverStyle(
borderRadius: BorderRadius.all(Radius.circular(8)),
),
child: Padding(
padding: const EdgeInsetsDirectional.all(4.0),
child: Row(
children: [
Padding(
// TODO: remove this after change icon to 20px
padding: EdgeInsets.all(2),
child: FlowySvg(
FlowySvgs.ai_sparks_s,
color: Theme.of(context).hintColor,
size: Size.square(16),
),
),
if (model != null && !model!.isDefault)
Padding(
padding: EdgeInsetsDirectional.only(end: 2.0),
child: FlowyText(
model!.i18n,
fontSize: 12,
figmaLineHeight: 16,
color: Theme.of(context).hintColor,
overflow: TextOverflow.ellipsis,
),
),
FlowySvg(
FlowySvgs.ai_source_drop_down_s,
child: FlowyHover(
style: const HoverStyle(
borderRadius: BorderRadius.all(Radius.circular(8)),
),
child: Padding(
padding: const EdgeInsetsDirectional.all(4.0),
child: Row(
mainAxisSize: MainAxisSize.min,
children: [
Padding(
// TODO: remove this after change icon to 20px
padding: EdgeInsets.all(2),
child: FlowySvg(
FlowySvgs.ai_sparks_s,
color: Theme.of(context).hintColor,
size: const Size.square(8),
size: Size.square(16),
),
],
),
),
if (model != null && !model!.isDefault)
Padding(
padding: EdgeInsetsDirectional.only(end: 2.0),
child: FlowyText(
model!.i18n,
fontSize: 12,
figmaLineHeight: 16,
color: Theme.of(context).hintColor,
overflow: TextOverflow.ellipsis,
),
),
FlowySvg(
FlowySvgs.ai_source_drop_down_s,
color: Theme.of(context).hintColor,
size: const Size.square(8),
),
],
),
),
),

View File

@ -448,7 +448,7 @@ class _ChangeModelButtonState extends State<ChangeModelButton> {
child: buildButton(context),
popupBuilder: (_) {
final bloc = context.read<AIPromptInputBloc>();
final (models, _) = bloc.aiModelStateNotifier.getAvailableModels();
final (models, _) = bloc.aiModelStateNotifier.getModelSelection();
return SelectModelPopoverContent(
models: models,
selectedModel: null,

View File

@ -407,7 +407,7 @@ class ChatAIMessagePopup extends StatelessWidget {
return MobileQuickActionButton(
onTap: () async {
final bloc = context.read<AIPromptInputBloc>();
final (models, _) = bloc.aiModelStateNotifier.getAvailableModels();
final (models, _) = bloc.aiModelStateNotifier.getModelSelection();
final result = await showChangeModelBottomSheet(context, models);
if (result != null) {
onChangeModel?.call(result);

View File

@ -6,174 +6,181 @@ import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:bloc/bloc.dart';
import 'package:collection/collection.dart';
import 'package:freezed_annotation/freezed_annotation.dart';
import 'package:equatable/equatable.dart';
import 'package:freezed_annotation/freezed_annotation.dart';
part 'ollama_setting_bloc.freezed.dart';
const kDefaultChatModel = 'llama3.1:latest';
const kDefaultEmbeddingModel = 'nomic-embed-text:latest';
/// Extension methods to map between PB and UI models
class OllamaSettingBloc extends Bloc<OllamaSettingEvent, OllamaSettingState> {
OllamaSettingBloc() : super(const OllamaSettingState()) {
on<OllamaSettingEvent>(_handleEvent);
on<_Started>(_handleStarted);
on<_DidLoadLocalModels>(_onLoadLocalModels);
on<_DidLoadSetting>(_onLoadSetting);
on<_UpdateSetting>(_onLoadSetting);
on<_OnEdit>(_onEdit);
on<_OnSubmit>(_onSubmit);
on<_SetDefaultModel>(_onSetDefaultModel);
}
Future<void> _handleEvent(
OllamaSettingEvent event,
Future<void> _handleStarted(
_Started event,
Emitter<OllamaSettingState> emit,
) async {
event.when(
started: () {
AIEventGetLocalAISetting().send().fold(
(setting) {
if (!isClosed) {
add(OllamaSettingEvent.didLoadSetting(setting));
}
},
Log.error,
);
},
didLoadSetting: (setting) => _updateSetting(setting, emit),
updateSetting: (setting) => _updateSetting(setting, emit),
onEdit: (content, settingType) {
final updatedSubmittedItems = state.submittedItems
.map(
(item) => item.settingType == settingType
? SubmittedItem(
content: content,
settingType: item.settingType,
)
: item,
)
.toList();
try {
final results = await Future.wait([
AIEventGetLocalModelSelection().send().then((r) => r.getOrThrow()),
AIEventGetLocalAISetting().send().then((r) => r.getOrThrow()),
]);
// Convert both lists to maps: {settingType: content}
final updatedMap = {
for (final item in updatedSubmittedItems)
item.settingType: item.content,
};
final models = results[0] as ModelSelectionPB;
final setting = results[1] as LocalAISettingPB;
final inputMap = {
for (final item in state.inputItems) item.settingType: item.content,
};
// Compare maps instead of lists
final isEdited = !const MapEquality<SettingType, String>()
.equals(updatedMap, inputMap);
emit(
state.copyWith(
submittedItems: updatedSubmittedItems,
isEdited: isEdited,
),
);
},
submit: () {
final setting = LocalAISettingPB();
final settingUpdaters = <SettingType, void Function(String)>{
SettingType.serverUrl: (value) => setting.serverUrl = value,
SettingType.chatModel: (value) => setting.chatModelName = value,
SettingType.embeddingModel: (value) =>
setting.embeddingModelName = value,
};
for (final item in state.submittedItems) {
settingUpdaters[item.settingType]?.call(item.content);
}
add(OllamaSettingEvent.updateSetting(setting));
AIEventUpdateLocalAISetting(setting).send().fold(
(_) => Log.info('AI setting updated successfully'),
(err) => Log.error("update ai setting failed: $err"),
);
},
);
if (!isClosed) {
add(OllamaSettingEvent.didLoadLocalModels(models));
add(OllamaSettingEvent.didLoadSetting(setting));
}
} catch (e, st) {
Log.error('Failed to load initial AI data: $e\n$st');
}
}
void _updateSetting(
LocalAISettingPB setting,
void _onLoadLocalModels(
_DidLoadLocalModels event,
Emitter<OllamaSettingState> emit,
) {
emit(state.copyWith(localModels: event.models));
}
void _onLoadSetting(
dynamic event,
Emitter<OllamaSettingState> emit,
) {
final setting = (event as dynamic).setting as LocalAISettingPB;
final submitted = setting.toSubmittedItems();
emit(
state.copyWith(
setting: setting,
inputItems: _createInputItems(setting),
submittedItems: _createSubmittedItems(setting),
isEdited: false, // Reset to false when the setting is loaded/updated.
inputItems: setting.toInputItems(),
submittedItems: submitted,
originalMap: {
for (final item in submitted) item.settingType: item.content,
},
isEdited: false,
),
);
}
List<SettingItem> _createInputItems(LocalAISettingPB setting) => [
SettingItem(
content: setting.serverUrl,
hintText: 'http://localhost:11434',
settingType: SettingType.serverUrl,
),
SettingItem(
content: setting.chatModelName,
hintText: 'llama3.1',
settingType: SettingType.chatModel,
),
SettingItem(
content: setting.embeddingModelName,
hintText: 'nomic-embed-text',
settingType: SettingType.embeddingModel,
),
];
void _onEdit(
_OnEdit event,
Emitter<OllamaSettingState> emit,
) {
final updated = state.submittedItems
.map(
(item) => item.settingType == event.settingType
? item.copyWith(content: event.content)
: item,
)
.toList();
List<SubmittedItem> _createSubmittedItems(LocalAISettingPB setting) => [
SubmittedItem(
content: setting.serverUrl,
settingType: SettingType.serverUrl,
),
SubmittedItem(
content: setting.chatModelName,
settingType: SettingType.chatModel,
),
SubmittedItem(
content: setting.embeddingModelName,
settingType: SettingType.embeddingModel,
),
];
final currentMap = {for (final i in updated) i.settingType: i.content};
final isEdited = !const MapEquality<SettingType, String>()
.equals(state.originalMap, currentMap);
emit(state.copyWith(submittedItems: updated, isEdited: isEdited));
}
void _onSubmit(
_OnSubmit event,
Emitter<OllamaSettingState> emit,
) {
final pb = LocalAISettingPB();
for (final item in state.submittedItems) {
switch (item.settingType) {
case SettingType.serverUrl:
pb.serverUrl = item.content;
break;
case SettingType.chatModel:
pb.globalChatModel = state.selectedModel?.name ?? item.content;
break;
case SettingType.embeddingModel:
pb.embeddingModelName = item.content;
break;
}
}
add(OllamaSettingEvent.updateSetting(pb));
AIEventUpdateLocalAISetting(pb).send().fold(
(_) => Log.info('AI setting updated successfully'),
(err) => Log.error('Update AI setting failed: $err'),
);
}
void _onSetDefaultModel(
_SetDefaultModel event,
Emitter<OllamaSettingState> emit,
) {
emit(state.copyWith(selectedModel: event.model, isEdited: true));
}
}
// Create an enum for setting type.
/// Setting types for mapping
enum SettingType {
serverUrl,
chatModel,
embeddingModel; // semicolon needed after the enum values
embeddingModel;
String get title {
switch (this) {
case SettingType.serverUrl:
return 'Ollama server url';
case SettingType.chatModel:
return 'Chat model name';
return 'Default model name';
case SettingType.embeddingModel:
return 'Embedding model name';
}
}
}
/// Input field representation
class SettingItem extends Equatable {
const SettingItem({
required this.content,
required this.hintText,
required this.settingType,
});
final String content;
final String hintText;
final SettingType settingType;
@override
List<Object?> get props => [content, settingType];
}
/// Items pending submission
class SubmittedItem extends Equatable {
const SubmittedItem({
required this.content,
required this.settingType,
});
final String content;
final SettingType settingType;
/// Returns a copy of this SubmittedItem with given fields updated.
SubmittedItem copyWith({
String? content,
SettingType? settingType,
}) {
return SubmittedItem(
content: content ?? this.content,
settingType: settingType ?? this.settingType,
);
}
@override
List<Object?> get props => [content, settingType];
}
@ -181,10 +188,18 @@ class SubmittedItem extends Equatable {
@freezed
class OllamaSettingEvent with _$OllamaSettingEvent {
const factory OllamaSettingEvent.started() = _Started;
const factory OllamaSettingEvent.didLoadSetting(LocalAISettingPB setting) =
_DidLoadSetting;
const factory OllamaSettingEvent.updateSetting(LocalAISettingPB setting) =
_UpdateSetting;
const factory OllamaSettingEvent.didLoadLocalModels(
ModelSelectionPB models,
) = _DidLoadLocalModels;
const factory OllamaSettingEvent.didLoadSetting(
LocalAISettingPB setting,
) = _DidLoadSetting;
const factory OllamaSettingEvent.updateSetting(
LocalAISettingPB setting,
) = _UpdateSetting;
const factory OllamaSettingEvent.setDefaultModel(
AIModelPB model,
) = _SetDefaultModel;
const factory OllamaSettingEvent.onEdit(
String content,
SettingType settingType,
@ -196,25 +211,42 @@ class OllamaSettingEvent with _$OllamaSettingEvent {
class OllamaSettingState with _$OllamaSettingState {
const factory OllamaSettingState({
LocalAISettingPB? setting,
@Default([
SettingItem(
content: 'http://localhost:11434',
hintText: 'http://localhost:11434',
settingType: SettingType.serverUrl,
),
SettingItem(
content: 'llama3.1',
hintText: 'llama3.1',
settingType: SettingType.chatModel,
),
SettingItem(
content: 'nomic-embed-text',
hintText: 'nomic-embed-text',
settingType: SettingType.embeddingModel,
),
])
List<SettingItem> inputItems,
@Default([]) List<SettingItem> inputItems,
AIModelPB? selectedModel,
ModelSelectionPB? localModels,
AIModelPB? defaultModel,
@Default([]) List<SubmittedItem> submittedItems,
@Default(false) bool isEdited,
}) = _PluginStateState;
@Default({}) Map<SettingType, String> originalMap,
}) = _OllamaSettingState;
}
extension on LocalAISettingPB {
List<SettingItem> toInputItems() => [
SettingItem(
content: serverUrl,
hintText: 'http://localhost:11434',
settingType: SettingType.serverUrl,
),
SettingItem(
content: embeddingModelName,
hintText: kDefaultEmbeddingModel,
settingType: SettingType.embeddingModel,
),
];
List<SubmittedItem> toSubmittedItems() => [
SubmittedItem(
content: serverUrl,
settingType: SettingType.serverUrl,
),
SubmittedItem(
content: globalChatModel,
settingType: SettingType.chatModel,
),
SubmittedItem(
content: embeddingModelName,
settingType: SettingType.embeddingModel,
),
];
}

View File

@ -93,7 +93,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
),
);
},
didLoadAvailableModels: (AvailableModelsPB models) {
didLoadAvailableModels: (ModelSelectionPB models) {
emit(
state.copyWith(
availableModels: models,
@ -134,7 +134,8 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
);
void _loadModelList() {
AIEventGetServerAvailableModels().send().then((result) {
final payload = ModelSourcePB(source: aiModelsGlobalActiveModel);
AIEventGetSettingModelSelection(payload).send().then((result) {
result.fold((models) {
if (!isClosed) {
add(SettingsAIEvent.didLoadAvailableModels(models));
@ -175,7 +176,7 @@ class SettingsAIEvent with _$SettingsAIEvent {
) = _DidReceiveUserProfile;
const factory SettingsAIEvent.didLoadAvailableModels(
AvailableModelsPB models,
ModelSelectionPB models,
) = _DidLoadAvailableModels;
}
@ -184,7 +185,7 @@ class SettingsAIState with _$SettingsAIState {
const factory SettingsAIState({
required UserProfilePB userProfile,
WorkspaceSettingsPB? aiSettings,
AvailableModelsPB? availableModels,
ModelSelectionPB? availableModels,
@Default(true) bool enableSearchIndexing,
}) = _SettingsAIState;
}

View File

@ -31,7 +31,10 @@ HotKeyItem openSettingsHotKey(
),
keyDownHandler: (_) {
if (_settingsDialogKey.currentContext == null) {
showSettingsDialog(context);
showSettingsDialog(
context,
userWorkspaceBloc: context.read<UserWorkspaceBloc>(),
);
} else {
Navigator.of(context, rootNavigator: true)
.popUntil((route) => route.isFirst);
@ -110,7 +113,7 @@ class _UserSettingButtonState extends State<UserSettingButton> {
void showSettingsDialog(
BuildContext context, {
UserWorkspaceBloc? userWorkspaceBloc,
required UserWorkspaceBloc userWorkspaceBloc,
PasswordBloc? passwordBloc,
SettingsPage? initPage,
}) {
@ -134,7 +137,7 @@ void showSettingsDialog(
value: BlocProvider.of<DocumentAppearanceCubit>(dialogContext),
),
BlocProvider.value(
value: userWorkspaceBloc ?? context.read<UserWorkspaceBloc>(),
value: userWorkspaceBloc,
),
],
child: SettingsDialog(

View File

@ -7,6 +7,14 @@ import 'package:flowy_infra_ui/widget/spacing.dart';
import 'package:flutter/material.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:appflowy/ai/ai.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/workspace/presentation/settings/shared/af_dropdown_menu_entry.dart';
import 'package:appflowy/workspace/presentation/settings/shared/settings_dropdown.dart';
import 'package:easy_localization/easy_localization.dart';
class OllamaSettingPage extends StatelessWidget {
const OllamaSettingPage({super.key});
@ -32,6 +40,7 @@ class OllamaSettingPage extends StatelessWidget {
children: [
for (final item in state.inputItems)
_SettingItemWidget(item: item),
const LocalAIModelSelection(),
_SaveButton(isEdited: state.isEdited),
],
),
@ -113,3 +122,59 @@ class _SaveButton extends StatelessWidget {
);
}
}
class LocalAIModelSelection extends StatelessWidget {
const LocalAIModelSelection({super.key});
static const double height = 49;
@override
Widget build(BuildContext context) {
return BlocBuilder<OllamaSettingBloc, OllamaSettingState>(
buildWhen: (previous, current) =>
previous.localModels != current.localModels,
builder: (context, state) {
final models = state.localModels;
if (models == null) {
return const SizedBox(
// Using same height as SettingsDropdown to avoid layout shift
height: height,
);
}
return Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
FlowyText.medium(
LocaleKeys.settings_aiPage_keys_globalLLMModel.tr(),
fontSize: 12,
figmaLineHeight: 16,
),
const VSpace(4),
SizedBox(
height: 40,
child: SettingsDropdown<AIModelPB>(
key: const Key('_AIModelSelection'),
onChanged: (model) => context
.read<OllamaSettingBloc>()
.add(OllamaSettingEvent.setDefaultModel(model)),
selectedOption: models.selectedModel,
selectOptionCompare: (left, right) => left?.name == right?.name,
options: models.models
.map(
(model) => buildDropdownMenuEntry<AIModelPB>(
context,
value: model,
label: model.i18n,
subLabel: model.desc,
maximumHeight: height,
),
)
.toList(),
),
),
],
);
},
);
}
}

View File

@ -866,6 +866,7 @@
"aiSettingsDescription": "Choose your preferred model to power AppFlowy AI. Now includes GPT-4o, GPT-o3-mini, DeepSeek R1, Claude 3.5 Sonnet, and models available in Ollama",
"loginToEnableAIFeature": "AI features are only enabled after logging in with @:appName Cloud. If you don't have an @:appName account, go to 'My Account' to sign up",
"llmModel": "Language Model",
"globalLLMModel": "Global Language Model",
"llmModelType": "Language Model Type",
"downloadLLMPrompt": "Download {}",
"downloadAppFlowyOfflineAI": "Downloading AI offline package will enable AI to run on your device. Do you want to continue?",

View File

@ -2210,12 +2210,6 @@ dependencies = [
"litrs",
]
[[package]]
name = "dotenv"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
[[package]]
name = "downcast-rs"
version = "2.0.1"
@ -2506,7 +2500,6 @@ dependencies = [
"bytes",
"collab-integrate",
"dashmap 6.0.1",
"dotenv",
"flowy-ai-pub",
"flowy-codegen",
"flowy-derive",
@ -2520,19 +2513,18 @@ dependencies = [
"lib-infra",
"log",
"notify",
"ollama-rs",
"pin-project",
"protobuf",
"reqwest 0.11.27",
"serde",
"serde_json",
"sha2",
"simsimd",
"strum_macros 0.21.1",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator 0.18.1",
]
@ -2798,6 +2790,7 @@ dependencies = [
"flowy-derive",
"flowy-sqlite",
"lib-dispatch",
"ollama-rs",
"protobuf",
"r2d2",
"reqwest 0.11.27",
@ -4044,6 +4037,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
"serde",
]
[[package]]
@ -4894,6 +4888,23 @@ dependencies = [
"memchr",
]
[[package]]
name = "ollama-rs"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a4b4750770584c8b4a643d0329e7bedacc4ecf68b7c7ac3e1fec2bafd6312f7"
dependencies = [
"async-stream",
"log",
"reqwest 0.12.15",
"schemars",
"serde",
"serde_json",
"static_assertions",
"thiserror 2.0.12",
"url",
]
[[package]]
name = "once_cell"
version = "1.19.0"
@ -5696,7 +5707,7 @@ dependencies = [
"rustc-hash 2.1.0",
"rustls 0.23.20",
"socket2 0.5.5",
"thiserror 2.0.9",
"thiserror 2.0.12",
"tokio",
"tracing",
]
@ -5715,7 +5726,7 @@ dependencies = [
"rustls 0.23.20",
"rustls-pki-types",
"slab",
"thiserror 2.0.9",
"thiserror 2.0.12",
"tinyvec",
"tracing",
"web-time",
@ -6407,6 +6418,31 @@ dependencies = [
"parking_lot 0.12.1",
]
[[package]]
name = "schemars"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
dependencies = [
"dyn-clone",
"indexmap 1.9.3",
"schemars_derive",
"serde",
"serde_json",
]
[[package]]
name = "schemars_derive"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d"
dependencies = [
"proc-macro2",
"quote",
"serde_derive_internals 0.29.1",
"syn 2.0.94",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -6554,6 +6590,17 @@ dependencies = [
"syn 2.0.94",
]
[[package]]
name = "serde_derive_internals"
version = "0.29.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.94",
]
[[package]]
name = "serde_html_form"
version = "0.2.7"
@ -6746,15 +6793,6 @@ dependencies = [
"time",
]
[[package]]
name = "simsimd"
version = "4.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4"
dependencies = [
"cc",
]
[[package]]
name = "siphasher"
version = "0.3.11"
@ -6841,6 +6879,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "string_cache"
version = "0.8.7"
@ -7098,7 +7142,7 @@ dependencies = [
"tantivy-stacker",
"tantivy-tokenizer-api",
"tempfile",
"thiserror 2.0.9",
"thiserror 2.0.12",
"time",
"uuid",
"winapi",
@ -7280,11 +7324,11 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.9"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
dependencies = [
"thiserror-impl 2.0.9",
"thiserror-impl 2.0.12",
]
[[package]]
@ -7300,9 +7344,9 @@ dependencies = [
[[package]]
name = "thiserror-impl"
version = "2.0.9"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
@ -7823,7 +7867,7 @@ checksum = "7a94b0f0954b3e59bfc2c246b4c8574390d94a4ad4ad246aaf2fb07d7dfd3b47"
dependencies = [
"proc-macro2",
"quote",
"serde_derive_internals",
"serde_derive_internals 0.28.0",
"syn 2.0.94",
]

View File

@ -0,0 +1,54 @@
use diesel::sqlite::SqliteConnection;
use flowy_error::FlowyResult;
use flowy_sqlite::upsert::excluded;
use flowy_sqlite::{
diesel,
query_dsl::*,
schema::{local_ai_model_table, local_ai_model_table::dsl},
ExpressionMethods, Identifiable, Insertable, Queryable,
};
#[derive(Clone, Default, Queryable, Insertable, Identifiable)]
#[diesel(table_name = local_ai_model_table)]
#[diesel(primary_key(name))]
pub struct LocalAIModelTable {
pub name: String,
pub model_type: i16,
}
#[derive(Clone, Debug, Copy)]
pub enum ModelType {
Embedding = 0,
Chat = 1,
}
impl From<i16> for ModelType {
fn from(value: i16) -> Self {
match value {
0 => ModelType::Embedding,
1 => ModelType::Chat,
_ => ModelType::Embedding,
}
}
}
pub fn select_local_ai_model(conn: &mut SqliteConnection, name: &str) -> Option<LocalAIModelTable> {
local_ai_model_table::table
.filter(dsl::name.eq(name))
.first::<LocalAIModelTable>(conn)
.ok()
}
pub fn upsert_local_ai_model(
conn: &mut SqliteConnection,
row: &LocalAIModelTable,
) -> FlowyResult<()> {
diesel::insert_into(local_ai_model_table::table)
.values(row)
.on_conflict(local_ai_model_table::name)
.do_update()
.set((local_ai_model_table::model_type.eq(excluded(local_ai_model_table::model_type)),))
.execute(conn)?;
Ok(())
}

View File

@ -1,5 +1,7 @@
mod chat_message_sql;
mod chat_sql;
mod local_model_sql;
pub use chat_message_sql::*;
pub use chat_sql::*;
pub use local_model_sql::*;

View File

@ -48,16 +48,16 @@ collab-integrate.workspace = true
[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies]
notify = "6.1.1"
ollama-rs = "0.3.0"
#faiss = { version = "0.12.1" }
af-mcp = { version = "0.1.0" }
[dev-dependencies]
dotenv = "0.15.0"
uuid.workspace = true
tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] }
simsimd = "4.4.0"
[build-dependencies]
flowy-codegen.workspace = true
[features]
dart = ["flowy-codegen/dart", "flowy-notification/dart"]
local_ai = []

View File

@ -1,7 +1,7 @@
use crate::chat::Chat;
use crate::entities::{
AIModelPB, AvailableModelsPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB,
FilePB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams,
AIModelPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB, FilePB,
ModelSelectionPB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams,
};
use crate::local_ai::controller::{LocalAIController, LocalAISetting};
use crate::middleware::chat_service_mw::ChatServiceMiddleware;
@ -330,14 +330,10 @@ impl AIManager {
.get_question_id_from_answer_id(chat_id, answer_message_id)
.await?;
let model = model.map_or_else(
|| {
self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(&chat_id.to_string()))
},
|model| Some(model.into()),
);
let model = match model {
None => self.get_active_model(&chat_id.to_string()).await,
Some(model) => Some(model.into()),
};
chat
.stream_regenerate_response(question_message_id, answer_stream_port, format, model)
.await?;
@ -345,18 +341,32 @@ impl AIManager {
}
pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> {
let previous_model = self.local_ai.get_local_ai_setting().chat_model_name;
self.local_ai.update_local_ai_setting(setting).await?;
let current_model = self.local_ai.get_local_ai_setting().chat_model_name;
let old_settings = self.local_ai.get_local_ai_setting();
// Only restart if the server URL has changed and local AI is not running
let need_restart =
old_settings.ollama_server_url != setting.ollama_server_url && !self.local_ai.is_running();
if previous_model != current_model {
// Update settings first
self
.local_ai
.update_local_ai_setting(setting.clone())
.await?;
// Handle model change if needed
let model_changed = old_settings.chat_model_name != setting.chat_model_name;
if model_changed {
info!(
"[AI Plugin] update global active model, previous: {}, current: {}",
previous_model, current_model
old_settings.chat_model_name, setting.chat_model_name
);
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
let model = AIModel::local(current_model, "".to_string());
self.update_selected_model(source_key, model).await?;
let model = AIModel::local(setting.chat_model_name, "".to_string());
self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
.await?;
}
if need_restart {
self.local_ai.restart_plugin().await;
}
Ok(())
@ -440,16 +450,16 @@ impl AIManager {
}
pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> {
info!(
"[Model Selection] update {} selected model: {:?}",
source, model
);
let source_key = ai_available_models_key(&source);
info!(
"[Model Selection] update {} selected model: {:?} for key:{}",
source, model, source_key
);
self
.store_preferences
.set_object::<AIModel>(&source_key, &model)?;
chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel)
chat_notification_builder(&source_key, ChatNotification::DidUpdateSelectedModel)
.payload(AIModelPB::from(model))
.send();
Ok(())
@ -458,12 +468,13 @@ impl AIManager {
#[instrument(skip_all, level = "debug")]
pub async fn toggle_local_ai(&self) -> FlowyResult<()> {
let enabled = self.local_ai.toggle_local_ai().await?;
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
if enabled {
if let Some(name) = self.local_ai.get_plugin_chat_model() {
info!("Set global active model to local ai: {}", name);
let model = AIModel::local(name, "".to_string());
self.update_selected_model(source_key, model).await?;
self
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
.await?;
}
} else {
info!("Set global active model to default");
@ -471,7 +482,7 @@ impl AIManager {
let models = self.get_server_available_models().await?;
if let Some(model) = models.into_iter().find(|m| m.name == global_active_model) {
self
.update_selected_model(source_key, AIModel::from(model))
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), AIModel::from(model))
.await?;
}
}
@ -484,119 +495,140 @@ impl AIManager {
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(source));
if model.is_none() {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = Some(AIModel::local(local_model, "".to_string()));
}
match model {
None => {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = Some(AIModel::local(local_model, "".to_string()));
}
model
},
Some(mut model) => {
let models = self.local_ai.get_all_chat_local_models().await;
if !models.contains(&model) {
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
model = AIModel::local(local_model, "".to_string());
}
}
Some(model)
},
}
model
}
pub async fn get_available_models(&self, source: String) -> FlowyResult<AvailableModelsPB> {
pub async fn get_local_available_models(&self) -> FlowyResult<ModelSelectionPB> {
let setting = self.local_ai.get_local_ai_setting();
let mut models = self.local_ai.get_all_chat_local_models().await;
let selected_model = AIModel::local(setting.chat_model_name, "".to_string());
if models.is_empty() {
models.push(selected_model.clone());
}
Ok(ModelSelectionPB {
models: models.into_iter().map(AIModelPB::from).collect(),
selected_model: AIModelPB::from(selected_model),
})
}
pub async fn get_available_models(
&self,
source: String,
setting_only: bool,
) -> FlowyResult<ModelSelectionPB> {
let is_local_mode = self.user_service.is_local_model().await?;
if is_local_mode {
let setting = self.local_ai.get_local_ai_setting();
let selected_model = AIModel::local(setting.chat_model_name, "".to_string());
let models = vec![selected_model.clone()];
Ok(AvailableModelsPB {
models: models.into_iter().map(|m| m.into()).collect(),
selected_model: AIModelPB::from(selected_model),
})
} else {
// Build the models list from server models and mark them as non-local.
let mut models: Vec<AIModel> = self
.get_server_available_models()
.await?
.into_iter()
.map(AIModel::from)
.collect();
trace!("[Model Selection]: Available models: {:?}", models);
let mut current_active_local_ai_model = None;
// If user enable local ai, then add local ai model to the list.
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
let model = AIModel::local(local_model, "".to_string());
current_active_local_ai_model = Some(model.clone());
trace!("[Model Selection] current local ai model: {}", model.name);
models.push(model);
}
if models.is_empty() {
return Ok(AvailableModelsPB {
models: models.into_iter().map(|m| m.into()).collect(),
selected_model: AIModelPB::default(),
});
}
// Global active model is the model selected by the user in the workspace settings.
let mut server_active_model = self
.get_workspace_select_model()
.await
.map(|m| AIModel::server(m, "".to_string()))
.unwrap_or_else(|_| AIModel::default());
trace!(
"[Model Selection] server active model: {:?}",
server_active_model
);
let mut user_selected_model = server_active_model.clone();
// when current select model is deprecated, reset the model to default
if !models.iter().any(|m| m.name == server_active_model.name) {
server_active_model = AIModel::default();
}
let source_key = ai_available_models_key(&source);
// We use source to identify user selected model. source can be document id or chat id.
match self.store_preferences.get_object::<AIModel>(&source_key) {
None => {
// when there is selected model and current local ai is active, then use local ai
if let Some(local_ai_model) = models.iter().find(|m| m.is_local) {
user_selected_model = local_ai_model.clone();
}
},
Some(mut model) => {
trace!("[Model Selection] user previous select model: {:?}", model);
// If source is provided, try to get the user-selected model from the store. User selected
// model will be used as the active model if it exists.
if model.is_local {
if let Some(local_ai_model) = &current_active_local_ai_model {
if local_ai_model.name != model.name {
model = local_ai_model.clone();
}
}
}
user_selected_model = model;
},
}
// If user selected model is not available in the list, use the global active model.
let active_model = models
.iter()
.find(|m| m.name == user_selected_model.name)
.cloned()
.or(Some(server_active_model.clone()));
// Update the stored preference if a different model is used.
if let Some(ref active_model) = active_model {
if active_model.name != user_selected_model.name {
self
.store_preferences
.set_object::<AIModel>(&source_key, &active_model.clone())?;
}
}
trace!("[Model Selection] final active model: {:?}", active_model);
let selected_model = AIModelPB::from(active_model.unwrap_or_default());
Ok(AvailableModelsPB {
models: models.into_iter().map(|m| m.into()).collect(),
selected_model,
})
return self.get_local_available_models().await;
}
// Fetch server models
let mut all_models: Vec<AIModel> = self
.get_server_available_models()
.await?
.into_iter()
.map(AIModel::from)
.collect();
trace!("[Model Selection]: Available models: {:?}", all_models);
// Add local models if enabled
if self.local_ai.is_enabled() {
if setting_only {
let setting = self.local_ai.get_local_ai_setting();
all_models.push(AIModel::local(setting.chat_model_name, "".to_string()));
} else {
all_models.extend(self.local_ai.get_all_chat_local_models().await);
}
}
// Return early if no models available
if all_models.is_empty() {
return Ok(ModelSelectionPB {
models: Vec::new(),
selected_model: AIModelPB::default(),
});
}
// Get server active model (only once)
let server_active_model = self
.get_workspace_select_model()
.await
.map(|m| AIModel::server(m, "".to_string()))
.unwrap_or_else(|_| AIModel::default());
trace!(
"[Model Selection] server active model: {:?}",
server_active_model
);
// Use server model as default if it exists in available models
let default_model = if all_models
.iter()
.any(|m| m.name == server_active_model.name)
{
server_active_model.clone()
} else {
AIModel::default()
};
// Get user's previously selected model
let user_selected_model = match self.get_active_model(&source).await {
Some(model) => {
trace!("[Model Selection] user previous select model: {:?}", model);
model
},
None => {
// When no selected model and local AI is active, use local AI model
all_models
.iter()
.find(|m| m.is_local)
.cloned()
.unwrap_or_else(|| default_model.clone())
},
};
// Determine final active model - use user's selection if available, otherwise default
let active_model = all_models
.iter()
.find(|m| m.name == user_selected_model.name)
.cloned()
.unwrap_or(default_model.clone());
// Update stored preference if changed
if active_model.name != user_selected_model.name {
if let Err(err) = self
.update_selected_model(source, active_model.clone())
.await
{
error!("[Model Selection] failed to update selected model: {}", err);
}
}
trace!("[Model Selection] final active model: {:?}", active_model);
// Create response with one transformation pass
Ok(ModelSelectionPB {
models: all_models.into_iter().map(AIModelPB::from).collect(),
selected_model: AIModelPB::from(active_model),
})
}
pub async fn get_or_create_chat_instance(&self, chat_id: &Uuid) -> Result<Arc<Chat>, FlowyError> {

View File

@ -182,7 +182,7 @@ pub struct ChatMessageListPB {
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct ServerAvailableModelsPB {
pub struct ServerModelSelectionPB {
#[pb(index = 1)]
pub models: Vec<AvailableModelPB>,
}
@ -200,7 +200,7 @@ pub struct AvailableModelPB {
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct AvailableModelsQueryPB {
pub struct ModelSourcePB {
#[pb(index = 1)]
#[validate(custom(function = "required_not_empty_str"))]
pub source: String,
@ -217,7 +217,7 @@ pub struct UpdateSelectedModelPB {
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct AvailableModelsPB {
pub struct ModelSelectionPB {
#[pb(index = 1)]
pub models: Vec<AIModelPB>,
@ -225,6 +225,12 @@ pub struct AvailableModelsPB {
pub selected_model: AIModelPB,
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct RepeatedAIModelPB {
#[pb(index = 1)]
pub items: Vec<AIModelPB>,
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct AIModelPB {
#[pb(index = 1)]
@ -686,7 +692,7 @@ pub struct LocalAISettingPB {
#[pb(index = 2)]
#[validate(custom(function = "required_not_empty_str"))]
pub chat_model_name: String,
pub global_chat_model: String,
#[pb(index = 3)]
#[validate(custom(function = "required_not_empty_str"))]
@ -697,7 +703,7 @@ impl From<LocalAISetting> for LocalAISettingPB {
fn from(value: LocalAISetting) -> Self {
LocalAISettingPB {
server_url: value.ollama_server_url,
chat_model_name: value.chat_model_name,
global_chat_model: value.chat_model_name,
embedding_model_name: value.embedding_model_name,
}
}
@ -707,7 +713,7 @@ impl From<LocalAISettingPB> for LocalAISetting {
fn from(value: LocalAISettingPB) -> Self {
LocalAISetting {
ollama_server_url: value.server_url,
chat_model_name: value.chat_model_name,
chat_model_name: value.global_chat_model,
embedding_model_name: value.embedding_model_name,
}
}

View File

@ -1,7 +1,6 @@
use crate::ai_manager::{AIManager, GLOBAL_ACTIVE_MODEL_KEY};
use crate::ai_manager::AIManager;
use crate::completion::AICompletion;
use crate::entities::*;
use crate::util::ai_available_models_key;
use flowy_ai_pub::cloud::{AIModel, ChatMessageType};
use flowy_error::{ErrorCode, FlowyError, FlowyResult};
use lib_dispatch::prelude::{data_result_ok, AFPluginData, AFPluginState, DataResult};
@ -78,23 +77,24 @@ pub(crate) async fn regenerate_response_handler(
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_server_model_list_handler(
pub(crate) async fn get_setting_model_selection_handler(
data: AFPluginData<ModelSourcePB>,
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<AvailableModelsPB, FlowyError> {
) -> DataResult<ModelSelectionPB, FlowyError> {
let data = data.try_into_inner()?;
let ai_manager = upgrade_ai_manager(ai_manager)?;
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
let models = ai_manager.get_available_models(source_key).await?;
let models = ai_manager.get_available_models(data.source, true).await?;
data_result_ok(models)
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_chat_models_handler(
data: AFPluginData<AvailableModelsQueryPB>,
pub(crate) async fn get_source_model_selection_handler(
data: AFPluginData<ModelSourcePB>,
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<AvailableModelsPB, FlowyError> {
) -> DataResult<ModelSelectionPB, FlowyError> {
let data = data.try_into_inner()?;
let ai_manager = upgrade_ai_manager(ai_manager)?;
let models = ai_manager.get_available_models(data.source).await?;
let models = ai_manager.get_available_models(data.source, false).await?;
data_result_ok(models)
}
@ -340,6 +340,15 @@ pub(crate) async fn get_local_ai_setting_handler(
data_result_ok(pb)
}
#[tracing::instrument(level = "debug", skip_all)]
pub(crate) async fn get_local_ai_models_handler(
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<ModelSelectionPB, FlowyError> {
let ai_manager = upgrade_ai_manager(ai_manager)?;
let data = ai_manager.get_local_available_models().await?;
data_result_ok(data)
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn update_local_ai_setting_handler(
ai_manager: AFPluginState<Weak<AIManager>>,

View File

@ -31,20 +31,24 @@ pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
.event(AIEvent::ToggleLocalAI, toggle_local_ai_handler)
.event(AIEvent::GetLocalAIState, get_local_ai_state_handler)
.event(AIEvent::GetLocalAISetting, get_local_ai_setting_handler)
.event(AIEvent::GetLocalModelSelection, get_local_ai_models_handler)
.event(
AIEvent::GetSourceModelSelection,
get_source_model_selection_handler,
)
.event(
AIEvent::UpdateLocalAISetting,
update_local_ai_setting_handler,
)
.event(
AIEvent::GetServerAvailableModels,
get_server_model_list_handler,
)
.event(AIEvent::CreateChatContext, create_chat_context_handler)
.event(AIEvent::GetChatInfo, create_chat_context_handler)
.event(AIEvent::GetChatSettings, get_chat_settings_handler)
.event(AIEvent::UpdateChatSettings, update_chat_settings_handler)
.event(AIEvent::RegenerateResponse, regenerate_response_handler)
.event(AIEvent::GetAvailableModels, get_chat_models_handler)
.event(
AIEvent::GetSettingModelSelection,
get_setting_model_selection_handler,
)
.event(AIEvent::UpdateSelectedModel, update_selected_model_handler)
}
@ -107,18 +111,21 @@ pub enum AIEvent {
#[event(input = "RegenerateResponsePB")]
RegenerateResponse = 27,
#[event(output = "AvailableModelsPB")]
GetServerAvailableModels = 28,
#[event(output = "LocalAISettingPB")]
GetLocalAISetting = 29,
#[event(input = "LocalAISettingPB")]
UpdateLocalAISetting = 30,
#[event(input = "AvailableModelsQueryPB", output = "AvailableModelsPB")]
GetAvailableModels = 31,
#[event(input = "ModelSourcePB", output = "ModelSelectionPB")]
GetSettingModelSelection = 31,
#[event(input = "UpdateSelectedModelPB")]
UpdateSelectedModel = 32,
#[event(output = "ModelSelectionPB")]
GetLocalModelSelection = 33,
#[event(input = "ModelSourcePB", output = "ModelSelectionPB")]
GetSourceModelSelection = 34,
}

View File

@ -16,9 +16,15 @@ use af_local_ai::ollama_plugin::OllamaAIPlugin;
use af_plugin::core::path::is_plugin_ready;
use af_plugin::core::plugin::RunningState;
use arc_swap::ArcSwapOption;
use flowy_ai_pub::cloud::AIModel;
use flowy_ai_pub::persistence::{
select_local_ai_model, upsert_local_ai_model, LocalAIModelTable, ModelType,
};
use flowy_ai_pub::user_service::AIUserService;
use futures_util::SinkExt;
use lib_infra::util::get_operating_system;
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
use ollama_rs::Ollama;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::path::PathBuf;
@ -39,8 +45,8 @@ impl Default for LocalAISetting {
fn default() -> Self {
Self {
ollama_server_url: "http://localhost:11434".to_string(),
chat_model_name: "llama3.1".to_string(),
embedding_model_name: "nomic-embed-text".to_string(),
chat_model_name: "llama3.1:latest".to_string(),
embedding_model_name: "nomic-embed-text:latest".to_string(),
}
}
}
@ -53,6 +59,7 @@ pub struct LocalAIController {
current_chat_id: ArcSwapOption<Uuid>,
store_preferences: Weak<KVStorePreferences>,
user_service: Arc<dyn AIUserService>,
ollama: ArcSwapOption<Ollama>,
}
impl Deref for LocalAIController {
@ -83,69 +90,80 @@ impl LocalAIController {
user_service.clone(),
res_impl,
));
// Subscribe to state changes
let mut running_state_rx = local_ai.subscribe_running_state();
let cloned_llm_res = Arc::clone(&local_ai_resource);
let cloned_store_preferences = store_preferences.clone();
let cloned_local_ai = Arc::clone(&local_ai);
let cloned_user_service = Arc::clone(&user_service);
let ollama = ArcSwapOption::default();
let sys = get_operating_system();
if sys.is_desktop() {
let setting = local_ai_resource.get_llm_setting();
ollama.store(
Ollama::try_new(&setting.ollama_server_url)
.map(Arc::new)
.ok(),
);
// Spawn a background task to listen for plugin state changes
tokio::spawn(async move {
while let Some(state) = running_state_rx.next().await {
// Skip if we cant get workspace_id
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
continue;
};
// Subscribe to state changes
let mut running_state_rx = local_ai.subscribe_running_state();
let cloned_llm_res = Arc::clone(&local_ai_resource);
let cloned_store_preferences = store_preferences.clone();
let cloned_local_ai = Arc::clone(&local_ai);
let cloned_user_service = Arc::clone(&user_service);
let key = local_ai_enabled_key(&workspace_id.to_string());
info!("[AI Plugin] state: {:?}", state);
// Read whether plugin is enabled from store; default to true
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
// Only check resource status if the plugin isnt in "UnexpectedStop" and is enabled
let (plugin_downloaded, lack_of_resource) =
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
// Possibly check plugin readiness and resource concurrency in parallel,
// but here we do it sequentially for clarity.
let downloaded = is_plugin_ready();
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
(downloaded, resource_lack)
} else {
(false, None)
};
// If plugin is running, retrieve version
let plugin_version = if matches!(state, RunningState::Running { .. }) {
match cloned_local_ai.plugin_info().await {
Ok(info) => Some(info.version),
Err(_) => None,
}
} else {
None
// Spawn a background task to listen for plugin state changes
tokio::spawn(async move {
while let Some(state) = running_state_rx.next().await {
// Skip if we can't get workspace_id
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
continue;
};
// Broadcast the new local AI state
let new_state = RunningStatePB::from(state);
chat_notification_builder(
APPFLOWY_AI_NOTIFICATION_KEY,
ChatNotification::UpdateLocalAIState,
)
.payload(LocalAIPB {
enabled,
plugin_downloaded,
lack_of_resource,
state: new_state,
plugin_version,
})
.send();
} else {
warn!("[AI Plugin] store preferences is dropped");
let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string());
info!("[AI Plugin] state: {:?}", state);
// Read whether plugin is enabled from store; default to true
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
// Only check resource status if the plugin isn't in "UnexpectedStop" and is enabled
let (plugin_downloaded, lack_of_resource) =
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
// Possibly check plugin readiness and resource concurrency in parallel,
// but here we do it sequentially for clarity.
let downloaded = is_plugin_ready();
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
(downloaded, resource_lack)
} else {
(false, None)
};
// If plugin is running, retrieve version
let plugin_version = if matches!(state, RunningState::Running { .. }) {
match cloned_local_ai.plugin_info().await {
Ok(info) => Some(info.version),
Err(_) => None,
}
} else {
None
};
// Broadcast the new local AI state
let new_state = RunningStatePB::from(state);
chat_notification_builder(
APPFLOWY_AI_NOTIFICATION_KEY,
ChatNotification::UpdateLocalAIState,
)
.payload(LocalAIPB {
enabled,
plugin_downloaded,
lack_of_resource,
state: new_state,
plugin_version,
})
.send();
} else {
warn!("[AI Plugin] store preferences is dropped");
}
}
}
});
});
}
Self {
ai_plugin: local_ai,
@ -153,6 +171,7 @@ impl LocalAIController {
current_chat_id: ArcSwapOption::default(),
store_preferences,
user_service,
ollama,
}
}
#[instrument(level = "debug", skip_all)]
@ -287,17 +306,85 @@ impl LocalAIController {
self.resource.get_llm_setting()
}
pub async fn get_all_chat_local_models(&self) -> Vec<AIModel> {
self
.get_filtered_local_models(|name| !name.contains("embed"))
.await
}
pub async fn get_all_embedded_local_models(&self) -> Vec<AIModel> {
self
.get_filtered_local_models(|name| name.contains("embed"))
.await
}
// Helper function to avoid code duplication in model retrieval
async fn get_filtered_local_models<F>(&self, filter_fn: F) -> Vec<AIModel>
where
F: Fn(&str) -> bool,
{
match self.ollama.load_full() {
None => vec![],
Some(ollama) => ollama
.list_local_models()
.await
.map(|models| {
models
.into_iter()
.filter(|m| filter_fn(&m.name.to_lowercase()))
.map(|m| AIModel::local(m.name, String::new()))
.collect()
})
.unwrap_or_default(),
}
}
pub async fn check_model_type(&self, model_name: &str) -> FlowyResult<ModelType> {
let uid = self.user_service.user_id()?;
let mut conn = self.user_service.sqlite_connection(uid)?;
match select_local_ai_model(&mut conn, model_name) {
None => {
let ollama = self
.ollama
.load_full()
.ok_or_else(|| FlowyError::local_ai().with_context("ollama is not initialized"))?;
let request = GenerateEmbeddingsRequest::new(
model_name.to_string(),
EmbeddingsInput::Single("Hello".to_string()),
);
let model_type = match ollama.generate_embeddings(request).await {
Ok(value) => {
if value.embeddings.is_empty() {
ModelType::Chat
} else {
ModelType::Embedding
}
},
Err(_) => ModelType::Chat,
};
upsert_local_ai_model(
&mut conn,
&LocalAIModelTable {
name: model_name.to_string(),
model_type: model_type as i16,
},
)?;
Ok(model_type)
},
Some(r) => Ok(ModelType::from(r.model_type)),
}
}
pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> {
info!(
"[AI Plugin] update local ai setting: {:?}, thread: {:?}",
setting,
std::thread::current().id()
);
if self.resource.set_llm_setting(setting).await.is_ok() {
let is_enabled = self.is_enabled();
self.toggle_plugin(is_enabled).await?;
}
self.resource.set_llm_setting(setting).await?;
Ok(())
}

View File

@ -161,7 +161,6 @@ impl LocalAIResourceController {
let setting = self.get_llm_setting();
let client = Client::builder().timeout(Duration::from_secs(5)).build()?;
match client.get(&setting.ollama_server_url).send().await {
Ok(resp) if resp.status().is_success() => {
info!(

View File

@ -36,6 +36,10 @@ client-api = { workspace = true, optional = true }
tantivy = { workspace = true, optional = true }
uuid.workspace = true
[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies]
ollama-rs = "0.3.0"
[features]
default = ["impl_from_dispatch_error", "impl_from_serde", "impl_from_reqwest", "impl_from_sqlite"]
impl_from_dispatch_error = ["lib-dispatch"]

View File

@ -264,3 +264,10 @@ impl From<uuid::Error> for FlowyError {
FlowyError::internal().with_context(value)
}
}
#[cfg(any(target_os = "windows", target_os = "macos", target_os = "linux"))]
impl From<ollama_rs::error::OllamaError> for FlowyError {
fn from(value: ollama_rs::error::OllamaError) -> Self {
FlowyError::local_ai().with_context(value)
}
}

View File

@ -0,0 +1 @@
-- This file should undo anything in `up.sql`

View File

@ -0,0 +1,6 @@
-- Your SQL goes here
CREATE TABLE local_ai_model_table
(
name TEXT PRIMARY KEY NOT NULL,
model_type SMALLINT NOT NULL
);

View File

@ -54,6 +54,13 @@ diesel::table! {
}
}
diesel::table! {
local_ai_model_table (name) {
name -> Text,
model_type -> SmallInt,
}
}
diesel::table! {
upload_file_part (upload_id, e_tag) {
upload_id -> Text,
@ -138,6 +145,7 @@ diesel::allow_tables_to_appear_in_same_query!(
chat_message_table,
chat_table,
collab_snapshot,
local_ai_model_table,
upload_file_part,
upload_file_table,
user_data_migration_records,