mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2025-12-27 15:13:46 +00:00
Merge pull request #7834 from AppFlowy-IO/support_switch_local_models
Support switch local models
This commit is contained in:
commit
ec5eb4e337
@ -7,7 +7,6 @@ import 'package:appflowy_backend/log.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
|
||||
import 'package:appflowy_result/appflowy_result.dart';
|
||||
import 'package:easy_localization/easy_localization.dart';
|
||||
import 'package:protobuf/protobuf.dart';
|
||||
import 'package:universal_platform/universal_platform.dart';
|
||||
|
||||
typedef OnModelStateChangedCallback = void Function(AIModelState state);
|
||||
@ -52,25 +51,29 @@ class AIModelStateNotifier {
|
||||
final String objectId;
|
||||
final LocalAIStateListener? _localAIListener;
|
||||
final AIModelSwitchListener _aiModelSwitchListener;
|
||||
LocalAIPB? _localAIState;
|
||||
AvailableModelsPB? _availableModels;
|
||||
|
||||
// callbacks
|
||||
LocalAIPB? _localAIState;
|
||||
ModelSelectionPB? _modelSelection;
|
||||
|
||||
AIModelState _currentState = _defaultState();
|
||||
List<AIModelPB> _availableModels = [];
|
||||
AIModelPB? _selectedModel;
|
||||
|
||||
final List<OnModelStateChangedCallback> _stateChangedCallbacks = [];
|
||||
final List<OnAvailableModelsChangedCallback>
|
||||
_availableModelsChangedCallbacks = [];
|
||||
|
||||
/// Starts platform-specific listeners
|
||||
void _startListening() {
|
||||
if (UniversalPlatform.isDesktop) {
|
||||
_localAIListener?.start(
|
||||
stateCallback: (state) async {
|
||||
_localAIState = state;
|
||||
_notifyStateChanged();
|
||||
|
||||
_updateAll();
|
||||
if (state.state == RunningStatePB.Running ||
|
||||
state.state == RunningStatePB.Stopped) {
|
||||
await _loadAvailableModels();
|
||||
_notifyAvailableModelsChanged();
|
||||
await _loadModelSelection();
|
||||
_updateAll();
|
||||
}
|
||||
},
|
||||
);
|
||||
@ -78,25 +81,25 @@ class AIModelStateNotifier {
|
||||
|
||||
_aiModelSwitchListener.start(
|
||||
onUpdateSelectedModel: (model) async {
|
||||
final updatedModels = _availableModels?.deepCopy()
|
||||
?..selectedModel = model;
|
||||
_availableModels = updatedModels;
|
||||
_notifyAvailableModelsChanged();
|
||||
|
||||
_selectedModel = model;
|
||||
_updateAll();
|
||||
if (model.isLocal && UniversalPlatform.isDesktop) {
|
||||
await _loadLocalAiState();
|
||||
await _loadLocalState();
|
||||
_updateAll();
|
||||
}
|
||||
_notifyStateChanged();
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
void _init() async {
|
||||
await Future.wait([_loadLocalAiState(), _loadAvailableModels()]);
|
||||
_notifyStateChanged();
|
||||
_notifyAvailableModelsChanged();
|
||||
Future<void> _init() async {
|
||||
await Future.wait([
|
||||
if (UniversalPlatform.isDesktop) _loadLocalState(),
|
||||
_loadModelSelection(),
|
||||
]);
|
||||
_updateAll();
|
||||
}
|
||||
|
||||
/// Register callbacks for state or available-models changes
|
||||
void addListener({
|
||||
OnModelStateChangedCallback? onStateChanged,
|
||||
OnAvailableModelsChangedCallback? onAvailableModelsChanged,
|
||||
@ -109,6 +112,7 @@ class AIModelStateNotifier {
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove previously registered callbacks
|
||||
void removeListener({
|
||||
OnModelStateChangedCallback? onStateChanged,
|
||||
OnAvailableModelsChangedCallback? onAvailableModelsChanged,
|
||||
@ -128,116 +132,88 @@ class AIModelStateNotifier {
|
||||
await _aiModelSwitchListener.stop();
|
||||
}
|
||||
|
||||
AIModelState getState() {
|
||||
if (UniversalPlatform.isMobile) {
|
||||
return AIModelState(
|
||||
type: AiType.cloud,
|
||||
hintText: LocaleKeys.chat_inputMessageHint.tr(),
|
||||
tooltip: null,
|
||||
isEditable: true,
|
||||
localAIEnabled: false,
|
||||
);
|
||||
/// Returns current AIModelState
|
||||
AIModelState getState() => _currentState;
|
||||
|
||||
/// Returns available models and the selected model
|
||||
(List<AIModelPB>, AIModelPB?) getModelSelection() =>
|
||||
(_availableModels, _selectedModel);
|
||||
|
||||
void _updateAll() {
|
||||
_currentState = _computeState();
|
||||
for (final cb in _stateChangedCallbacks) {
|
||||
cb(_currentState);
|
||||
}
|
||||
|
||||
final availableModels = _availableModels;
|
||||
final localAiState = _localAIState;
|
||||
|
||||
if (availableModels == null) {
|
||||
return AIModelState(
|
||||
type: AiType.cloud,
|
||||
hintText: LocaleKeys.chat_inputMessageHint.tr(),
|
||||
isEditable: true,
|
||||
tooltip: null,
|
||||
localAIEnabled: false,
|
||||
);
|
||||
}
|
||||
if (localAiState == null) {
|
||||
return AIModelState(
|
||||
type: AiType.cloud,
|
||||
hintText: LocaleKeys.chat_inputMessageHint.tr(),
|
||||
tooltip: null,
|
||||
isEditable: true,
|
||||
localAIEnabled: false,
|
||||
);
|
||||
for (final cb in _availableModelsChangedCallbacks) {
|
||||
cb(_availableModels, _selectedModel);
|
||||
}
|
||||
}
|
||||
|
||||
if (!availableModels.selectedModel.isLocal) {
|
||||
return AIModelState(
|
||||
type: AiType.cloud,
|
||||
hintText: LocaleKeys.chat_inputMessageHint.tr(),
|
||||
tooltip: null,
|
||||
isEditable: true,
|
||||
localAIEnabled: false,
|
||||
);
|
||||
}
|
||||
|
||||
final editable = localAiState.state == RunningStatePB.Running;
|
||||
final tooltip = localAiState.enabled
|
||||
? (editable
|
||||
? null
|
||||
: LocaleKeys.settings_aiPage_keys_localAINotReadyTextFieldPrompt
|
||||
.tr())
|
||||
: LocaleKeys.settings_aiPage_keys_localAIDisabledTextFieldPrompt.tr();
|
||||
|
||||
final hintText = localAiState.enabled
|
||||
? (editable
|
||||
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
|
||||
: LocaleKeys.settings_aiPage_keys_localAIInitializing.tr())
|
||||
: LocaleKeys.settings_aiPage_keys_localAIDisabled.tr();
|
||||
|
||||
return AIModelState(
|
||||
type: AiType.local,
|
||||
hintText: hintText,
|
||||
tooltip: tooltip,
|
||||
isEditable: editable,
|
||||
localAIEnabled: localAiState.enabled,
|
||||
Future<void> _loadModelSelection() async {
|
||||
await AIEventGetSourceModelSelection(
|
||||
ModelSourcePB(source: objectId),
|
||||
).send().fold(
|
||||
(ms) {
|
||||
_modelSelection = ms;
|
||||
_availableModels = ms.models;
|
||||
_selectedModel = ms.selectedModel;
|
||||
},
|
||||
(e) => Log.error("Failed to fetch models: \$e"),
|
||||
);
|
||||
}
|
||||
|
||||
(List<AIModelPB>, AIModelPB?) getAvailableModels() {
|
||||
final availableModels = _availableModels;
|
||||
if (availableModels == null) {
|
||||
return ([], null);
|
||||
}
|
||||
return (availableModels.models, availableModels.selectedModel);
|
||||
}
|
||||
|
||||
void _notifyAvailableModelsChanged() {
|
||||
final (models, selectedModel) = getAvailableModels();
|
||||
for (final callback in _availableModelsChangedCallbacks) {
|
||||
callback(models, selectedModel);
|
||||
}
|
||||
}
|
||||
|
||||
void _notifyStateChanged() {
|
||||
final state = getState();
|
||||
for (final callback in _stateChangedCallbacks) {
|
||||
callback(state);
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _loadAvailableModels() {
|
||||
final payload = AvailableModelsQueryPB(source: objectId);
|
||||
return AIEventGetAvailableModels(payload).send().fold(
|
||||
(models) => _availableModels = models,
|
||||
(err) => Log.error("Failed to get available models: $err"),
|
||||
Future<void> _loadLocalState() async {
|
||||
await AIEventGetLocalAIState().send().fold(
|
||||
(s) => _localAIState = s,
|
||||
(e) => Log.error("Failed to fetch local AI state: \$e"),
|
||||
);
|
||||
}
|
||||
|
||||
Future<void> _loadLocalAiState() {
|
||||
return AIEventGetLocalAIState().send().fold(
|
||||
(localAIState) => _localAIState = localAIState,
|
||||
(error) => Log.error("Failed to get local AI state: $error"),
|
||||
);
|
||||
static AIModelState _defaultState() => AIModelState(
|
||||
type: AiType.cloud,
|
||||
hintText: LocaleKeys.chat_inputMessageHint.tr(),
|
||||
tooltip: null,
|
||||
isEditable: true,
|
||||
localAIEnabled: false,
|
||||
);
|
||||
|
||||
/// Core logic computing the state from local and selection data
|
||||
AIModelState _computeState() {
|
||||
if (UniversalPlatform.isMobile) return _defaultState();
|
||||
|
||||
if (_modelSelection == null || _localAIState == null) {
|
||||
return _defaultState();
|
||||
}
|
||||
|
||||
if (!_selectedModel!.isLocal) {
|
||||
return _defaultState();
|
||||
}
|
||||
|
||||
final enabled = _localAIState!.enabled;
|
||||
final running = _localAIState!.state == RunningStatePB.Running;
|
||||
final hintKey = enabled
|
||||
? (running
|
||||
? LocaleKeys.chat_inputLocalAIMessageHint
|
||||
: LocaleKeys.settings_aiPage_keys_localAIInitializing)
|
||||
: LocaleKeys.settings_aiPage_keys_localAIDisabled;
|
||||
final tooltipKey = enabled
|
||||
? (running
|
||||
? null
|
||||
: LocaleKeys.settings_aiPage_keys_localAINotReadyTextFieldPrompt)
|
||||
: LocaleKeys.settings_aiPage_keys_localAIDisabledTextFieldPrompt;
|
||||
|
||||
return AIModelState(
|
||||
type: AiType.local,
|
||||
hintText: hintKey.tr(),
|
||||
tooltip: tooltipKey?.tr(),
|
||||
isEditable: running,
|
||||
localAIEnabled: enabled,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
extension AiModelExtension on AIModelPB {
|
||||
bool get isDefault {
|
||||
return name == "Auto";
|
||||
}
|
||||
|
||||
String get i18n {
|
||||
return isDefault ? LocaleKeys.chat_switchModel_autoModel.tr() : name;
|
||||
}
|
||||
extension AIModelPBExtension on AIModelPB {
|
||||
bool get isDefault => name == 'Auto';
|
||||
String get i18n =>
|
||||
isDefault ? LocaleKeys.chat_switchModel_autoModel.tr() : name;
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ class SelectModelState with _$SelectModelState {
|
||||
}) = _SelectModelState;
|
||||
|
||||
factory SelectModelState.initial(AIModelStateNotifier notifier) {
|
||||
final (models, selectedModel) = notifier.getAvailableModels();
|
||||
final (models, selectedModel) = notifier.getModelSelection();
|
||||
return SelectModelState(
|
||||
models: models,
|
||||
selectedModel: selectedModel,
|
||||
|
||||
@ -90,38 +90,40 @@ class SelectModelPopoverContent extends StatelessWidget {
|
||||
|
||||
return Padding(
|
||||
padding: const EdgeInsets.all(8.0),
|
||||
child: Column(
|
||||
mainAxisSize: MainAxisSize.min,
|
||||
crossAxisAlignment: CrossAxisAlignment.start,
|
||||
children: [
|
||||
if (localModels.isNotEmpty) ...[
|
||||
_ModelSectionHeader(
|
||||
title: LocaleKeys.chat_switchModel_localModel.tr(),
|
||||
child: SingleChildScrollView(
|
||||
child: Column(
|
||||
mainAxisSize: MainAxisSize.min,
|
||||
crossAxisAlignment: CrossAxisAlignment.start,
|
||||
children: [
|
||||
if (localModels.isNotEmpty) ...[
|
||||
_ModelSectionHeader(
|
||||
title: LocaleKeys.chat_switchModel_localModel.tr(),
|
||||
),
|
||||
const VSpace(4.0),
|
||||
],
|
||||
...localModels.map(
|
||||
(model) => _ModelItem(
|
||||
model: model,
|
||||
isSelected: model == selectedModel,
|
||||
onTap: () => onSelectModel?.call(model),
|
||||
),
|
||||
),
|
||||
if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[
|
||||
const VSpace(8.0),
|
||||
_ModelSectionHeader(
|
||||
title: LocaleKeys.chat_switchModel_cloudModel.tr(),
|
||||
),
|
||||
const VSpace(4.0),
|
||||
],
|
||||
...cloudModels.map(
|
||||
(model) => _ModelItem(
|
||||
model: model,
|
||||
isSelected: model == selectedModel,
|
||||
onTap: () => onSelectModel?.call(model),
|
||||
),
|
||||
),
|
||||
const VSpace(4.0),
|
||||
],
|
||||
...localModels.map(
|
||||
(model) => _ModelItem(
|
||||
model: model,
|
||||
isSelected: model == selectedModel,
|
||||
onTap: () => onSelectModel?.call(model),
|
||||
),
|
||||
),
|
||||
if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[
|
||||
const VSpace(8.0),
|
||||
_ModelSectionHeader(
|
||||
title: LocaleKeys.chat_switchModel_cloudModel.tr(),
|
||||
),
|
||||
const VSpace(4.0),
|
||||
],
|
||||
...cloudModels.map(
|
||||
(model) => _ModelItem(
|
||||
model: model,
|
||||
isSelected: model == selectedModel,
|
||||
onTap: () => onSelectModel?.call(model),
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
@ -215,45 +217,41 @@ class _CurrentModelButton extends StatelessWidget {
|
||||
behavior: HitTestBehavior.opaque,
|
||||
child: SizedBox(
|
||||
height: DesktopAIPromptSizes.actionBarButtonSize,
|
||||
child: AnimatedSize(
|
||||
duration: const Duration(milliseconds: 50),
|
||||
curve: Curves.easeInOut,
|
||||
alignment: AlignmentDirectional.centerStart,
|
||||
child: FlowyHover(
|
||||
style: const HoverStyle(
|
||||
borderRadius: BorderRadius.all(Radius.circular(8)),
|
||||
),
|
||||
child: Padding(
|
||||
padding: const EdgeInsetsDirectional.all(4.0),
|
||||
child: Row(
|
||||
children: [
|
||||
Padding(
|
||||
// TODO: remove this after change icon to 20px
|
||||
padding: EdgeInsets.all(2),
|
||||
child: FlowySvg(
|
||||
FlowySvgs.ai_sparks_s,
|
||||
color: Theme.of(context).hintColor,
|
||||
size: Size.square(16),
|
||||
),
|
||||
),
|
||||
if (model != null && !model!.isDefault)
|
||||
Padding(
|
||||
padding: EdgeInsetsDirectional.only(end: 2.0),
|
||||
child: FlowyText(
|
||||
model!.i18n,
|
||||
fontSize: 12,
|
||||
figmaLineHeight: 16,
|
||||
color: Theme.of(context).hintColor,
|
||||
overflow: TextOverflow.ellipsis,
|
||||
),
|
||||
),
|
||||
FlowySvg(
|
||||
FlowySvgs.ai_source_drop_down_s,
|
||||
child: FlowyHover(
|
||||
style: const HoverStyle(
|
||||
borderRadius: BorderRadius.all(Radius.circular(8)),
|
||||
),
|
||||
child: Padding(
|
||||
padding: const EdgeInsetsDirectional.all(4.0),
|
||||
child: Row(
|
||||
mainAxisSize: MainAxisSize.min,
|
||||
children: [
|
||||
Padding(
|
||||
// TODO: remove this after change icon to 20px
|
||||
padding: EdgeInsets.all(2),
|
||||
child: FlowySvg(
|
||||
FlowySvgs.ai_sparks_s,
|
||||
color: Theme.of(context).hintColor,
|
||||
size: const Size.square(8),
|
||||
size: Size.square(16),
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
if (model != null && !model!.isDefault)
|
||||
Padding(
|
||||
padding: EdgeInsetsDirectional.only(end: 2.0),
|
||||
child: FlowyText(
|
||||
model!.i18n,
|
||||
fontSize: 12,
|
||||
figmaLineHeight: 16,
|
||||
color: Theme.of(context).hintColor,
|
||||
overflow: TextOverflow.ellipsis,
|
||||
),
|
||||
),
|
||||
FlowySvg(
|
||||
FlowySvgs.ai_source_drop_down_s,
|
||||
color: Theme.of(context).hintColor,
|
||||
size: const Size.square(8),
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
),
|
||||
|
||||
@ -448,7 +448,7 @@ class _ChangeModelButtonState extends State<ChangeModelButton> {
|
||||
child: buildButton(context),
|
||||
popupBuilder: (_) {
|
||||
final bloc = context.read<AIPromptInputBloc>();
|
||||
final (models, _) = bloc.aiModelStateNotifier.getAvailableModels();
|
||||
final (models, _) = bloc.aiModelStateNotifier.getModelSelection();
|
||||
return SelectModelPopoverContent(
|
||||
models: models,
|
||||
selectedModel: null,
|
||||
|
||||
@ -407,7 +407,7 @@ class ChatAIMessagePopup extends StatelessWidget {
|
||||
return MobileQuickActionButton(
|
||||
onTap: () async {
|
||||
final bloc = context.read<AIPromptInputBloc>();
|
||||
final (models, _) = bloc.aiModelStateNotifier.getAvailableModels();
|
||||
final (models, _) = bloc.aiModelStateNotifier.getModelSelection();
|
||||
final result = await showChangeModelBottomSheet(context, models);
|
||||
if (result != null) {
|
||||
onChangeModel?.call(result);
|
||||
|
||||
@ -6,174 +6,181 @@ import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
|
||||
import 'package:appflowy_result/appflowy_result.dart';
|
||||
import 'package:bloc/bloc.dart';
|
||||
import 'package:collection/collection.dart';
|
||||
import 'package:freezed_annotation/freezed_annotation.dart';
|
||||
import 'package:equatable/equatable.dart';
|
||||
import 'package:freezed_annotation/freezed_annotation.dart';
|
||||
|
||||
part 'ollama_setting_bloc.freezed.dart';
|
||||
|
||||
const kDefaultChatModel = 'llama3.1:latest';
|
||||
const kDefaultEmbeddingModel = 'nomic-embed-text:latest';
|
||||
|
||||
/// Extension methods to map between PB and UI models
|
||||
class OllamaSettingBloc extends Bloc<OllamaSettingEvent, OllamaSettingState> {
|
||||
OllamaSettingBloc() : super(const OllamaSettingState()) {
|
||||
on<OllamaSettingEvent>(_handleEvent);
|
||||
on<_Started>(_handleStarted);
|
||||
on<_DidLoadLocalModels>(_onLoadLocalModels);
|
||||
on<_DidLoadSetting>(_onLoadSetting);
|
||||
on<_UpdateSetting>(_onLoadSetting);
|
||||
on<_OnEdit>(_onEdit);
|
||||
on<_OnSubmit>(_onSubmit);
|
||||
on<_SetDefaultModel>(_onSetDefaultModel);
|
||||
}
|
||||
|
||||
Future<void> _handleEvent(
|
||||
OllamaSettingEvent event,
|
||||
Future<void> _handleStarted(
|
||||
_Started event,
|
||||
Emitter<OllamaSettingState> emit,
|
||||
) async {
|
||||
event.when(
|
||||
started: () {
|
||||
AIEventGetLocalAISetting().send().fold(
|
||||
(setting) {
|
||||
if (!isClosed) {
|
||||
add(OllamaSettingEvent.didLoadSetting(setting));
|
||||
}
|
||||
},
|
||||
Log.error,
|
||||
);
|
||||
},
|
||||
didLoadSetting: (setting) => _updateSetting(setting, emit),
|
||||
updateSetting: (setting) => _updateSetting(setting, emit),
|
||||
onEdit: (content, settingType) {
|
||||
final updatedSubmittedItems = state.submittedItems
|
||||
.map(
|
||||
(item) => item.settingType == settingType
|
||||
? SubmittedItem(
|
||||
content: content,
|
||||
settingType: item.settingType,
|
||||
)
|
||||
: item,
|
||||
)
|
||||
.toList();
|
||||
try {
|
||||
final results = await Future.wait([
|
||||
AIEventGetLocalModelSelection().send().then((r) => r.getOrThrow()),
|
||||
AIEventGetLocalAISetting().send().then((r) => r.getOrThrow()),
|
||||
]);
|
||||
|
||||
// Convert both lists to maps: {settingType: content}
|
||||
final updatedMap = {
|
||||
for (final item in updatedSubmittedItems)
|
||||
item.settingType: item.content,
|
||||
};
|
||||
final models = results[0] as ModelSelectionPB;
|
||||
final setting = results[1] as LocalAISettingPB;
|
||||
|
||||
final inputMap = {
|
||||
for (final item in state.inputItems) item.settingType: item.content,
|
||||
};
|
||||
|
||||
// Compare maps instead of lists
|
||||
final isEdited = !const MapEquality<SettingType, String>()
|
||||
.equals(updatedMap, inputMap);
|
||||
|
||||
emit(
|
||||
state.copyWith(
|
||||
submittedItems: updatedSubmittedItems,
|
||||
isEdited: isEdited,
|
||||
),
|
||||
);
|
||||
},
|
||||
submit: () {
|
||||
final setting = LocalAISettingPB();
|
||||
final settingUpdaters = <SettingType, void Function(String)>{
|
||||
SettingType.serverUrl: (value) => setting.serverUrl = value,
|
||||
SettingType.chatModel: (value) => setting.chatModelName = value,
|
||||
SettingType.embeddingModel: (value) =>
|
||||
setting.embeddingModelName = value,
|
||||
};
|
||||
|
||||
for (final item in state.submittedItems) {
|
||||
settingUpdaters[item.settingType]?.call(item.content);
|
||||
}
|
||||
add(OllamaSettingEvent.updateSetting(setting));
|
||||
AIEventUpdateLocalAISetting(setting).send().fold(
|
||||
(_) => Log.info('AI setting updated successfully'),
|
||||
(err) => Log.error("update ai setting failed: $err"),
|
||||
);
|
||||
},
|
||||
);
|
||||
if (!isClosed) {
|
||||
add(OllamaSettingEvent.didLoadLocalModels(models));
|
||||
add(OllamaSettingEvent.didLoadSetting(setting));
|
||||
}
|
||||
} catch (e, st) {
|
||||
Log.error('Failed to load initial AI data: $e\n$st');
|
||||
}
|
||||
}
|
||||
|
||||
void _updateSetting(
|
||||
LocalAISettingPB setting,
|
||||
void _onLoadLocalModels(
|
||||
_DidLoadLocalModels event,
|
||||
Emitter<OllamaSettingState> emit,
|
||||
) {
|
||||
emit(state.copyWith(localModels: event.models));
|
||||
}
|
||||
|
||||
void _onLoadSetting(
|
||||
dynamic event,
|
||||
Emitter<OllamaSettingState> emit,
|
||||
) {
|
||||
final setting = (event as dynamic).setting as LocalAISettingPB;
|
||||
final submitted = setting.toSubmittedItems();
|
||||
emit(
|
||||
state.copyWith(
|
||||
setting: setting,
|
||||
inputItems: _createInputItems(setting),
|
||||
submittedItems: _createSubmittedItems(setting),
|
||||
isEdited: false, // Reset to false when the setting is loaded/updated.
|
||||
inputItems: setting.toInputItems(),
|
||||
submittedItems: submitted,
|
||||
originalMap: {
|
||||
for (final item in submitted) item.settingType: item.content,
|
||||
},
|
||||
isEdited: false,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
List<SettingItem> _createInputItems(LocalAISettingPB setting) => [
|
||||
SettingItem(
|
||||
content: setting.serverUrl,
|
||||
hintText: 'http://localhost:11434',
|
||||
settingType: SettingType.serverUrl,
|
||||
),
|
||||
SettingItem(
|
||||
content: setting.chatModelName,
|
||||
hintText: 'llama3.1',
|
||||
settingType: SettingType.chatModel,
|
||||
),
|
||||
SettingItem(
|
||||
content: setting.embeddingModelName,
|
||||
hintText: 'nomic-embed-text',
|
||||
settingType: SettingType.embeddingModel,
|
||||
),
|
||||
];
|
||||
void _onEdit(
|
||||
_OnEdit event,
|
||||
Emitter<OllamaSettingState> emit,
|
||||
) {
|
||||
final updated = state.submittedItems
|
||||
.map(
|
||||
(item) => item.settingType == event.settingType
|
||||
? item.copyWith(content: event.content)
|
||||
: item,
|
||||
)
|
||||
.toList();
|
||||
|
||||
List<SubmittedItem> _createSubmittedItems(LocalAISettingPB setting) => [
|
||||
SubmittedItem(
|
||||
content: setting.serverUrl,
|
||||
settingType: SettingType.serverUrl,
|
||||
),
|
||||
SubmittedItem(
|
||||
content: setting.chatModelName,
|
||||
settingType: SettingType.chatModel,
|
||||
),
|
||||
SubmittedItem(
|
||||
content: setting.embeddingModelName,
|
||||
settingType: SettingType.embeddingModel,
|
||||
),
|
||||
];
|
||||
final currentMap = {for (final i in updated) i.settingType: i.content};
|
||||
final isEdited = !const MapEquality<SettingType, String>()
|
||||
.equals(state.originalMap, currentMap);
|
||||
|
||||
emit(state.copyWith(submittedItems: updated, isEdited: isEdited));
|
||||
}
|
||||
|
||||
void _onSubmit(
|
||||
_OnSubmit event,
|
||||
Emitter<OllamaSettingState> emit,
|
||||
) {
|
||||
final pb = LocalAISettingPB();
|
||||
for (final item in state.submittedItems) {
|
||||
switch (item.settingType) {
|
||||
case SettingType.serverUrl:
|
||||
pb.serverUrl = item.content;
|
||||
break;
|
||||
case SettingType.chatModel:
|
||||
pb.globalChatModel = state.selectedModel?.name ?? item.content;
|
||||
break;
|
||||
case SettingType.embeddingModel:
|
||||
pb.embeddingModelName = item.content;
|
||||
break;
|
||||
}
|
||||
}
|
||||
add(OllamaSettingEvent.updateSetting(pb));
|
||||
AIEventUpdateLocalAISetting(pb).send().fold(
|
||||
(_) => Log.info('AI setting updated successfully'),
|
||||
(err) => Log.error('Update AI setting failed: $err'),
|
||||
);
|
||||
}
|
||||
|
||||
void _onSetDefaultModel(
|
||||
_SetDefaultModel event,
|
||||
Emitter<OllamaSettingState> emit,
|
||||
) {
|
||||
emit(state.copyWith(selectedModel: event.model, isEdited: true));
|
||||
}
|
||||
}
|
||||
|
||||
// Create an enum for setting type.
|
||||
/// Setting types for mapping
|
||||
enum SettingType {
|
||||
serverUrl,
|
||||
chatModel,
|
||||
embeddingModel; // semicolon needed after the enum values
|
||||
embeddingModel;
|
||||
|
||||
String get title {
|
||||
switch (this) {
|
||||
case SettingType.serverUrl:
|
||||
return 'Ollama server url';
|
||||
case SettingType.chatModel:
|
||||
return 'Chat model name';
|
||||
return 'Default model name';
|
||||
case SettingType.embeddingModel:
|
||||
return 'Embedding model name';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Input field representation
|
||||
class SettingItem extends Equatable {
|
||||
const SettingItem({
|
||||
required this.content,
|
||||
required this.hintText,
|
||||
required this.settingType,
|
||||
});
|
||||
|
||||
final String content;
|
||||
final String hintText;
|
||||
final SettingType settingType;
|
||||
|
||||
@override
|
||||
List<Object?> get props => [content, settingType];
|
||||
}
|
||||
|
||||
/// Items pending submission
|
||||
class SubmittedItem extends Equatable {
|
||||
const SubmittedItem({
|
||||
required this.content,
|
||||
required this.settingType,
|
||||
});
|
||||
|
||||
final String content;
|
||||
final SettingType settingType;
|
||||
|
||||
/// Returns a copy of this SubmittedItem with given fields updated.
|
||||
SubmittedItem copyWith({
|
||||
String? content,
|
||||
SettingType? settingType,
|
||||
}) {
|
||||
return SubmittedItem(
|
||||
content: content ?? this.content,
|
||||
settingType: settingType ?? this.settingType,
|
||||
);
|
||||
}
|
||||
|
||||
@override
|
||||
List<Object?> get props => [content, settingType];
|
||||
}
|
||||
@ -181,10 +188,18 @@ class SubmittedItem extends Equatable {
|
||||
@freezed
|
||||
class OllamaSettingEvent with _$OllamaSettingEvent {
|
||||
const factory OllamaSettingEvent.started() = _Started;
|
||||
const factory OllamaSettingEvent.didLoadSetting(LocalAISettingPB setting) =
|
||||
_DidLoadSetting;
|
||||
const factory OllamaSettingEvent.updateSetting(LocalAISettingPB setting) =
|
||||
_UpdateSetting;
|
||||
const factory OllamaSettingEvent.didLoadLocalModels(
|
||||
ModelSelectionPB models,
|
||||
) = _DidLoadLocalModels;
|
||||
const factory OllamaSettingEvent.didLoadSetting(
|
||||
LocalAISettingPB setting,
|
||||
) = _DidLoadSetting;
|
||||
const factory OllamaSettingEvent.updateSetting(
|
||||
LocalAISettingPB setting,
|
||||
) = _UpdateSetting;
|
||||
const factory OllamaSettingEvent.setDefaultModel(
|
||||
AIModelPB model,
|
||||
) = _SetDefaultModel;
|
||||
const factory OllamaSettingEvent.onEdit(
|
||||
String content,
|
||||
SettingType settingType,
|
||||
@ -196,25 +211,42 @@ class OllamaSettingEvent with _$OllamaSettingEvent {
|
||||
class OllamaSettingState with _$OllamaSettingState {
|
||||
const factory OllamaSettingState({
|
||||
LocalAISettingPB? setting,
|
||||
@Default([
|
||||
SettingItem(
|
||||
content: 'http://localhost:11434',
|
||||
hintText: 'http://localhost:11434',
|
||||
settingType: SettingType.serverUrl,
|
||||
),
|
||||
SettingItem(
|
||||
content: 'llama3.1',
|
||||
hintText: 'llama3.1',
|
||||
settingType: SettingType.chatModel,
|
||||
),
|
||||
SettingItem(
|
||||
content: 'nomic-embed-text',
|
||||
hintText: 'nomic-embed-text',
|
||||
settingType: SettingType.embeddingModel,
|
||||
),
|
||||
])
|
||||
List<SettingItem> inputItems,
|
||||
@Default([]) List<SettingItem> inputItems,
|
||||
AIModelPB? selectedModel,
|
||||
ModelSelectionPB? localModels,
|
||||
AIModelPB? defaultModel,
|
||||
@Default([]) List<SubmittedItem> submittedItems,
|
||||
@Default(false) bool isEdited,
|
||||
}) = _PluginStateState;
|
||||
@Default({}) Map<SettingType, String> originalMap,
|
||||
}) = _OllamaSettingState;
|
||||
}
|
||||
|
||||
extension on LocalAISettingPB {
|
||||
List<SettingItem> toInputItems() => [
|
||||
SettingItem(
|
||||
content: serverUrl,
|
||||
hintText: 'http://localhost:11434',
|
||||
settingType: SettingType.serverUrl,
|
||||
),
|
||||
SettingItem(
|
||||
content: embeddingModelName,
|
||||
hintText: kDefaultEmbeddingModel,
|
||||
settingType: SettingType.embeddingModel,
|
||||
),
|
||||
];
|
||||
|
||||
List<SubmittedItem> toSubmittedItems() => [
|
||||
SubmittedItem(
|
||||
content: serverUrl,
|
||||
settingType: SettingType.serverUrl,
|
||||
),
|
||||
SubmittedItem(
|
||||
content: globalChatModel,
|
||||
settingType: SettingType.chatModel,
|
||||
),
|
||||
SubmittedItem(
|
||||
content: embeddingModelName,
|
||||
settingType: SettingType.embeddingModel,
|
||||
),
|
||||
];
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
|
||||
),
|
||||
);
|
||||
},
|
||||
didLoadAvailableModels: (AvailableModelsPB models) {
|
||||
didLoadAvailableModels: (ModelSelectionPB models) {
|
||||
emit(
|
||||
state.copyWith(
|
||||
availableModels: models,
|
||||
@ -134,7 +134,8 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
|
||||
);
|
||||
|
||||
void _loadModelList() {
|
||||
AIEventGetServerAvailableModels().send().then((result) {
|
||||
final payload = ModelSourcePB(source: aiModelsGlobalActiveModel);
|
||||
AIEventGetSettingModelSelection(payload).send().then((result) {
|
||||
result.fold((models) {
|
||||
if (!isClosed) {
|
||||
add(SettingsAIEvent.didLoadAvailableModels(models));
|
||||
@ -175,7 +176,7 @@ class SettingsAIEvent with _$SettingsAIEvent {
|
||||
) = _DidReceiveUserProfile;
|
||||
|
||||
const factory SettingsAIEvent.didLoadAvailableModels(
|
||||
AvailableModelsPB models,
|
||||
ModelSelectionPB models,
|
||||
) = _DidLoadAvailableModels;
|
||||
}
|
||||
|
||||
@ -184,7 +185,7 @@ class SettingsAIState with _$SettingsAIState {
|
||||
const factory SettingsAIState({
|
||||
required UserProfilePB userProfile,
|
||||
WorkspaceSettingsPB? aiSettings,
|
||||
AvailableModelsPB? availableModels,
|
||||
ModelSelectionPB? availableModels,
|
||||
@Default(true) bool enableSearchIndexing,
|
||||
}) = _SettingsAIState;
|
||||
}
|
||||
|
||||
@ -31,7 +31,10 @@ HotKeyItem openSettingsHotKey(
|
||||
),
|
||||
keyDownHandler: (_) {
|
||||
if (_settingsDialogKey.currentContext == null) {
|
||||
showSettingsDialog(context);
|
||||
showSettingsDialog(
|
||||
context,
|
||||
userWorkspaceBloc: context.read<UserWorkspaceBloc>(),
|
||||
);
|
||||
} else {
|
||||
Navigator.of(context, rootNavigator: true)
|
||||
.popUntil((route) => route.isFirst);
|
||||
@ -110,7 +113,7 @@ class _UserSettingButtonState extends State<UserSettingButton> {
|
||||
|
||||
void showSettingsDialog(
|
||||
BuildContext context, {
|
||||
UserWorkspaceBloc? userWorkspaceBloc,
|
||||
required UserWorkspaceBloc userWorkspaceBloc,
|
||||
PasswordBloc? passwordBloc,
|
||||
SettingsPage? initPage,
|
||||
}) {
|
||||
@ -134,7 +137,7 @@ void showSettingsDialog(
|
||||
value: BlocProvider.of<DocumentAppearanceCubit>(dialogContext),
|
||||
),
|
||||
BlocProvider.value(
|
||||
value: userWorkspaceBloc ?? context.read<UserWorkspaceBloc>(),
|
||||
value: userWorkspaceBloc,
|
||||
),
|
||||
],
|
||||
child: SettingsDialog(
|
||||
|
||||
@ -7,6 +7,14 @@ import 'package:flowy_infra_ui/widget/spacing.dart';
|
||||
import 'package:flutter/material.dart';
|
||||
import 'package:flutter_bloc/flutter_bloc.dart';
|
||||
|
||||
import 'package:appflowy/ai/ai.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
|
||||
|
||||
import 'package:appflowy/generated/locale_keys.g.dart';
|
||||
import 'package:appflowy/workspace/presentation/settings/shared/af_dropdown_menu_entry.dart';
|
||||
import 'package:appflowy/workspace/presentation/settings/shared/settings_dropdown.dart';
|
||||
import 'package:easy_localization/easy_localization.dart';
|
||||
|
||||
class OllamaSettingPage extends StatelessWidget {
|
||||
const OllamaSettingPage({super.key});
|
||||
|
||||
@ -32,6 +40,7 @@ class OllamaSettingPage extends StatelessWidget {
|
||||
children: [
|
||||
for (final item in state.inputItems)
|
||||
_SettingItemWidget(item: item),
|
||||
const LocalAIModelSelection(),
|
||||
_SaveButton(isEdited: state.isEdited),
|
||||
],
|
||||
),
|
||||
@ -113,3 +122,59 @@ class _SaveButton extends StatelessWidget {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class LocalAIModelSelection extends StatelessWidget {
|
||||
const LocalAIModelSelection({super.key});
|
||||
static const double height = 49;
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
return BlocBuilder<OllamaSettingBloc, OllamaSettingState>(
|
||||
buildWhen: (previous, current) =>
|
||||
previous.localModels != current.localModels,
|
||||
builder: (context, state) {
|
||||
final models = state.localModels;
|
||||
if (models == null) {
|
||||
return const SizedBox(
|
||||
// Using same height as SettingsDropdown to avoid layout shift
|
||||
height: height,
|
||||
);
|
||||
}
|
||||
|
||||
return Column(
|
||||
crossAxisAlignment: CrossAxisAlignment.start,
|
||||
children: [
|
||||
FlowyText.medium(
|
||||
LocaleKeys.settings_aiPage_keys_globalLLMModel.tr(),
|
||||
fontSize: 12,
|
||||
figmaLineHeight: 16,
|
||||
),
|
||||
const VSpace(4),
|
||||
SizedBox(
|
||||
height: 40,
|
||||
child: SettingsDropdown<AIModelPB>(
|
||||
key: const Key('_AIModelSelection'),
|
||||
onChanged: (model) => context
|
||||
.read<OllamaSettingBloc>()
|
||||
.add(OllamaSettingEvent.setDefaultModel(model)),
|
||||
selectedOption: models.selectedModel,
|
||||
selectOptionCompare: (left, right) => left?.name == right?.name,
|
||||
options: models.models
|
||||
.map(
|
||||
(model) => buildDropdownMenuEntry<AIModelPB>(
|
||||
context,
|
||||
value: model,
|
||||
label: model.i18n,
|
||||
subLabel: model.desc,
|
||||
maximumHeight: height,
|
||||
),
|
||||
)
|
||||
.toList(),
|
||||
),
|
||||
),
|
||||
],
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -866,6 +866,7 @@
|
||||
"aiSettingsDescription": "Choose your preferred model to power AppFlowy AI. Now includes GPT-4o, GPT-o3-mini, DeepSeek R1, Claude 3.5 Sonnet, and models available in Ollama",
|
||||
"loginToEnableAIFeature": "AI features are only enabled after logging in with @:appName Cloud. If you don't have an @:appName account, go to 'My Account' to sign up",
|
||||
"llmModel": "Language Model",
|
||||
"globalLLMModel": "Global Language Model",
|
||||
"llmModelType": "Language Model Type",
|
||||
"downloadLLMPrompt": "Download {}",
|
||||
"downloadAppFlowyOfflineAI": "Downloading AI offline package will enable AI to run on your device. Do you want to continue?",
|
||||
|
||||
98
frontend/rust-lib/Cargo.lock
generated
98
frontend/rust-lib/Cargo.lock
generated
@ -2210,12 +2210,6 @@ dependencies = [
|
||||
"litrs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dotenv"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
|
||||
|
||||
[[package]]
|
||||
name = "downcast-rs"
|
||||
version = "2.0.1"
|
||||
@ -2506,7 +2500,6 @@ dependencies = [
|
||||
"bytes",
|
||||
"collab-integrate",
|
||||
"dashmap 6.0.1",
|
||||
"dotenv",
|
||||
"flowy-ai-pub",
|
||||
"flowy-codegen",
|
||||
"flowy-derive",
|
||||
@ -2520,19 +2513,18 @@ dependencies = [
|
||||
"lib-infra",
|
||||
"log",
|
||||
"notify",
|
||||
"ollama-rs",
|
||||
"pin-project",
|
||||
"protobuf",
|
||||
"reqwest 0.11.27",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"simsimd",
|
||||
"strum_macros 0.21.1",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
"validator 0.18.1",
|
||||
]
|
||||
@ -2798,6 +2790,7 @@ dependencies = [
|
||||
"flowy-derive",
|
||||
"flowy-sqlite",
|
||||
"lib-dispatch",
|
||||
"ollama-rs",
|
||||
"protobuf",
|
||||
"r2d2",
|
||||
"reqwest 0.11.27",
|
||||
@ -4044,6 +4037,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown 0.12.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4894,6 +4888,23 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ollama-rs"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a4b4750770584c8b4a643d0329e7bedacc4ecf68b7c7ac3e1fec2bafd6312f7"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"log",
|
||||
"reqwest 0.12.15",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"static_assertions",
|
||||
"thiserror 2.0.12",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
@ -5696,7 +5707,7 @@ dependencies = [
|
||||
"rustc-hash 2.1.0",
|
||||
"rustls 0.23.20",
|
||||
"socket2 0.5.5",
|
||||
"thiserror 2.0.9",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
@ -5715,7 +5726,7 @@ dependencies = [
|
||||
"rustls 0.23.20",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror 2.0.9",
|
||||
"thiserror 2.0.12",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
@ -6407,6 +6418,31 @@ dependencies = [
|
||||
"parking_lot 0.12.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schemars"
|
||||
version = "0.8.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
|
||||
dependencies = [
|
||||
"dyn-clone",
|
||||
"indexmap 1.9.3",
|
||||
"schemars_derive",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schemars_derive"
|
||||
version = "0.8.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde_derive_internals 0.29.1",
|
||||
"syn 2.0.94",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
@ -6554,6 +6590,17 @@ dependencies = [
|
||||
"syn 2.0.94",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive_internals"
|
||||
version = "0.29.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.94",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_html_form"
|
||||
version = "0.2.7"
|
||||
@ -6746,15 +6793,6 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simsimd"
|
||||
version = "4.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "0.3.11"
|
||||
@ -6841,6 +6879,12 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "string_cache"
|
||||
version = "0.8.7"
|
||||
@ -7098,7 +7142,7 @@ dependencies = [
|
||||
"tantivy-stacker",
|
||||
"tantivy-tokenizer-api",
|
||||
"tempfile",
|
||||
"thiserror 2.0.9",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"uuid",
|
||||
"winapi",
|
||||
@ -7280,11 +7324,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.9"
|
||||
version = "2.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc"
|
||||
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
|
||||
dependencies = [
|
||||
"thiserror-impl 2.0.9",
|
||||
"thiserror-impl 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -7300,9 +7344,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.9"
|
||||
version = "2.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4"
|
||||
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -7823,7 +7867,7 @@ checksum = "7a94b0f0954b3e59bfc2c246b4c8574390d94a4ad4ad246aaf2fb07d7dfd3b47"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde_derive_internals",
|
||||
"serde_derive_internals 0.28.0",
|
||||
"syn 2.0.94",
|
||||
]
|
||||
|
||||
|
||||
@ -0,0 +1,54 @@
|
||||
use diesel::sqlite::SqliteConnection;
|
||||
use flowy_error::FlowyResult;
|
||||
use flowy_sqlite::upsert::excluded;
|
||||
use flowy_sqlite::{
|
||||
diesel,
|
||||
query_dsl::*,
|
||||
schema::{local_ai_model_table, local_ai_model_table::dsl},
|
||||
ExpressionMethods, Identifiable, Insertable, Queryable,
|
||||
};
|
||||
|
||||
#[derive(Clone, Default, Queryable, Insertable, Identifiable)]
|
||||
#[diesel(table_name = local_ai_model_table)]
|
||||
#[diesel(primary_key(name))]
|
||||
pub struct LocalAIModelTable {
|
||||
pub name: String,
|
||||
pub model_type: i16,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy)]
|
||||
pub enum ModelType {
|
||||
Embedding = 0,
|
||||
Chat = 1,
|
||||
}
|
||||
|
||||
impl From<i16> for ModelType {
|
||||
fn from(value: i16) -> Self {
|
||||
match value {
|
||||
0 => ModelType::Embedding,
|
||||
1 => ModelType::Chat,
|
||||
_ => ModelType::Embedding,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_local_ai_model(conn: &mut SqliteConnection, name: &str) -> Option<LocalAIModelTable> {
|
||||
local_ai_model_table::table
|
||||
.filter(dsl::name.eq(name))
|
||||
.first::<LocalAIModelTable>(conn)
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub fn upsert_local_ai_model(
|
||||
conn: &mut SqliteConnection,
|
||||
row: &LocalAIModelTable,
|
||||
) -> FlowyResult<()> {
|
||||
diesel::insert_into(local_ai_model_table::table)
|
||||
.values(row)
|
||||
.on_conflict(local_ai_model_table::name)
|
||||
.do_update()
|
||||
.set((local_ai_model_table::model_type.eq(excluded(local_ai_model_table::model_type)),))
|
||||
.execute(conn)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -1,5 +1,7 @@
|
||||
mod chat_message_sql;
|
||||
mod chat_sql;
|
||||
mod local_model_sql;
|
||||
|
||||
pub use chat_message_sql::*;
|
||||
pub use chat_sql::*;
|
||||
pub use local_model_sql::*;
|
||||
|
||||
@ -48,16 +48,16 @@ collab-integrate.workspace = true
|
||||
|
||||
[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies]
|
||||
notify = "6.1.1"
|
||||
ollama-rs = "0.3.0"
|
||||
#faiss = { version = "0.12.1" }
|
||||
af-mcp = { version = "0.1.0" }
|
||||
|
||||
[dev-dependencies]
|
||||
dotenv = "0.15.0"
|
||||
uuid.workspace = true
|
||||
tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] }
|
||||
simsimd = "4.4.0"
|
||||
|
||||
[build-dependencies]
|
||||
flowy-codegen.workspace = true
|
||||
|
||||
[features]
|
||||
dart = ["flowy-codegen/dart", "flowy-notification/dart"]
|
||||
local_ai = []
|
||||
@ -1,7 +1,7 @@
|
||||
use crate::chat::Chat;
|
||||
use crate::entities::{
|
||||
AIModelPB, AvailableModelsPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB,
|
||||
FilePB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams,
|
||||
AIModelPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB, FilePB,
|
||||
ModelSelectionPB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams,
|
||||
};
|
||||
use crate::local_ai::controller::{LocalAIController, LocalAISetting};
|
||||
use crate::middleware::chat_service_mw::ChatServiceMiddleware;
|
||||
@ -330,14 +330,10 @@ impl AIManager {
|
||||
.get_question_id_from_answer_id(chat_id, answer_message_id)
|
||||
.await?;
|
||||
|
||||
let model = model.map_or_else(
|
||||
|| {
|
||||
self
|
||||
.store_preferences
|
||||
.get_object::<AIModel>(&ai_available_models_key(&chat_id.to_string()))
|
||||
},
|
||||
|model| Some(model.into()),
|
||||
);
|
||||
let model = match model {
|
||||
None => self.get_active_model(&chat_id.to_string()).await,
|
||||
Some(model) => Some(model.into()),
|
||||
};
|
||||
chat
|
||||
.stream_regenerate_response(question_message_id, answer_stream_port, format, model)
|
||||
.await?;
|
||||
@ -345,18 +341,32 @@ impl AIManager {
|
||||
}
|
||||
|
||||
pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> {
|
||||
let previous_model = self.local_ai.get_local_ai_setting().chat_model_name;
|
||||
self.local_ai.update_local_ai_setting(setting).await?;
|
||||
let current_model = self.local_ai.get_local_ai_setting().chat_model_name;
|
||||
let old_settings = self.local_ai.get_local_ai_setting();
|
||||
// Only restart if the server URL has changed and local AI is not running
|
||||
let need_restart =
|
||||
old_settings.ollama_server_url != setting.ollama_server_url && !self.local_ai.is_running();
|
||||
|
||||
if previous_model != current_model {
|
||||
// Update settings first
|
||||
self
|
||||
.local_ai
|
||||
.update_local_ai_setting(setting.clone())
|
||||
.await?;
|
||||
|
||||
// Handle model change if needed
|
||||
let model_changed = old_settings.chat_model_name != setting.chat_model_name;
|
||||
if model_changed {
|
||||
info!(
|
||||
"[AI Plugin] update global active model, previous: {}, current: {}",
|
||||
previous_model, current_model
|
||||
old_settings.chat_model_name, setting.chat_model_name
|
||||
);
|
||||
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
|
||||
let model = AIModel::local(current_model, "".to_string());
|
||||
self.update_selected_model(source_key, model).await?;
|
||||
let model = AIModel::local(setting.chat_model_name, "".to_string());
|
||||
self
|
||||
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if need_restart {
|
||||
self.local_ai.restart_plugin().await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -440,16 +450,16 @@ impl AIManager {
|
||||
}
|
||||
|
||||
pub async fn update_selected_model(&self, source: String, model: AIModel) -> FlowyResult<()> {
|
||||
info!(
|
||||
"[Model Selection] update {} selected model: {:?}",
|
||||
source, model
|
||||
);
|
||||
let source_key = ai_available_models_key(&source);
|
||||
info!(
|
||||
"[Model Selection] update {} selected model: {:?} for key:{}",
|
||||
source, model, source_key
|
||||
);
|
||||
self
|
||||
.store_preferences
|
||||
.set_object::<AIModel>(&source_key, &model)?;
|
||||
|
||||
chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel)
|
||||
chat_notification_builder(&source_key, ChatNotification::DidUpdateSelectedModel)
|
||||
.payload(AIModelPB::from(model))
|
||||
.send();
|
||||
Ok(())
|
||||
@ -458,12 +468,13 @@ impl AIManager {
|
||||
#[instrument(skip_all, level = "debug")]
|
||||
pub async fn toggle_local_ai(&self) -> FlowyResult<()> {
|
||||
let enabled = self.local_ai.toggle_local_ai().await?;
|
||||
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
|
||||
if enabled {
|
||||
if let Some(name) = self.local_ai.get_plugin_chat_model() {
|
||||
info!("Set global active model to local ai: {}", name);
|
||||
let model = AIModel::local(name, "".to_string());
|
||||
self.update_selected_model(source_key, model).await?;
|
||||
self
|
||||
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), model)
|
||||
.await?;
|
||||
}
|
||||
} else {
|
||||
info!("Set global active model to default");
|
||||
@ -471,7 +482,7 @@ impl AIManager {
|
||||
let models = self.get_server_available_models().await?;
|
||||
if let Some(model) = models.into_iter().find(|m| m.name == global_active_model) {
|
||||
self
|
||||
.update_selected_model(source_key, AIModel::from(model))
|
||||
.update_selected_model(GLOBAL_ACTIVE_MODEL_KEY.to_string(), AIModel::from(model))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
@ -484,119 +495,140 @@ impl AIManager {
|
||||
.store_preferences
|
||||
.get_object::<AIModel>(&ai_available_models_key(source));
|
||||
|
||||
if model.is_none() {
|
||||
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
|
||||
model = Some(AIModel::local(local_model, "".to_string()));
|
||||
}
|
||||
match model {
|
||||
None => {
|
||||
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
|
||||
model = Some(AIModel::local(local_model, "".to_string()));
|
||||
}
|
||||
model
|
||||
},
|
||||
Some(mut model) => {
|
||||
let models = self.local_ai.get_all_chat_local_models().await;
|
||||
if !models.contains(&model) {
|
||||
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
|
||||
model = AIModel::local(local_model, "".to_string());
|
||||
}
|
||||
}
|
||||
Some(model)
|
||||
},
|
||||
}
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
pub async fn get_available_models(&self, source: String) -> FlowyResult<AvailableModelsPB> {
|
||||
pub async fn get_local_available_models(&self) -> FlowyResult<ModelSelectionPB> {
|
||||
let setting = self.local_ai.get_local_ai_setting();
|
||||
let mut models = self.local_ai.get_all_chat_local_models().await;
|
||||
let selected_model = AIModel::local(setting.chat_model_name, "".to_string());
|
||||
|
||||
if models.is_empty() {
|
||||
models.push(selected_model.clone());
|
||||
}
|
||||
|
||||
Ok(ModelSelectionPB {
|
||||
models: models.into_iter().map(AIModelPB::from).collect(),
|
||||
selected_model: AIModelPB::from(selected_model),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_available_models(
|
||||
&self,
|
||||
source: String,
|
||||
setting_only: bool,
|
||||
) -> FlowyResult<ModelSelectionPB> {
|
||||
let is_local_mode = self.user_service.is_local_model().await?;
|
||||
if is_local_mode {
|
||||
let setting = self.local_ai.get_local_ai_setting();
|
||||
let selected_model = AIModel::local(setting.chat_model_name, "".to_string());
|
||||
let models = vec![selected_model.clone()];
|
||||
|
||||
Ok(AvailableModelsPB {
|
||||
models: models.into_iter().map(|m| m.into()).collect(),
|
||||
selected_model: AIModelPB::from(selected_model),
|
||||
})
|
||||
} else {
|
||||
// Build the models list from server models and mark them as non-local.
|
||||
let mut models: Vec<AIModel> = self
|
||||
.get_server_available_models()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(AIModel::from)
|
||||
.collect();
|
||||
|
||||
trace!("[Model Selection]: Available models: {:?}", models);
|
||||
let mut current_active_local_ai_model = None;
|
||||
|
||||
// If user enable local ai, then add local ai model to the list.
|
||||
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
|
||||
let model = AIModel::local(local_model, "".to_string());
|
||||
current_active_local_ai_model = Some(model.clone());
|
||||
trace!("[Model Selection] current local ai model: {}", model.name);
|
||||
models.push(model);
|
||||
}
|
||||
|
||||
if models.is_empty() {
|
||||
return Ok(AvailableModelsPB {
|
||||
models: models.into_iter().map(|m| m.into()).collect(),
|
||||
selected_model: AIModelPB::default(),
|
||||
});
|
||||
}
|
||||
|
||||
// Global active model is the model selected by the user in the workspace settings.
|
||||
let mut server_active_model = self
|
||||
.get_workspace_select_model()
|
||||
.await
|
||||
.map(|m| AIModel::server(m, "".to_string()))
|
||||
.unwrap_or_else(|_| AIModel::default());
|
||||
|
||||
trace!(
|
||||
"[Model Selection] server active model: {:?}",
|
||||
server_active_model
|
||||
);
|
||||
|
||||
let mut user_selected_model = server_active_model.clone();
|
||||
// when current select model is deprecated, reset the model to default
|
||||
if !models.iter().any(|m| m.name == server_active_model.name) {
|
||||
server_active_model = AIModel::default();
|
||||
}
|
||||
|
||||
let source_key = ai_available_models_key(&source);
|
||||
// We use source to identify user selected model. source can be document id or chat id.
|
||||
match self.store_preferences.get_object::<AIModel>(&source_key) {
|
||||
None => {
|
||||
// when there is selected model and current local ai is active, then use local ai
|
||||
if let Some(local_ai_model) = models.iter().find(|m| m.is_local) {
|
||||
user_selected_model = local_ai_model.clone();
|
||||
}
|
||||
},
|
||||
Some(mut model) => {
|
||||
trace!("[Model Selection] user previous select model: {:?}", model);
|
||||
// If source is provided, try to get the user-selected model from the store. User selected
|
||||
// model will be used as the active model if it exists.
|
||||
if model.is_local {
|
||||
if let Some(local_ai_model) = ¤t_active_local_ai_model {
|
||||
if local_ai_model.name != model.name {
|
||||
model = local_ai_model.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user_selected_model = model;
|
||||
},
|
||||
}
|
||||
|
||||
// If user selected model is not available in the list, use the global active model.
|
||||
let active_model = models
|
||||
.iter()
|
||||
.find(|m| m.name == user_selected_model.name)
|
||||
.cloned()
|
||||
.or(Some(server_active_model.clone()));
|
||||
|
||||
// Update the stored preference if a different model is used.
|
||||
if let Some(ref active_model) = active_model {
|
||||
if active_model.name != user_selected_model.name {
|
||||
self
|
||||
.store_preferences
|
||||
.set_object::<AIModel>(&source_key, &active_model.clone())?;
|
||||
}
|
||||
}
|
||||
|
||||
trace!("[Model Selection] final active model: {:?}", active_model);
|
||||
let selected_model = AIModelPB::from(active_model.unwrap_or_default());
|
||||
Ok(AvailableModelsPB {
|
||||
models: models.into_iter().map(|m| m.into()).collect(),
|
||||
selected_model,
|
||||
})
|
||||
return self.get_local_available_models().await;
|
||||
}
|
||||
|
||||
// Fetch server models
|
||||
let mut all_models: Vec<AIModel> = self
|
||||
.get_server_available_models()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(AIModel::from)
|
||||
.collect();
|
||||
|
||||
trace!("[Model Selection]: Available models: {:?}", all_models);
|
||||
|
||||
// Add local models if enabled
|
||||
if self.local_ai.is_enabled() {
|
||||
if setting_only {
|
||||
let setting = self.local_ai.get_local_ai_setting();
|
||||
all_models.push(AIModel::local(setting.chat_model_name, "".to_string()));
|
||||
} else {
|
||||
all_models.extend(self.local_ai.get_all_chat_local_models().await);
|
||||
}
|
||||
}
|
||||
|
||||
// Return early if no models available
|
||||
if all_models.is_empty() {
|
||||
return Ok(ModelSelectionPB {
|
||||
models: Vec::new(),
|
||||
selected_model: AIModelPB::default(),
|
||||
});
|
||||
}
|
||||
|
||||
// Get server active model (only once)
|
||||
let server_active_model = self
|
||||
.get_workspace_select_model()
|
||||
.await
|
||||
.map(|m| AIModel::server(m, "".to_string()))
|
||||
.unwrap_or_else(|_| AIModel::default());
|
||||
|
||||
trace!(
|
||||
"[Model Selection] server active model: {:?}",
|
||||
server_active_model
|
||||
);
|
||||
|
||||
// Use server model as default if it exists in available models
|
||||
let default_model = if all_models
|
||||
.iter()
|
||||
.any(|m| m.name == server_active_model.name)
|
||||
{
|
||||
server_active_model.clone()
|
||||
} else {
|
||||
AIModel::default()
|
||||
};
|
||||
|
||||
// Get user's previously selected model
|
||||
let user_selected_model = match self.get_active_model(&source).await {
|
||||
Some(model) => {
|
||||
trace!("[Model Selection] user previous select model: {:?}", model);
|
||||
model
|
||||
},
|
||||
None => {
|
||||
// When no selected model and local AI is active, use local AI model
|
||||
all_models
|
||||
.iter()
|
||||
.find(|m| m.is_local)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| default_model.clone())
|
||||
},
|
||||
};
|
||||
|
||||
// Determine final active model - use user's selection if available, otherwise default
|
||||
let active_model = all_models
|
||||
.iter()
|
||||
.find(|m| m.name == user_selected_model.name)
|
||||
.cloned()
|
||||
.unwrap_or(default_model.clone());
|
||||
|
||||
// Update stored preference if changed
|
||||
if active_model.name != user_selected_model.name {
|
||||
if let Err(err) = self
|
||||
.update_selected_model(source, active_model.clone())
|
||||
.await
|
||||
{
|
||||
error!("[Model Selection] failed to update selected model: {}", err);
|
||||
}
|
||||
}
|
||||
|
||||
trace!("[Model Selection] final active model: {:?}", active_model);
|
||||
|
||||
// Create response with one transformation pass
|
||||
Ok(ModelSelectionPB {
|
||||
models: all_models.into_iter().map(AIModelPB::from).collect(),
|
||||
selected_model: AIModelPB::from(active_model),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_or_create_chat_instance(&self, chat_id: &Uuid) -> Result<Arc<Chat>, FlowyError> {
|
||||
|
||||
@ -182,7 +182,7 @@ pub struct ChatMessageListPB {
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Clone, Debug)]
|
||||
pub struct ServerAvailableModelsPB {
|
||||
pub struct ServerModelSelectionPB {
|
||||
#[pb(index = 1)]
|
||||
pub models: Vec<AvailableModelPB>,
|
||||
}
|
||||
@ -200,7 +200,7 @@ pub struct AvailableModelPB {
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
|
||||
pub struct AvailableModelsQueryPB {
|
||||
pub struct ModelSourcePB {
|
||||
#[pb(index = 1)]
|
||||
#[validate(custom(function = "required_not_empty_str"))]
|
||||
pub source: String,
|
||||
@ -217,7 +217,7 @@ pub struct UpdateSelectedModelPB {
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Clone, Debug)]
|
||||
pub struct AvailableModelsPB {
|
||||
pub struct ModelSelectionPB {
|
||||
#[pb(index = 1)]
|
||||
pub models: Vec<AIModelPB>,
|
||||
|
||||
@ -225,6 +225,12 @@ pub struct AvailableModelsPB {
|
||||
pub selected_model: AIModelPB,
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Clone, Debug)]
|
||||
pub struct RepeatedAIModelPB {
|
||||
#[pb(index = 1)]
|
||||
pub items: Vec<AIModelPB>,
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Clone, Debug)]
|
||||
pub struct AIModelPB {
|
||||
#[pb(index = 1)]
|
||||
@ -686,7 +692,7 @@ pub struct LocalAISettingPB {
|
||||
|
||||
#[pb(index = 2)]
|
||||
#[validate(custom(function = "required_not_empty_str"))]
|
||||
pub chat_model_name: String,
|
||||
pub global_chat_model: String,
|
||||
|
||||
#[pb(index = 3)]
|
||||
#[validate(custom(function = "required_not_empty_str"))]
|
||||
@ -697,7 +703,7 @@ impl From<LocalAISetting> for LocalAISettingPB {
|
||||
fn from(value: LocalAISetting) -> Self {
|
||||
LocalAISettingPB {
|
||||
server_url: value.ollama_server_url,
|
||||
chat_model_name: value.chat_model_name,
|
||||
global_chat_model: value.chat_model_name,
|
||||
embedding_model_name: value.embedding_model_name,
|
||||
}
|
||||
}
|
||||
@ -707,7 +713,7 @@ impl From<LocalAISettingPB> for LocalAISetting {
|
||||
fn from(value: LocalAISettingPB) -> Self {
|
||||
LocalAISetting {
|
||||
ollama_server_url: value.server_url,
|
||||
chat_model_name: value.chat_model_name,
|
||||
chat_model_name: value.global_chat_model,
|
||||
embedding_model_name: value.embedding_model_name,
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
use crate::ai_manager::{AIManager, GLOBAL_ACTIVE_MODEL_KEY};
|
||||
use crate::ai_manager::AIManager;
|
||||
use crate::completion::AICompletion;
|
||||
use crate::entities::*;
|
||||
use crate::util::ai_available_models_key;
|
||||
use flowy_ai_pub::cloud::{AIModel, ChatMessageType};
|
||||
use flowy_error::{ErrorCode, FlowyError, FlowyResult};
|
||||
use lib_dispatch::prelude::{data_result_ok, AFPluginData, AFPluginState, DataResult};
|
||||
@ -78,23 +77,24 @@ pub(crate) async fn regenerate_response_handler(
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) async fn get_server_model_list_handler(
|
||||
pub(crate) async fn get_setting_model_selection_handler(
|
||||
data: AFPluginData<ModelSourcePB>,
|
||||
ai_manager: AFPluginState<Weak<AIManager>>,
|
||||
) -> DataResult<AvailableModelsPB, FlowyError> {
|
||||
) -> DataResult<ModelSelectionPB, FlowyError> {
|
||||
let data = data.try_into_inner()?;
|
||||
let ai_manager = upgrade_ai_manager(ai_manager)?;
|
||||
let source_key = ai_available_models_key(GLOBAL_ACTIVE_MODEL_KEY);
|
||||
let models = ai_manager.get_available_models(source_key).await?;
|
||||
let models = ai_manager.get_available_models(data.source, true).await?;
|
||||
data_result_ok(models)
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) async fn get_chat_models_handler(
|
||||
data: AFPluginData<AvailableModelsQueryPB>,
|
||||
pub(crate) async fn get_source_model_selection_handler(
|
||||
data: AFPluginData<ModelSourcePB>,
|
||||
ai_manager: AFPluginState<Weak<AIManager>>,
|
||||
) -> DataResult<AvailableModelsPB, FlowyError> {
|
||||
) -> DataResult<ModelSelectionPB, FlowyError> {
|
||||
let data = data.try_into_inner()?;
|
||||
let ai_manager = upgrade_ai_manager(ai_manager)?;
|
||||
let models = ai_manager.get_available_models(data.source).await?;
|
||||
let models = ai_manager.get_available_models(data.source, false).await?;
|
||||
data_result_ok(models)
|
||||
}
|
||||
|
||||
@ -340,6 +340,15 @@ pub(crate) async fn get_local_ai_setting_handler(
|
||||
data_result_ok(pb)
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub(crate) async fn get_local_ai_models_handler(
|
||||
ai_manager: AFPluginState<Weak<AIManager>>,
|
||||
) -> DataResult<ModelSelectionPB, FlowyError> {
|
||||
let ai_manager = upgrade_ai_manager(ai_manager)?;
|
||||
let data = ai_manager.get_local_available_models().await?;
|
||||
data_result_ok(data)
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) async fn update_local_ai_setting_handler(
|
||||
ai_manager: AFPluginState<Weak<AIManager>>,
|
||||
|
||||
@ -31,20 +31,24 @@ pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
|
||||
.event(AIEvent::ToggleLocalAI, toggle_local_ai_handler)
|
||||
.event(AIEvent::GetLocalAIState, get_local_ai_state_handler)
|
||||
.event(AIEvent::GetLocalAISetting, get_local_ai_setting_handler)
|
||||
.event(AIEvent::GetLocalModelSelection, get_local_ai_models_handler)
|
||||
.event(
|
||||
AIEvent::GetSourceModelSelection,
|
||||
get_source_model_selection_handler,
|
||||
)
|
||||
.event(
|
||||
AIEvent::UpdateLocalAISetting,
|
||||
update_local_ai_setting_handler,
|
||||
)
|
||||
.event(
|
||||
AIEvent::GetServerAvailableModels,
|
||||
get_server_model_list_handler,
|
||||
)
|
||||
.event(AIEvent::CreateChatContext, create_chat_context_handler)
|
||||
.event(AIEvent::GetChatInfo, create_chat_context_handler)
|
||||
.event(AIEvent::GetChatSettings, get_chat_settings_handler)
|
||||
.event(AIEvent::UpdateChatSettings, update_chat_settings_handler)
|
||||
.event(AIEvent::RegenerateResponse, regenerate_response_handler)
|
||||
.event(AIEvent::GetAvailableModels, get_chat_models_handler)
|
||||
.event(
|
||||
AIEvent::GetSettingModelSelection,
|
||||
get_setting_model_selection_handler,
|
||||
)
|
||||
.event(AIEvent::UpdateSelectedModel, update_selected_model_handler)
|
||||
}
|
||||
|
||||
@ -107,18 +111,21 @@ pub enum AIEvent {
|
||||
#[event(input = "RegenerateResponsePB")]
|
||||
RegenerateResponse = 27,
|
||||
|
||||
#[event(output = "AvailableModelsPB")]
|
||||
GetServerAvailableModels = 28,
|
||||
|
||||
#[event(output = "LocalAISettingPB")]
|
||||
GetLocalAISetting = 29,
|
||||
|
||||
#[event(input = "LocalAISettingPB")]
|
||||
UpdateLocalAISetting = 30,
|
||||
|
||||
#[event(input = "AvailableModelsQueryPB", output = "AvailableModelsPB")]
|
||||
GetAvailableModels = 31,
|
||||
#[event(input = "ModelSourcePB", output = "ModelSelectionPB")]
|
||||
GetSettingModelSelection = 31,
|
||||
|
||||
#[event(input = "UpdateSelectedModelPB")]
|
||||
UpdateSelectedModel = 32,
|
||||
|
||||
#[event(output = "ModelSelectionPB")]
|
||||
GetLocalModelSelection = 33,
|
||||
|
||||
#[event(input = "ModelSourcePB", output = "ModelSelectionPB")]
|
||||
GetSourceModelSelection = 34,
|
||||
}
|
||||
|
||||
@ -16,9 +16,15 @@ use af_local_ai::ollama_plugin::OllamaAIPlugin;
|
||||
use af_plugin::core::path::is_plugin_ready;
|
||||
use af_plugin::core::plugin::RunningState;
|
||||
use arc_swap::ArcSwapOption;
|
||||
use flowy_ai_pub::cloud::AIModel;
|
||||
use flowy_ai_pub::persistence::{
|
||||
select_local_ai_model, upsert_local_ai_model, LocalAIModelTable, ModelType,
|
||||
};
|
||||
use flowy_ai_pub::user_service::AIUserService;
|
||||
use futures_util::SinkExt;
|
||||
use lib_infra::util::get_operating_system;
|
||||
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
|
||||
use ollama_rs::Ollama;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
@ -39,8 +45,8 @@ impl Default for LocalAISetting {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
ollama_server_url: "http://localhost:11434".to_string(),
|
||||
chat_model_name: "llama3.1".to_string(),
|
||||
embedding_model_name: "nomic-embed-text".to_string(),
|
||||
chat_model_name: "llama3.1:latest".to_string(),
|
||||
embedding_model_name: "nomic-embed-text:latest".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -53,6 +59,7 @@ pub struct LocalAIController {
|
||||
current_chat_id: ArcSwapOption<Uuid>,
|
||||
store_preferences: Weak<KVStorePreferences>,
|
||||
user_service: Arc<dyn AIUserService>,
|
||||
ollama: ArcSwapOption<Ollama>,
|
||||
}
|
||||
|
||||
impl Deref for LocalAIController {
|
||||
@ -83,69 +90,80 @@ impl LocalAIController {
|
||||
user_service.clone(),
|
||||
res_impl,
|
||||
));
|
||||
// Subscribe to state changes
|
||||
let mut running_state_rx = local_ai.subscribe_running_state();
|
||||
|
||||
let cloned_llm_res = Arc::clone(&local_ai_resource);
|
||||
let cloned_store_preferences = store_preferences.clone();
|
||||
let cloned_local_ai = Arc::clone(&local_ai);
|
||||
let cloned_user_service = Arc::clone(&user_service);
|
||||
let ollama = ArcSwapOption::default();
|
||||
let sys = get_operating_system();
|
||||
if sys.is_desktop() {
|
||||
let setting = local_ai_resource.get_llm_setting();
|
||||
ollama.store(
|
||||
Ollama::try_new(&setting.ollama_server_url)
|
||||
.map(Arc::new)
|
||||
.ok(),
|
||||
);
|
||||
|
||||
// Spawn a background task to listen for plugin state changes
|
||||
tokio::spawn(async move {
|
||||
while let Some(state) = running_state_rx.next().await {
|
||||
// Skip if we can’t get workspace_id
|
||||
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
|
||||
continue;
|
||||
};
|
||||
// Subscribe to state changes
|
||||
let mut running_state_rx = local_ai.subscribe_running_state();
|
||||
let cloned_llm_res = Arc::clone(&local_ai_resource);
|
||||
let cloned_store_preferences = store_preferences.clone();
|
||||
let cloned_local_ai = Arc::clone(&local_ai);
|
||||
let cloned_user_service = Arc::clone(&user_service);
|
||||
|
||||
let key = local_ai_enabled_key(&workspace_id.to_string());
|
||||
info!("[AI Plugin] state: {:?}", state);
|
||||
|
||||
// Read whether plugin is enabled from store; default to true
|
||||
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
|
||||
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
|
||||
// Only check resource status if the plugin isn’t in "UnexpectedStop" and is enabled
|
||||
let (plugin_downloaded, lack_of_resource) =
|
||||
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
|
||||
// Possibly check plugin readiness and resource concurrency in parallel,
|
||||
// but here we do it sequentially for clarity.
|
||||
let downloaded = is_plugin_ready();
|
||||
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
|
||||
(downloaded, resource_lack)
|
||||
} else {
|
||||
(false, None)
|
||||
};
|
||||
|
||||
// If plugin is running, retrieve version
|
||||
let plugin_version = if matches!(state, RunningState::Running { .. }) {
|
||||
match cloned_local_ai.plugin_info().await {
|
||||
Ok(info) => Some(info.version),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
// Spawn a background task to listen for plugin state changes
|
||||
tokio::spawn(async move {
|
||||
while let Some(state) = running_state_rx.next().await {
|
||||
// Skip if we can't get workspace_id
|
||||
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Broadcast the new local AI state
|
||||
let new_state = RunningStatePB::from(state);
|
||||
chat_notification_builder(
|
||||
APPFLOWY_AI_NOTIFICATION_KEY,
|
||||
ChatNotification::UpdateLocalAIState,
|
||||
)
|
||||
.payload(LocalAIPB {
|
||||
enabled,
|
||||
plugin_downloaded,
|
||||
lack_of_resource,
|
||||
state: new_state,
|
||||
plugin_version,
|
||||
})
|
||||
.send();
|
||||
} else {
|
||||
warn!("[AI Plugin] store preferences is dropped");
|
||||
let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string());
|
||||
info!("[AI Plugin] state: {:?}", state);
|
||||
|
||||
// Read whether plugin is enabled from store; default to true
|
||||
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
|
||||
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
|
||||
// Only check resource status if the plugin isn't in "UnexpectedStop" and is enabled
|
||||
let (plugin_downloaded, lack_of_resource) =
|
||||
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
|
||||
// Possibly check plugin readiness and resource concurrency in parallel,
|
||||
// but here we do it sequentially for clarity.
|
||||
let downloaded = is_plugin_ready();
|
||||
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
|
||||
(downloaded, resource_lack)
|
||||
} else {
|
||||
(false, None)
|
||||
};
|
||||
|
||||
// If plugin is running, retrieve version
|
||||
let plugin_version = if matches!(state, RunningState::Running { .. }) {
|
||||
match cloned_local_ai.plugin_info().await {
|
||||
Ok(info) => Some(info.version),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Broadcast the new local AI state
|
||||
let new_state = RunningStatePB::from(state);
|
||||
chat_notification_builder(
|
||||
APPFLOWY_AI_NOTIFICATION_KEY,
|
||||
ChatNotification::UpdateLocalAIState,
|
||||
)
|
||||
.payload(LocalAIPB {
|
||||
enabled,
|
||||
plugin_downloaded,
|
||||
lack_of_resource,
|
||||
state: new_state,
|
||||
plugin_version,
|
||||
})
|
||||
.send();
|
||||
} else {
|
||||
warn!("[AI Plugin] store preferences is dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
ai_plugin: local_ai,
|
||||
@ -153,6 +171,7 @@ impl LocalAIController {
|
||||
current_chat_id: ArcSwapOption::default(),
|
||||
store_preferences,
|
||||
user_service,
|
||||
ollama,
|
||||
}
|
||||
}
|
||||
#[instrument(level = "debug", skip_all)]
|
||||
@ -287,17 +306,85 @@ impl LocalAIController {
|
||||
self.resource.get_llm_setting()
|
||||
}
|
||||
|
||||
pub async fn get_all_chat_local_models(&self) -> Vec<AIModel> {
|
||||
self
|
||||
.get_filtered_local_models(|name| !name.contains("embed"))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_all_embedded_local_models(&self) -> Vec<AIModel> {
|
||||
self
|
||||
.get_filtered_local_models(|name| name.contains("embed"))
|
||||
.await
|
||||
}
|
||||
|
||||
// Helper function to avoid code duplication in model retrieval
|
||||
async fn get_filtered_local_models<F>(&self, filter_fn: F) -> Vec<AIModel>
|
||||
where
|
||||
F: Fn(&str) -> bool,
|
||||
{
|
||||
match self.ollama.load_full() {
|
||||
None => vec![],
|
||||
Some(ollama) => ollama
|
||||
.list_local_models()
|
||||
.await
|
||||
.map(|models| {
|
||||
models
|
||||
.into_iter()
|
||||
.filter(|m| filter_fn(&m.name.to_lowercase()))
|
||||
.map(|m| AIModel::local(m.name, String::new()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_model_type(&self, model_name: &str) -> FlowyResult<ModelType> {
|
||||
let uid = self.user_service.user_id()?;
|
||||
let mut conn = self.user_service.sqlite_connection(uid)?;
|
||||
match select_local_ai_model(&mut conn, model_name) {
|
||||
None => {
|
||||
let ollama = self
|
||||
.ollama
|
||||
.load_full()
|
||||
.ok_or_else(|| FlowyError::local_ai().with_context("ollama is not initialized"))?;
|
||||
|
||||
let request = GenerateEmbeddingsRequest::new(
|
||||
model_name.to_string(),
|
||||
EmbeddingsInput::Single("Hello".to_string()),
|
||||
);
|
||||
|
||||
let model_type = match ollama.generate_embeddings(request).await {
|
||||
Ok(value) => {
|
||||
if value.embeddings.is_empty() {
|
||||
ModelType::Chat
|
||||
} else {
|
||||
ModelType::Embedding
|
||||
}
|
||||
},
|
||||
Err(_) => ModelType::Chat,
|
||||
};
|
||||
|
||||
upsert_local_ai_model(
|
||||
&mut conn,
|
||||
&LocalAIModelTable {
|
||||
name: model_name.to_string(),
|
||||
model_type: model_type as i16,
|
||||
},
|
||||
)?;
|
||||
Ok(model_type)
|
||||
},
|
||||
Some(r) => Ok(ModelType::from(r.model_type)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_local_ai_setting(&self, setting: LocalAISetting) -> FlowyResult<()> {
|
||||
info!(
|
||||
"[AI Plugin] update local ai setting: {:?}, thread: {:?}",
|
||||
setting,
|
||||
std::thread::current().id()
|
||||
);
|
||||
|
||||
if self.resource.set_llm_setting(setting).await.is_ok() {
|
||||
let is_enabled = self.is_enabled();
|
||||
self.toggle_plugin(is_enabled).await?;
|
||||
}
|
||||
self.resource.set_llm_setting(setting).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@ -161,7 +161,6 @@ impl LocalAIResourceController {
|
||||
|
||||
let setting = self.get_llm_setting();
|
||||
let client = Client::builder().timeout(Duration::from_secs(5)).build()?;
|
||||
|
||||
match client.get(&setting.ollama_server_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
info!(
|
||||
|
||||
@ -36,6 +36,10 @@ client-api = { workspace = true, optional = true }
|
||||
tantivy = { workspace = true, optional = true }
|
||||
uuid.workspace = true
|
||||
|
||||
[target.'cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))'.dependencies]
|
||||
ollama-rs = "0.3.0"
|
||||
|
||||
|
||||
[features]
|
||||
default = ["impl_from_dispatch_error", "impl_from_serde", "impl_from_reqwest", "impl_from_sqlite"]
|
||||
impl_from_dispatch_error = ["lib-dispatch"]
|
||||
|
||||
@ -264,3 +264,10 @@ impl From<uuid::Error> for FlowyError {
|
||||
FlowyError::internal().with_context(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "windows", target_os = "macos", target_os = "linux"))]
|
||||
impl From<ollama_rs::error::OllamaError> for FlowyError {
|
||||
fn from(value: ollama_rs::error::OllamaError) -> Self {
|
||||
FlowyError::local_ai().with_context(value)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1 @@
|
||||
-- This file should undo anything in `up.sql`
|
||||
@ -0,0 +1,6 @@
|
||||
-- Your SQL goes here
|
||||
CREATE TABLE local_ai_model_table
|
||||
(
|
||||
name TEXT PRIMARY KEY NOT NULL,
|
||||
model_type SMALLINT NOT NULL
|
||||
);
|
||||
@ -54,6 +54,13 @@ diesel::table! {
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
local_ai_model_table (name) {
|
||||
name -> Text,
|
||||
model_type -> SmallInt,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
upload_file_part (upload_id, e_tag) {
|
||||
upload_id -> Text,
|
||||
@ -138,6 +145,7 @@ diesel::allow_tables_to_appear_in_same_query!(
|
||||
chat_message_table,
|
||||
chat_table,
|
||||
collab_snapshot,
|
||||
local_ai_model_table,
|
||||
upload_file_part,
|
||||
upload_file_table,
|
||||
user_data_migration_records,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user