diff --git a/frontend/appflowy_flutter/lib/ai/ai.dart b/frontend/appflowy_flutter/lib/ai/ai.dart index e3f52a8168..9bfeeb4e00 100644 --- a/frontend/appflowy_flutter/lib/ai/ai.dart +++ b/frontend/appflowy_flutter/lib/ai/ai.dart @@ -2,6 +2,8 @@ export 'service/ai_entities.dart'; export 'service/ai_prompt_input_bloc.dart'; export 'service/appflowy_ai_service.dart'; export 'service/error.dart'; +export 'service/ai_model_state_notifier.dart'; +export 'service/select_model_bloc.dart'; export 'widgets/loading_indicator.dart'; export 'widgets/prompt_input/action_buttons.dart'; export 'widgets/prompt_input/desktop_prompt_text_field.dart'; @@ -13,4 +15,5 @@ export 'widgets/prompt_input/mentioned_page_text_span.dart'; export 'widgets/prompt_input/predefined_format_buttons.dart'; export 'widgets/prompt_input/select_sources_bottom_sheet.dart'; export 'widgets/prompt_input/select_sources_menu.dart'; +export 'widgets/prompt_input/select_model_menu.dart'; export 'widgets/prompt_input/send_button.dart'; diff --git a/frontend/appflowy_flutter/lib/ai/service/ai_entities.dart b/frontend/appflowy_flutter/lib/ai/service/ai_entities.dart index b8592bc32b..249a92019a 100644 --- a/frontend/appflowy_flutter/lib/ai/service/ai_entities.dart +++ b/frontend/appflowy_flutter/lib/ai/service/ai_entities.dart @@ -4,6 +4,14 @@ import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart'; import 'package:easy_localization/easy_localization.dart'; import 'package:equatable/equatable.dart'; +enum AiType { + cloud, + local; + + bool get isCloud => this == cloud; + bool get isLocal => this == local; +} + class PredefinedFormat extends Equatable { const PredefinedFormat({ required this.imageFormat, diff --git a/frontend/appflowy_flutter/lib/ai/service/ai_input_control.dart b/frontend/appflowy_flutter/lib/ai/service/ai_input_control.dart deleted file mode 100644 index c468fdd6e9..0000000000 --- a/frontend/appflowy_flutter/lib/ai/service/ai_input_control.dart +++ /dev/null @@ -1,138 +0,0 @@ -import 'package:appflowy/ai/service/ai_prompt_input_bloc.dart'; -import 'package:appflowy/generated/locale_keys.g.dart'; -import 'package:appflowy/plugins/ai_chat/application/ai_model_switch_listener.dart'; -import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart'; -import 'package:appflowy_backend/dispatch/dispatch.dart'; -import 'package:appflowy_backend/log.dart'; -import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart'; -import 'package:easy_localization/easy_localization.dart'; -import 'package:protobuf/protobuf.dart'; -import 'package:universal_platform/universal_platform.dart'; - -class AIModelStateNotifier { - AIModelStateNotifier({required this.objectId}) - : _isDesktop = UniversalPlatform.isDesktop, - _localAIListener = - UniversalPlatform.isDesktop ? LocalAIStateListener() : null, - _aiModelSwitchListener = AIModelSwitchListener(objectId: objectId); - - final String objectId; - final bool _isDesktop; - final LocalAIStateListener? _localAIListener; - final AIModelSwitchListener _aiModelSwitchListener; - - LocalAIPB? _localAIState; - AvailableModelsPB? _availableModels; - - // Callbacks - void Function(AiType, bool, String)? onChanged; - void Function(AvailableModelsPB)? onAvailableModelsChanged; - - String hintText() { - final aiType = getCurrentAiType(); - if (aiType.isLocal) { - return isEditable() - ? LocaleKeys.chat_inputLocalAIMessageHint.tr() - : LocaleKeys.settings_aiPage_keys_localAIInitializing.tr(); - } - return LocaleKeys.chat_inputMessageHint.tr(); - } - - AiType getCurrentAiType() { - // On non-desktop platforms, always return cloud type. - if (!_isDesktop) return AiType.cloud; - return (_availableModels?.selectedModel.isLocal ?? false) - ? AiType.local - : AiType.cloud; - } - - bool isEditable() { - // On non-desktop platforms, always editable. - if (!_isDesktop) return true; - return getCurrentAiType().isLocal - ? _localAIState?.state == RunningStatePB.Running - : true; - } - - void _notifyStateChanged() { - onChanged?.call(getCurrentAiType(), isEditable(), hintText()); - } - - Future init() async { - // Load both available models and local state concurrently. - await Future.wait([ - _loadAvailableModels(), - _loadLocalAIState(), - ]); - } - - Future _loadAvailableModels() async { - final payload = AvailableModelsQueryPB(source: objectId); - final result = await AIEventGetAvailableModels(payload).send(); - result.fold( - (models) { - _availableModels = models; - onAvailableModelsChanged?.call(models); - _notifyStateChanged(); - }, - (err) => Log.error("Failed to get available models: $err"), - ); - } - - Future _loadLocalAIState() async { - final result = await AIEventGetLocalAIState().send(); - result.fold( - (state) { - _localAIState = state; - _notifyStateChanged(); - }, - (error) { - Log.error("Failed to get local AI state: $error"); - _notifyStateChanged(); - }, - ); - } - - void startListening({ - void Function(AiType, bool, String)? onChanged, - void Function(AvailableModelsPB)? onAvailableModelsChanged, - }) { - this.onChanged = onChanged; - this.onAvailableModelsChanged = onAvailableModelsChanged; - - // Only start local AI listener on desktop platforms. - if (_isDesktop) { - _localAIListener?.start( - stateCallback: (state) { - _localAIState = state; - if (state.state == RunningStatePB.Running || - state.state == RunningStatePB.Stopped) { - _loadAvailableModels(); - } - }, - ); - } - - _aiModelSwitchListener.start( - onUpdateSelectedModel: (model) { - if (_availableModels != null) { - final updatedModels = _availableModels!.deepCopy() - ..selectedModel = model; - _availableModels = updatedModels; - onAvailableModelsChanged?.call(updatedModels); - } - if (model.isLocal && _isDesktop) { - _loadLocalAIState(); - } else { - _notifyStateChanged(); - } - }, - ); - } - - Future stop() async { - onChanged = null; - await _localAIListener?.stop(); - await _aiModelSwitchListener.stop(); - } -} diff --git a/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart new file mode 100644 index 0000000000..c43bd01c6f --- /dev/null +++ b/frontend/appflowy_flutter/lib/ai/service/ai_model_state_notifier.dart @@ -0,0 +1,172 @@ +import 'package:appflowy/ai/ai.dart'; +import 'package:appflowy/generated/locale_keys.g.dart'; +import 'package:appflowy/plugins/ai_chat/application/ai_model_switch_listener.dart'; +import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart'; +import 'package:appflowy_backend/dispatch/dispatch.dart'; +import 'package:appflowy_backend/log.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart'; +import 'package:appflowy_result/appflowy_result.dart'; +import 'package:easy_localization/easy_localization.dart'; +import 'package:protobuf/protobuf.dart'; +import 'package:universal_platform/universal_platform.dart'; + +typedef OnModelStateChangedCallback = void Function(AiType, bool, String); +typedef OnAvailableModelsChangedCallback = void Function( + List, + AIModelPB?, +); + +class AIModelStateNotifier { + AIModelStateNotifier({required this.objectId}) + : _localAIListener = + UniversalPlatform.isDesktop ? LocalAIStateListener() : null, + _aiModelSwitchListener = AIModelSwitchListener(objectId: objectId) { + _startListening(); + _init(); + } + + final String objectId; + final LocalAIStateListener? _localAIListener; + final AIModelSwitchListener _aiModelSwitchListener; + LocalAIPB? _localAIState; + AvailableModelsPB? _availableModels; + + // callbacks + final List _stateChangedCallbacks = []; + final List + _availableModelsChangedCallbacks = []; + + void _startListening() { + if (UniversalPlatform.isDesktop) { + _localAIListener?.start( + stateCallback: (state) async { + _localAIState = state; + _notifyStateChanged(); + + if (state.state == RunningStatePB.Running || + state.state == RunningStatePB.Stopped) { + await _loadAvailableModels(); + _notifyAvailableModelsChanged(); + } + }, + ); + } + + _aiModelSwitchListener.start( + onUpdateSelectedModel: (model) async { + final updatedModels = _availableModels?.deepCopy() + ?..selectedModel = model; + _availableModels = updatedModels; + _notifyAvailableModelsChanged(); + + if (model.isLocal && UniversalPlatform.isDesktop) { + await _loadLocalAiState(); + } + _notifyStateChanged(); + }, + ); + } + + void _init() async { + await Future.wait([_loadLocalAiState(), _loadAvailableModels()]); + _notifyStateChanged(); + _notifyAvailableModelsChanged(); + } + + void addListener({ + OnModelStateChangedCallback? onStateChanged, + OnAvailableModelsChangedCallback? onAvailableModelsChanged, + }) { + if (onStateChanged != null) { + _stateChangedCallbacks.add(onStateChanged); + } + if (onAvailableModelsChanged != null) { + _availableModelsChangedCallbacks.add(onAvailableModelsChanged); + } + } + + void removeListener({ + OnModelStateChangedCallback? onStateChanged, + OnAvailableModelsChangedCallback? onAvailableModelsChanged, + }) { + if (onStateChanged != null) { + _stateChangedCallbacks.remove(onStateChanged); + } + if (onAvailableModelsChanged != null) { + _availableModelsChangedCallbacks.remove(onAvailableModelsChanged); + } + } + + Future dispose() async { + _stateChangedCallbacks.clear(); + _availableModelsChangedCallbacks.clear(); + await _localAIListener?.stop(); + await _aiModelSwitchListener.stop(); + } + + (AiType, String, bool) getState() { + if (UniversalPlatform.isMobile) { + return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true); + } + + final availableModels = _availableModels; + final localAiState = _localAIState; + + if (availableModels == null) { + Log.warn("No available models"); + return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true); + } + if (localAiState == null) { + Log.warn("Cannot get local AI state"); + return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true); + } + + if (!availableModels.selectedModel.isLocal) { + return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true); + } + + final editable = localAiState.state == RunningStatePB.Running; + final hintText = editable + ? LocaleKeys.chat_inputLocalAIMessageHint.tr() + : LocaleKeys.settings_aiPage_keys_localAIInitializing.tr(); + + return (AiType.local, hintText, editable); + } + + (List, AIModelPB?) getAvailableModels() { + final availableModels = _availableModels; + if (availableModels == null) { + return ([], null); + } + return (availableModels.models, availableModels.selectedModel); + } + + void _notifyAvailableModelsChanged() { + final (models, selectedModel) = getAvailableModels(); + for (final callback in _availableModelsChangedCallbacks) { + callback(models, selectedModel); + } + } + + void _notifyStateChanged() { + final (type, hintText, isEditable) = getState(); + for (final callback in _stateChangedCallbacks) { + callback(type, isEditable, hintText); + } + } + + Future _loadAvailableModels() { + final payload = AvailableModelsQueryPB(source: objectId); + return AIEventGetAvailableModels(payload).send().fold( + (models) => _availableModels = models, + (err) => Log.error("Failed to get available models: $err"), + ); + } + + Future _loadLocalAiState() { + return AIEventGetLocalAIState().send().fold( + (localAIState) => _localAIState = localAIState, + (error) => Log.error("Failed to get local AI state: $error"), + ); + } +} diff --git a/frontend/appflowy_flutter/lib/ai/service/ai_prompt_input_bloc.dart b/frontend/appflowy_flutter/lib/ai/service/ai_prompt_input_bloc.dart index e2acda9a1e..95854ab047 100644 --- a/frontend/appflowy_flutter/lib/ai/service/ai_prompt_input_bloc.dart +++ b/frontend/appflowy_flutter/lib/ai/service/ai_prompt_input_bloc.dart @@ -1,6 +1,6 @@ import 'dart:async'; -import 'package:appflowy/ai/service/ai_input_control.dart'; +import 'package:appflowy/ai/service/ai_model_state_notifier.dart'; import 'package:appflowy/plugins/ai_chat/application/chat_entity.dart'; import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart'; import 'package:flutter_bloc/flutter_bloc.dart'; @@ -14,17 +14,18 @@ class AIPromptInputBloc extends Bloc { AIPromptInputBloc({ required String objectId, required PredefinedFormat? predefinedFormat, - }) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId), - super(AIPromptInputState.initial(objectId, predefinedFormat)) { + }) : aiModelStateNotifier = AIModelStateNotifier(objectId: objectId), + super(AIPromptInputState.initial(predefinedFormat)) { _dispatch(); + _startListening(); _init(); } - final AIModelStateNotifier _aiModelStateNotifier; + final AIModelStateNotifier aiModelStateNotifier; @override Future close() async { - await _aiModelStateNotifier.stop(); + await aiModelStateNotifier.dispose(); return super.close(); } @@ -36,7 +37,6 @@ class AIPromptInputBloc extends Bloc { emit( state.copyWith( aiType: aiType, - supportChatWithFile: false, editable: editable, hintText: hintText, ), @@ -103,16 +103,17 @@ class AIPromptInputBloc extends Bloc { ); } - void _init() { - _aiModelStateNotifier.startListening( - onChanged: (aiType, editable, hintText) { - if (!isClosed) { - add(AIPromptInputEvent.updateAIState(aiType, editable, hintText)); - } + void _startListening() { + aiModelStateNotifier.addListener( + onStateChanged: (aiType, editable, hintText) { + add(AIPromptInputEvent.updateAIState(aiType, editable, hintText)); }, ); + } - _aiModelStateNotifier.init(); + void _init() { + final (aiType, hintText, isEditable) = aiModelStateNotifier.getState(); + add(AIPromptInputEvent.updateAIState(aiType, isEditable, hintText)); } Map consumeMetadata() { @@ -155,7 +156,6 @@ class AIPromptInputEvent with _$AIPromptInputEvent { @freezed class AIPromptInputState with _$AIPromptInputState { const factory AIPromptInputState({ - required String objectId, required AiType aiType, required bool supportChatWithFile, required bool showPredefinedFormats, @@ -166,12 +166,8 @@ class AIPromptInputState with _$AIPromptInputState { required String hintText, }) = _AIPromptInputState; - factory AIPromptInputState.initial( - String objectId, - PredefinedFormat? format, - ) => + factory AIPromptInputState.initial(PredefinedFormat? format) => AIPromptInputState( - objectId: objectId, aiType: AiType.cloud, supportChatWithFile: false, showPredefinedFormats: format != null, @@ -182,11 +178,3 @@ class AIPromptInputState with _$AIPromptInputState { hintText: '', ); } - -enum AiType { - cloud, - local; - - bool get isCloud => this == cloud; - bool get isLocal => this == local; -} diff --git a/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart b/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart index 665533bd40..7ad52b9ec4 100644 --- a/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart +++ b/frontend/appflowy_flutter/lib/ai/service/select_model_bloc.dart @@ -1,66 +1,66 @@ import 'dart:async'; -import 'package:appflowy/ai/service/ai_input_control.dart'; +import 'package:appflowy/ai/service/ai_model_state_notifier.dart'; import 'package:appflowy_backend/dispatch/dispatch.dart'; import 'package:appflowy_backend/protobuf/flowy-ai/entities.pbserver.dart'; import 'package:flutter_bloc/flutter_bloc.dart'; import 'package:freezed_annotation/freezed_annotation.dart'; -import 'package:protobuf/protobuf.dart'; part 'select_model_bloc.freezed.dart'; class SelectModelBloc extends Bloc { SelectModelBloc({ - required this.objectId, - }) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId), - super(const SelectModelState()) { - _aiModelStateNotifier.init(); - _aiModelStateNotifier.startListening( - onAvailableModelsChanged: (models) { - if (!isClosed) { - add(SelectModelEvent.didLoadModels(models)); - } - }, - ); - + required AIModelStateNotifier aiModelStateNotifier, + }) : _aiModelStateNotifier = aiModelStateNotifier, + super(SelectModelState.initial(aiModelStateNotifier)) { on( - (event, emit) async { - await event.when( - selectModel: (AIModelPB model) async { - await AIEventUpdateSelectedModel( + (event, emit) { + event.when( + selectModel: (model) { + AIEventUpdateSelectedModel( UpdateSelectedModelPB( - source: objectId, + source: _aiModelStateNotifier.objectId, selectedModel: model, ), ).send(); - state.availableModels?.freeze(); - final newAvailableModels = state.availableModels?.rebuild((m) { - m.selectedModel = model; - }); - + emit(state.copyWith(selectedModel: model)); + }, + didLoadModels: (models, selectedModel) { emit( - state.copyWith( - availableModels: newAvailableModels, + SelectModelState( + models: models, + selectedModel: selectedModel, ), ); }, - didLoadModels: (AvailableModelsPB models) { - emit(state.copyWith(availableModels: models)); - }, ); }, ); + + _aiModelStateNotifier.addListener( + onAvailableModelsChanged: _onAvailableModelsChanged, + ); } - final String objectId; final AIModelStateNotifier _aiModelStateNotifier; @override Future close() async { - await _aiModelStateNotifier.stop(); + _aiModelStateNotifier.removeListener( + onAvailableModelsChanged: _onAvailableModelsChanged, + ); await super.close(); } + + void _onAvailableModelsChanged( + List models, + AIModelPB? selectedModel, + ) { + if (!isClosed) { + add(SelectModelEvent.didLoadModels(models, selectedModel)); + } + } } @freezed @@ -70,13 +70,23 @@ class SelectModelEvent with _$SelectModelEvent { ) = _SelectModel; const factory SelectModelEvent.didLoadModels( - AvailableModelsPB models, + List models, + AIModelPB? selectedModel, ) = _DidLoadModels; } @freezed class SelectModelState with _$SelectModelState { const factory SelectModelState({ - AvailableModelsPB? availableModels, + required List models, + required AIModelPB? selectedModel, }) = _SelectModelState; + + factory SelectModelState.initial(AIModelStateNotifier notifier) { + final (models, selectedModel) = notifier.getAvailableModels(); + return SelectModelState( + models: models, + selectedModel: selectedModel, + ); + } } diff --git a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/desktop_prompt_text_field.dart b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/desktop_prompt_text_field.dart index bf21a59ad0..fcf487da94 100644 --- a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/desktop_prompt_text_field.dart +++ b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/desktop_prompt_text_field.dart @@ -1,17 +1,14 @@ import 'package:appflowy/ai/ai.dart'; -import 'package:appflowy/ai/service/select_model_bloc.dart'; import 'package:appflowy/generated/locale_keys.g.dart'; import 'package:appflowy/plugins/ai_chat/application/chat_input_control_cubit.dart'; import 'package:appflowy/plugins/ai_chat/presentation/layout_define.dart'; import 'package:appflowy/startup/startup.dart'; import 'package:appflowy/util/theme_extension.dart'; -import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart'; import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart'; import 'package:easy_localization/easy_localization.dart'; import 'package:extended_text_field/extended_text_field.dart'; import 'package:flowy_infra/file_picker/file_picker_service.dart'; import 'package:flowy_infra_ui/flowy_infra_ui.dart'; -import 'package:flowy_infra_ui/style_widget/hover.dart'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import 'package:flutter_bloc/flutter_bloc.dart'; @@ -169,7 +166,6 @@ class _DesktopPromptInputState extends State { top: null, child: TextFieldTapRegion( child: _PromptBottomActions( - objectId: state.objectId, showPredefinedFormats: state.showPredefinedFormats, onTogglePredefinedFormatSection: () => @@ -567,7 +563,6 @@ class PromptInputTextField extends StatelessWidget { class _PromptBottomActions extends StatelessWidget { const _PromptBottomActions({ - required this.objectId, required this.sendButtonState, required this.showPredefinedFormats, required this.onTogglePredefinedFormatSection, @@ -579,7 +574,6 @@ class _PromptBottomActions extends StatelessWidget { this.extraBottomActionButton, }); - final String objectId; final bool showPredefinedFormats; final void Function() onTogglePredefinedFormatSection; final void Function() onStartMention; @@ -600,10 +594,16 @@ class _PromptBottomActions extends StatelessWidget { return Row( children: [ _predefinedFormatButton(), - SelectModelButton(objectId: objectId), + const HSpace( + DesktopAIChatSizes.inputActionBarButtonSpacing, + ), + SelectModelMenu( + aiModelStateNotifier: + context.read().aiModelStateNotifier, + ), const Spacer(), if (state.aiType.isCloud) ...[ - _selectSourcesButton(context), + _selectSourcesButton(), const HSpace( DesktopAIChatSizes.inputActionBarButtonSpacing, ), @@ -639,7 +639,7 @@ class _PromptBottomActions extends StatelessWidget { ); } - Widget _selectSourcesButton(BuildContext context) { + Widget _selectSourcesButton() { return PromptInputDesktopSelectSourcesButton( onUpdateSelectedSources: onUpdateSelectedSources, selectedSourcesNotifier: selectedSourcesNotifier, @@ -686,225 +686,3 @@ class _PromptBottomActions extends StatelessWidget { ); } } - -class SelectModelButton extends StatefulWidget { - const SelectModelButton({ - super.key, - required this.objectId, - }); - - final String objectId; - - @override - State createState() => _SelectModelButtonState(); -} - -class _SelectModelButtonState extends State { - final popoverController = PopoverController(); - late SelectModelBloc bloc; - - @override - void initState() { - super.initState(); - bloc = SelectModelBloc(objectId: widget.objectId); - } - - @override - void dispose() { - popoverController.close(); - bloc.close(); - super.dispose(); - } - - @override - Widget build(BuildContext context) { - return BlocProvider.value( - value: bloc, - child: BlocBuilder( - builder: (context, state) { - return AppFlowyPopover( - // constraints: BoxConstraints.loose(const Size(250, 200)), - offset: const Offset(0.0, -10.0), - direction: PopoverDirection.topWithLeftAligned, - margin: EdgeInsets.zero, - controller: popoverController, - onOpen: () {}, - onClose: () {}, - popupBuilder: (_) { - return BlocProvider.value( - value: bloc, - child: _PopoverSelectModel( - onClose: () => popoverController.close(), - ), - ); - }, - child: _CurrentModelButton( - key: ValueKey(state.availableModels?.selectedModel.name), - modelName: state.availableModels?.selectedModel.name ?? "", - onTap: () => popoverController.show(), - ), - ); - }, - ), - ); - } -} - -class _PopoverSelectModel extends StatelessWidget { - const _PopoverSelectModel({ - required this.onClose, - }); - - final VoidCallback onClose; - - @override - Widget build(BuildContext context) { - return BlocBuilder( - builder: (context, state) { - if (state.availableModels == null || - state.availableModels!.models.isEmpty) { - return const SizedBox.shrink(); - } - - // Separate models into local and cloud models - final localModels = state.availableModels!.models - .where((model) => model.isLocal) - .toList(); - - final cloudModels = state.availableModels!.models - .where((model) => !model.isLocal) - .toList(); - - return Padding( - padding: const EdgeInsets.fromLTRB(8, 4, 8, 12), - child: Column( - mainAxisSize: MainAxisSize.min, - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - // Local AI Models Section - if (localModels.isNotEmpty) ...[ - _ModelSectionHeader( - title: LocaleKeys.chat_changeFormat_localModel.tr(), - ), - const SizedBox(height: 4), - ...localModels.map( - (model) => _ModelItem( - model: model, - onTap: () { - context.read().add( - SelectModelEvent.selectModel(model), - ); - onClose(); - }, - ), - ), - const SizedBox(height: 8), - ], - - // Cloud AI Models Section - if (cloudModels.isNotEmpty) ...[ - if (localModels.isNotEmpty) - _ModelSectionHeader( - title: LocaleKeys.chat_changeFormat_cloudModel.tr(), - ), - const VSpace(4), - ...cloudModels.map( - (model) => _ModelItem( - model: model, - onTap: () { - context.read().add( - SelectModelEvent.selectModel(model), - ); - onClose(); - }, - ), - ), - ], - ], - ), - ); - }, - ); - } -} - -class _ModelSectionHeader extends StatelessWidget { - const _ModelSectionHeader({ - required this.title, - }); - - final String title; - - @override - Widget build(BuildContext context) { - return Padding( - padding: const EdgeInsets.only(top: 4, bottom: 2), - child: FlowyText( - title, - fontSize: 12, - color: Theme.of(context).hintColor, - fontWeight: FontWeight.w500, - ), - ); - } -} - -class _ModelItem extends StatelessWidget { - const _ModelItem({ - required this.model, - required this.onTap, - }); - - final AIModelPB model; - final VoidCallback onTap; - - @override - Widget build(BuildContext context) { - final modelName = model.name; - - return FlowyTextButton( - modelName, - fillColor: Colors.transparent, - onPressed: onTap, - ); - } -} - -class _CurrentModelButton extends StatelessWidget { - const _CurrentModelButton({ - required this.modelName, - required this.onTap, - super.key, - }); - - final String modelName; - final VoidCallback onTap; - - @override - Widget build(BuildContext context) { - return FlowyTooltip( - message: LocaleKeys.chat_changeFormat_switchModel.tr(), - child: GestureDetector( - onTap: onTap, - behavior: HitTestBehavior.opaque, - child: SizedBox( - height: DesktopAIPromptSizes.actionBarButtonSize, - child: FlowyHover( - style: const HoverStyle( - borderRadius: BorderRadius.all(Radius.circular(8)), - ), - child: Padding( - padding: const EdgeInsetsDirectional.fromSTEB(6, 6, 4, 6), - child: FlowyText( - modelName, - fontSize: 12, - figmaLineHeight: 16, - color: Theme.of(context).hintColor, - ), - ), - ), - ), - ), - ); - } -} diff --git a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/predefined_format_buttons.dart b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/predefined_format_buttons.dart index 6d6fc8de31..403b978905 100644 --- a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/predefined_format_buttons.dart +++ b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/predefined_format_buttons.dart @@ -104,6 +104,7 @@ class ChangeFormatBar extends StatelessWidget { }, child: FlowyTooltip( message: format.i18n, + preferBelow: false, child: SizedBox.square( dimension: _buttonSize, child: FlowyHover( @@ -150,6 +151,7 @@ class ChangeFormatBar extends StatelessWidget { }, child: FlowyTooltip( message: format.i18n, + preferBelow: false, child: SizedBox.square( dimension: _buttonSize, child: FlowyHover( diff --git a/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart new file mode 100644 index 0000000000..b9e3daada9 --- /dev/null +++ b/frontend/appflowy_flutter/lib/ai/widgets/prompt_input/select_model_menu.dart @@ -0,0 +1,222 @@ +import 'package:appflowy/ai/ai.dart'; +import 'package:appflowy/generated/flowy_svgs.g.dart'; +import 'package:appflowy/generated/locale_keys.g.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart'; +import 'package:easy_localization/easy_localization.dart'; +import 'package:flowy_infra_ui/flowy_infra_ui.dart'; +import 'package:flowy_infra_ui/style_widget/hover.dart'; +import 'package:flutter/material.dart'; +import 'package:flutter_bloc/flutter_bloc.dart'; + +class SelectModelMenu extends StatefulWidget { + const SelectModelMenu({ + super.key, + required this.aiModelStateNotifier, + }); + + final AIModelStateNotifier aiModelStateNotifier; + + @override + State createState() => _SelectModelMenuState(); +} + +class _SelectModelMenuState extends State { + final popoverController = PopoverController(); + + @override + Widget build(BuildContext context) { + return BlocProvider( + create: (context) => SelectModelBloc( + aiModelStateNotifier: widget.aiModelStateNotifier, + ), + child: BlocBuilder( + builder: (context, state) { + if (state.selectedModel == null) { + return const SizedBox.shrink(); + } + return AppFlowyPopover( + offset: Offset(-12.0, 0.0), + constraints: BoxConstraints(maxWidth: 250, maxHeight: 600), + direction: PopoverDirection.topWithLeftAligned, + margin: EdgeInsets.zero, + controller: popoverController, + popupBuilder: (popoverContext) { + return SelectModelPopoverContent( + models: state.models, + selectedModel: state.selectedModel, + onSelectModel: (model) { + if (model != state.selectedModel) { + context + .read() + .add(SelectModelEvent.selectModel(model)); + } + popoverController.close(); + }, + ); + }, + child: _CurrentModelButton( + modelName: state.selectedModel!.name, + onTap: () => popoverController.show(), + ), + ); + }, + ), + ); + } +} + +class SelectModelPopoverContent extends StatelessWidget { + const SelectModelPopoverContent({ + super.key, + required this.models, + required this.selectedModel, + this.onSelectModel, + }); + + final List models; + final AIModelPB? selectedModel; + final void Function(AIModelPB)? onSelectModel; + + @override + Widget build(BuildContext context) { + if (models.isEmpty) { + return const SizedBox.shrink(); + } + + // separate models into local and cloud models + final localModels = models.where((model) => model.isLocal).toList(); + final cloudModels = models.where((model) => !model.isLocal).toList(); + + return Padding( + padding: const EdgeInsets.all(8.0), + child: Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + if (localModels.isNotEmpty) ...[ + _ModelSectionHeader( + title: LocaleKeys.chat_switchModel_localModel.tr(), + ), + const VSpace(4.0), + ], + ...localModels.map( + (model) => _ModelItem( + model: model, + isSelected: model == selectedModel, + onTap: () => onSelectModel?.call(model), + ), + ), + if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[ + const VSpace(8.0), + _ModelSectionHeader( + title: LocaleKeys.chat_switchModel_cloudModel.tr(), + ), + const VSpace(4.0), + ], + ...cloudModels.map( + (model) => _ModelItem( + model: model, + isSelected: model == selectedModel, + onTap: () => onSelectModel?.call(model), + ), + ), + ], + ), + ); + } +} + +class _ModelSectionHeader extends StatelessWidget { + const _ModelSectionHeader({ + required this.title, + }); + + final String title; + + @override + Widget build(BuildContext context) { + return Padding( + padding: const EdgeInsets.only(top: 4, bottom: 2), + child: FlowyText( + title, + fontSize: 12, + figmaLineHeight: 16, + color: Theme.of(context).hintColor, + fontWeight: FontWeight.w500, + ), + ); + } +} + +class _ModelItem extends StatelessWidget { + const _ModelItem({ + required this.model, + required this.isSelected, + required this.onTap, + }); + + final AIModelPB model; + final bool isSelected; + final VoidCallback onTap; + + @override + Widget build(BuildContext context) { + return SizedBox( + height: 32, + child: FlowyButton( + onTap: onTap, + margin: EdgeInsets.symmetric(horizontal: 8.0, vertical: 6.0), + text: FlowyText(model.name), + rightIcon: isSelected ? FlowySvg(FlowySvgs.check_s) : null, + ), + ); + } +} + +class _CurrentModelButton extends StatelessWidget { + const _CurrentModelButton({ + required this.modelName, + required this.onTap, + }); + + final String modelName; + final VoidCallback onTap; + + @override + Widget build(BuildContext context) { + return FlowyTooltip( + message: LocaleKeys.chat_switchModel_label.tr(), + child: GestureDetector( + onTap: onTap, + behavior: HitTestBehavior.opaque, + child: SizedBox( + height: DesktopAIPromptSizes.actionBarButtonSize, + child: FlowyHover( + style: const HoverStyle( + borderRadius: BorderRadius.all(Radius.circular(8)), + ), + child: Padding( + padding: const EdgeInsetsDirectional.all(4.0), + child: Row( + children: [ + FlowyText( + modelName, + fontSize: 12, + figmaLineHeight: 16, + color: Theme.of(context).hintColor, + ), + HSpace(2.0), + FlowySvg( + FlowySvgs.ai_source_drop_down_s, + color: Theme.of(context).hintColor, + size: const Size.square(8), + ), + ], + ), + ), + ), + ), + ), + ); + } +} diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/application/chat_bloc.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/application/chat_bloc.dart index e7aca346e0..4924a42c0d 100644 --- a/frontend/appflowy_flutter/lib/plugins/ai_chat/application/chat_bloc.dart +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/application/chat_bloc.dart @@ -239,9 +239,9 @@ class ChatBloc extends Bloc { ), ); }, - regenerateAnswer: (id, format) { + regenerateAnswer: (id, format, model) { _clearRelatedQuestions(); - _regenerateAnswer(id, format); + _regenerateAnswer(id, format, model); lastSentMessage = null; isFetchingRelatedQuestions = false; @@ -483,6 +483,7 @@ class ChatBloc extends Bloc { void _regenerateAnswer( String answerMessageIdString, PredefinedFormat? format, + AIModelPB? model, ) async { final id = temporaryMessageIDMap.entries .firstWhereOrNull((e) => e.value == answerMessageIdString) @@ -505,6 +506,9 @@ class ChatBloc extends Bloc { if (format != null) { payload.format = format.toPB(); } + if (model != null) { + payload.model = model; + } await AIEventRegenerateResponse(payload).send().fold( (success) { @@ -637,6 +641,7 @@ class ChatEvent with _$ChatEvent { const factory ChatEvent.regenerateAnswer( String id, PredefinedFormat? format, + AIModelPB? model, ) = _RegenerateAnswer; // streaming answer diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/chat_page.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/chat_page.dart index cbc4929f56..4f843d447b 100644 --- a/frontend/appflowy_flutter/lib/plugins/ai_chat/chat_page.dart +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/chat_page.dart @@ -265,10 +265,13 @@ class _ChatContentPage extends StatelessWidget { _onSelectMetadata(context, metadata), onRegenerate: () => context .read() - .add(ChatEvent.regenerateAnswer(message.id, null)), + .add(ChatEvent.regenerateAnswer(message.id, null, null)), onChangeFormat: (format) => context .read() - .add(ChatEvent.regenerateAnswer(message.id, format)), + .add(ChatEvent.regenerateAnswer(message.id, format, null)), + onChangeModel: (model) => context + .read() + .add(ChatEvent.regenerateAnswer(message.id, null, model)), onStopStream: () => context.read().add( const ChatEvent.stopStream(), ), diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_change_model_bottom_sheet.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_change_model_bottom_sheet.dart new file mode 100644 index 0000000000..aa0d840574 --- /dev/null +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_change_model_bottom_sheet.dart @@ -0,0 +1,145 @@ +import 'package:appflowy/generated/locale_keys.g.dart'; +import 'package:appflowy/mobile/presentation/base/app_bar/app_bar_actions.dart'; +import 'package:appflowy/mobile/presentation/bottom_sheet/bottom_sheet.dart'; +import 'package:appflowy/mobile/presentation/widgets/widgets.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart'; +import 'package:collection/collection.dart'; +import 'package:easy_localization/easy_localization.dart'; +import 'package:flowy_infra_ui/flowy_infra_ui.dart'; +import 'package:flutter/material.dart'; + +Future showChangeModelBottomSheet( + BuildContext context, + List models, +) { + return showMobileBottomSheet( + context, + showDragHandle: true, + builder: (context) => _ChangeModelBottomSheetContent(models: models), + ); +} + +class _ChangeModelBottomSheetContent extends StatefulWidget { + const _ChangeModelBottomSheetContent({ + required this.models, + }); + + final List models; + + @override + State<_ChangeModelBottomSheetContent> createState() => + _ChangeModelBottomSheetContentState(); +} + +class _ChangeModelBottomSheetContentState + extends State<_ChangeModelBottomSheetContent> { + AIModelPB? model; + + @override + Widget build(BuildContext context) { + return Column( + mainAxisSize: MainAxisSize.min, + children: [ + _Header( + onCancel: () => Navigator.of(context).pop(), + onDone: () => Navigator.of(context).pop(model), + ), + const VSpace(4.0), + _Body( + models: widget.models, + selectedModel: model, + onSelectModel: (format) { + setState(() => model = format); + }, + ), + const VSpace(16.0), + ], + ); + } +} + +class _Header extends StatelessWidget { + const _Header({ + required this.onCancel, + required this.onDone, + }); + + final VoidCallback onCancel; + final VoidCallback onDone; + + @override + Widget build(BuildContext context) { + return SizedBox( + height: 44.0, + child: Stack( + children: [ + Align( + alignment: Alignment.centerLeft, + child: AppBarBackButton( + padding: const EdgeInsets.symmetric( + vertical: 12, + horizontal: 16, + ), + onTap: onCancel, + ), + ), + Align( + child: Container( + constraints: const BoxConstraints(maxWidth: 250), + child: FlowyText( + LocaleKeys.chat_switchModel_label.tr(), + fontSize: 17.0, + fontWeight: FontWeight.w500, + overflow: TextOverflow.ellipsis, + ), + ), + ), + Align( + alignment: Alignment.centerRight, + child: AppBarDoneButton( + onTap: onDone, + ), + ), + ], + ), + ); + } +} + +class _Body extends StatelessWidget { + const _Body({ + required this.models, + required this.selectedModel, + required this.onSelectModel, + }); + + final List models; + final AIModelPB? selectedModel; + final void Function(AIModelPB) onSelectModel; + + @override + Widget build(BuildContext context) { + return Column( + mainAxisSize: MainAxisSize.min, + children: models + .mapIndexed( + (index, model) => _buildModelButton(model, index == 0), + ) + .toList(), + ); + } + + Widget _buildModelButton( + AIModelPB model, [ + bool isFirst = false, + ]) { + return FlowyOptionTile.checkbox( + text: model.name, + isSelected: model == selectedModel, + showTopBorder: isFirst, + onTap: () { + onSelectModel(model); + }, + ); + } +} diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart index 6b1d428d04..150ce20192 100644 --- a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_action_bar.dart @@ -21,6 +21,7 @@ import 'package:appflowy/workspace/application/view/view_ext.dart'; import 'package:appflowy/workspace/presentation/home/menu/sidebar/space/shared_widget.dart'; import 'package:appflowy/workspace/presentation/home/menu/view/view_item.dart'; import 'package:appflowy/workspace/presentation/widgets/dialogs.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart'; import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart'; import 'package:appflowy_editor/appflowy_editor.dart'; import 'package:appflowy_result/appflowy_result.dart'; @@ -41,6 +42,7 @@ class AIMessageActionBar extends StatefulWidget { required this.showDecoration, this.onRegenerate, this.onChangeFormat, + this.onChangeModel, this.onOverrideVisibility, }); @@ -48,6 +50,7 @@ class AIMessageActionBar extends StatefulWidget { final bool showDecoration; final void Function()? onRegenerate; final void Function(PredefinedFormat)? onChangeFormat; + final void Function(AIModelPB)? onChangeModel; final void Function(bool)? onOverrideVisibility; @override @@ -126,6 +129,12 @@ class _AIMessageActionBarState extends State { popoverMutex: popoverMutex, onOverrideVisibility: widget.onOverrideVisibility, ), + ChangeModelButton( + isInHoverBar: widget.showDecoration, + onRegenerate: widget.onChangeModel, + popoverMutex: popoverMutex, + onOverrideVisibility: widget.onOverrideVisibility, + ), SaveToPageButton( textMessage: widget.message as TextMessage, isInHoverBar: widget.showDecoration, @@ -405,6 +414,85 @@ class _ChangeFormatPopoverContentState } } +class ChangeModelButton extends StatefulWidget { + const ChangeModelButton({ + super.key, + required this.isInHoverBar, + this.popoverMutex, + this.onRegenerate, + this.onOverrideVisibility, + }); + + final bool isInHoverBar; + final PopoverMutex? popoverMutex; + final void Function(AIModelPB)? onRegenerate; + final void Function(bool)? onOverrideVisibility; + + @override + State createState() => _ChangeModelButtonState(); +} + +class _ChangeModelButtonState extends State { + final popoverController = PopoverController(); + + @override + Widget build(BuildContext context) { + return AppFlowyPopover( + controller: popoverController, + mutex: widget.popoverMutex, + triggerActions: PopoverTriggerFlags.none, + margin: EdgeInsets.zero, + offset: Offset(8, 0), + direction: PopoverDirection.rightWithBottomAligned, + constraints: BoxConstraints(maxWidth: 250, maxHeight: 600), + onClose: () => widget.onOverrideVisibility?.call(false), + child: buildButton(context), + popupBuilder: (_) { + final bloc = context.read(); + final (models, _) = bloc.aiModelStateNotifier.getAvailableModels(); + return SelectModelPopoverContent( + models: models, + selectedModel: null, + onSelectModel: widget.onRegenerate, + ); + }, + ); + } + + Widget buildButton(BuildContext context) { + return FlowyTooltip( + message: LocaleKeys.chat_switchModel_label.tr(), + child: FlowyIconButton( + width: 32.0, + height: DesktopAIChatSizes.messageActionBarIconSize, + hoverColor: AFThemeExtension.of(context).lightGreyHover, + radius: widget.isInHoverBar + ? DesktopAIChatSizes.messageHoverActionBarIconRadius + : DesktopAIChatSizes.messageActionBarIconRadius, + icon: Row( + mainAxisSize: MainAxisSize.min, + children: [ + FlowySvg( + FlowySvgs.ai_sparks_s, + color: Theme.of(context).hintColor, + size: const Size.square(16), + ), + FlowySvg( + FlowySvgs.ai_source_drop_down_s, + color: Theme.of(context).hintColor, + size: const Size.square(8), + ), + ], + ), + onPressed: () { + widget.onOverrideVisibility?.call(true); + popoverController.show(); + }, + ), + ); + } +} + class SaveToPageButton extends StatefulWidget { const SaveToPageButton({ super.key, diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart index 770fb990b1..eed5f0a520 100644 --- a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_message_bubble.dart @@ -12,6 +12,7 @@ import 'package:appflowy/shared/markdown_to_document.dart'; import 'package:appflowy/startup/startup.dart'; import 'package:appflowy/workspace/application/view/view_ext.dart'; import 'package:appflowy/workspace/presentation/widgets/dialogs.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart'; import 'package:easy_localization/easy_localization.dart'; import 'package:flowy_infra/theme_extension.dart'; import 'package:flowy_infra_ui/flowy_infra_ui.dart'; @@ -23,6 +24,7 @@ import 'package:universal_platform/universal_platform.dart'; import '../chat_avatar.dart'; import '../layout_define.dart'; +import 'ai_change_model_bottom_sheet.dart'; import 'ai_message_action_bar.dart'; import 'ai_change_format_bottom_sheet.dart'; import 'message_util.dart'; @@ -41,6 +43,7 @@ class ChatAIMessageBubble extends StatelessWidget { this.isSelectingMessages = false, this.onRegenerate, this.onChangeFormat, + this.onChangeModel, }); final Message message; @@ -50,6 +53,7 @@ class ChatAIMessageBubble extends StatelessWidget { final bool isSelectingMessages; final void Function()? onRegenerate; final void Function(PredefinedFormat)? onChangeFormat; + final void Function(AIModelPB)? onChangeModel; @override Widget build(BuildContext context) { @@ -73,6 +77,7 @@ class ChatAIMessageBubble extends StatelessWidget { message: message, onRegenerate: onRegenerate, onChangeFormat: onChangeFormat, + onChangeModel: onChangeModel, child: child, ); } @@ -82,6 +87,7 @@ class ChatAIMessageBubble extends StatelessWidget { message: message, onRegenerate: onRegenerate, onChangeFormat: onChangeFormat, + onChangeModel: onChangeModel, child: child, ); } @@ -91,6 +97,7 @@ class ChatAIMessageBubble extends StatelessWidget { message: message, onRegenerate: onRegenerate, onChangeFormat: onChangeFormat, + onChangeModel: onChangeModel, child: child, ); } @@ -103,12 +110,14 @@ class ChatAIBottomInlineActions extends StatelessWidget { required this.message, this.onRegenerate, this.onChangeFormat, + this.onChangeModel, }); final Widget child; final Message message; final void Function()? onRegenerate; final void Function(PredefinedFormat)? onChangeFormat; + final void Function(AIModelPB)? onChangeModel; @override Widget build(BuildContext context) { @@ -127,6 +136,7 @@ class ChatAIBottomInlineActions extends StatelessWidget { showDecoration: false, onRegenerate: onRegenerate, onChangeFormat: onChangeFormat, + onChangeModel: onChangeModel, ), ), const VSpace(32.0), @@ -142,12 +152,14 @@ class ChatAIMessageHover extends StatefulWidget { required this.message, this.onRegenerate, this.onChangeFormat, + this.onChangeModel, }); final Widget child; final Message message; final void Function()? onRegenerate; final void Function(PredefinedFormat)? onChangeFormat; + final void Function(AIModelPB)? onChangeModel; @override State createState() => _ChatAIMessageHoverState(); @@ -229,6 +241,7 @@ class _ChatAIMessageHoverState extends State { showDecoration: true, onRegenerate: widget.onRegenerate, onChangeFormat: widget.onChangeFormat, + onChangeModel: widget.onChangeModel, onOverrideVisibility: (visibility) { overrideVisibility = visibility; }, @@ -302,12 +315,14 @@ class ChatAIMessagePopup extends StatelessWidget { required this.message, this.onRegenerate, this.onChangeFormat, + this.onChangeModel, }); final Widget child; final Message message; final void Function()? onRegenerate; final void Function(PredefinedFormat)? onChangeFormat; + final void Function(AIModelPB)? onChangeModel; @override Widget build(BuildContext context) { @@ -328,6 +343,8 @@ class ChatAIMessagePopup extends StatelessWidget { _divider(), _changeFormatButton(context), _divider(), + _changeModelButton(context), + _divider(), _saveToPageButton(context), ], ); @@ -399,6 +416,25 @@ class ChatAIMessagePopup extends StatelessWidget { ); } + Widget _changeModelButton(BuildContext context) { + return MobileQuickActionButton( + onTap: () async { + final bloc = context.read(); + final (models, _) = bloc.aiModelStateNotifier.getAvailableModels(); + final result = await showChangeModelBottomSheet(context, models); + if (result != null) { + onChangeModel?.call(result); + if (context.mounted) { + Navigator.of(context).pop(); + } + } + }, + icon: FlowySvgs.ai_sparks_s, + iconSize: const Size.square(20), + text: LocaleKeys.chat_switchModel_label.tr(), + ); + } + Widget _saveToPageButton(BuildContext context) { return MobileQuickActionButton( onTap: () async { diff --git a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_text_message.dart b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_text_message.dart index 5a55072c17..380767105f 100644 --- a/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_text_message.dart +++ b/frontend/appflowy_flutter/lib/plugins/ai_chat/presentation/message/ai_text_message.dart @@ -4,6 +4,7 @@ import 'package:appflowy/plugins/ai_chat/application/chat_ai_message_bloc.dart'; import 'package:appflowy/plugins/ai_chat/application/chat_bloc.dart'; import 'package:appflowy/plugins/ai_chat/application/chat_entity.dart'; import 'package:appflowy/plugins/ai_chat/application/chat_message_stream.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart'; import 'package:easy_localization/easy_localization.dart'; import 'package:fixnum/fixnum.dart'; import 'package:flowy_infra_ui/flowy_infra_ui.dart'; @@ -36,6 +37,7 @@ class ChatAIMessageWidget extends StatelessWidget { this.onSelectedMetadata, this.onRegenerate, this.onChangeFormat, + this.onChangeModel, this.isLastMessage = false, this.isStreaming = false, this.isSelectingMessages = false, @@ -53,6 +55,7 @@ class ChatAIMessageWidget extends StatelessWidget { final void Function()? onRegenerate; final void Function() onStopStream; final void Function(PredefinedFormat)? onChangeFormat; + final void Function(AIModelPB)? onChangeModel; final bool isStreaming; final bool isLastMessage; final bool isSelectingMessages; @@ -110,6 +113,7 @@ class ChatAIMessageWidget extends StatelessWidget { isSelectingMessages: isSelectingMessages, onRegenerate: onRegenerate, onChangeFormat: onChangeFormat, + onChangeModel: onChangeModel, child: Column( crossAxisAlignment: CrossAxisAlignment.start, children: [ diff --git a/frontend/resources/translations/en.json b/frontend/resources/translations/en.json index 62de8be000..24f734af5d 100644 --- a/frontend/resources/translations/en.json +++ b/frontend/resources/translations/en.json @@ -247,14 +247,16 @@ "table": "Table", "blankDescription": "Format response", "defaultDescription": "Auto mode", - "localModel": "Local Model", - "cloudModel": "Cloud Model", - "switchModel": "Switch model", "textWithImageDescription": "@:chat.changeFormat.text with image", "numberWithImageDescription": "@:chat.changeFormat.number with image", "bulletWithImageDescription": "@:chat.changeFormat.bullet with image", "tableWithImageDescription": "@:chat.changeFormat.table with image" }, + "switchModel": { + "label": "Switch model", + "localModel": "Local Model", + "cloudModel": "Cloud Model" + }, "selectBanner": { "saveButton": "Add to …", "selectMessages": "Select messages", @@ -3199,4 +3201,4 @@ "rewrite": "Rewrite", "insertBelow": "Insert below" } -} \ No newline at end of file +} diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index 03a18dcdec..0a63d8ece0 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -248,22 +248,23 @@ impl AIManager { answer_message_id: i64, answer_stream_port: i64, format: Option, + model: Option, ) -> FlowyResult<()> { let chat = self.get_or_create_chat_instance(chat_id).await?; let question_message_id = chat .get_question_id_from_answer_id(answer_message_id) .await?; - let preferred_model = self - .store_preferences - .get_object::(&ai_available_models_key(chat_id)); + let model = model.map_or_else( + || { + self + .store_preferences + .get_object::(&ai_available_models_key(chat_id)) + }, + |model| Some(model.into()), + ); chat - .stream_regenerate_response( - question_message_id, - answer_stream_port, - format, - preferred_model, - ) + .stream_regenerate_response(question_message_id, answer_stream_port, format, model) .await?; Ok(()) } diff --git a/frontend/rust-lib/flowy-ai/src/entities.rs b/frontend/rust-lib/flowy-ai/src/entities.rs index 0e996228c6..d10aa950e3 100644 --- a/frontend/rust-lib/flowy-ai/src/entities.rs +++ b/frontend/rust-lib/flowy-ai/src/entities.rs @@ -102,6 +102,9 @@ pub struct RegenerateResponsePB { #[pb(index = 4, one_of)] pub format: Option, + + #[pb(index = 5, one_of)] + pub model: Option, } #[derive(Default, ProtoBuf, Validate, Clone, Debug)] diff --git a/frontend/rust-lib/flowy-ai/src/event_handler.rs b/frontend/rust-lib/flowy-ai/src/event_handler.rs index 6bedebd261..ec8b7b4964 100644 --- a/frontend/rust-lib/flowy-ai/src/event_handler.rs +++ b/frontend/rust-lib/flowy-ai/src/event_handler.rs @@ -99,6 +99,7 @@ pub(crate) async fn regenerate_response_handler( data.answer_message_id, data.answer_stream_port, data.format, + data.model, ) .await?; Ok(())