Merge pull request #7617 from richardshiue/chore/improve-model-selection-ui

feat: regenerate message with different model
This commit is contained in:
Nathan.fooo 2025-03-27 12:54:26 +08:00 committed by GitHub
commit 584f762e11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 779 additions and 446 deletions

View File

@ -2,6 +2,8 @@ export 'service/ai_entities.dart';
export 'service/ai_prompt_input_bloc.dart';
export 'service/appflowy_ai_service.dart';
export 'service/error.dart';
export 'service/ai_model_state_notifier.dart';
export 'service/select_model_bloc.dart';
export 'widgets/loading_indicator.dart';
export 'widgets/prompt_input/action_buttons.dart';
export 'widgets/prompt_input/desktop_prompt_text_field.dart';
@ -13,4 +15,5 @@ export 'widgets/prompt_input/mentioned_page_text_span.dart';
export 'widgets/prompt_input/predefined_format_buttons.dart';
export 'widgets/prompt_input/select_sources_bottom_sheet.dart';
export 'widgets/prompt_input/select_sources_menu.dart';
export 'widgets/prompt_input/select_model_menu.dart';
export 'widgets/prompt_input/send_button.dart';

View File

@ -4,6 +4,14 @@ import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:equatable/equatable.dart';
enum AiType {
cloud,
local;
bool get isCloud => this == cloud;
bool get isLocal => this == local;
}
class PredefinedFormat extends Equatable {
const PredefinedFormat({
required this.imageFormat,

View File

@ -1,138 +0,0 @@
import 'package:appflowy/ai/service/ai_prompt_input_bloc.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/ai_chat/application/ai_model_switch_listener.dart';
import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:protobuf/protobuf.dart';
import 'package:universal_platform/universal_platform.dart';
class AIModelStateNotifier {
AIModelStateNotifier({required this.objectId})
: _isDesktop = UniversalPlatform.isDesktop,
_localAIListener =
UniversalPlatform.isDesktop ? LocalAIStateListener() : null,
_aiModelSwitchListener = AIModelSwitchListener(objectId: objectId);
final String objectId;
final bool _isDesktop;
final LocalAIStateListener? _localAIListener;
final AIModelSwitchListener _aiModelSwitchListener;
LocalAIPB? _localAIState;
AvailableModelsPB? _availableModels;
// Callbacks
void Function(AiType, bool, String)? onChanged;
void Function(AvailableModelsPB)? onAvailableModelsChanged;
String hintText() {
final aiType = getCurrentAiType();
if (aiType.isLocal) {
return isEditable()
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
: LocaleKeys.settings_aiPage_keys_localAIInitializing.tr();
}
return LocaleKeys.chat_inputMessageHint.tr();
}
AiType getCurrentAiType() {
// On non-desktop platforms, always return cloud type.
if (!_isDesktop) return AiType.cloud;
return (_availableModels?.selectedModel.isLocal ?? false)
? AiType.local
: AiType.cloud;
}
bool isEditable() {
// On non-desktop platforms, always editable.
if (!_isDesktop) return true;
return getCurrentAiType().isLocal
? _localAIState?.state == RunningStatePB.Running
: true;
}
void _notifyStateChanged() {
onChanged?.call(getCurrentAiType(), isEditable(), hintText());
}
Future<void> init() async {
// Load both available models and local state concurrently.
await Future.wait([
_loadAvailableModels(),
_loadLocalAIState(),
]);
}
Future<void> _loadAvailableModels() async {
final payload = AvailableModelsQueryPB(source: objectId);
final result = await AIEventGetAvailableModels(payload).send();
result.fold(
(models) {
_availableModels = models;
onAvailableModelsChanged?.call(models);
_notifyStateChanged();
},
(err) => Log.error("Failed to get available models: $err"),
);
}
Future<void> _loadLocalAIState() async {
final result = await AIEventGetLocalAIState().send();
result.fold(
(state) {
_localAIState = state;
_notifyStateChanged();
},
(error) {
Log.error("Failed to get local AI state: $error");
_notifyStateChanged();
},
);
}
void startListening({
void Function(AiType, bool, String)? onChanged,
void Function(AvailableModelsPB)? onAvailableModelsChanged,
}) {
this.onChanged = onChanged;
this.onAvailableModelsChanged = onAvailableModelsChanged;
// Only start local AI listener on desktop platforms.
if (_isDesktop) {
_localAIListener?.start(
stateCallback: (state) {
_localAIState = state;
if (state.state == RunningStatePB.Running ||
state.state == RunningStatePB.Stopped) {
_loadAvailableModels();
}
},
);
}
_aiModelSwitchListener.start(
onUpdateSelectedModel: (model) {
if (_availableModels != null) {
final updatedModels = _availableModels!.deepCopy()
..selectedModel = model;
_availableModels = updatedModels;
onAvailableModelsChanged?.call(updatedModels);
}
if (model.isLocal && _isDesktop) {
_loadLocalAIState();
} else {
_notifyStateChanged();
}
},
);
}
Future<void> stop() async {
onChanged = null;
await _localAIListener?.stop();
await _aiModelSwitchListener.stop();
}
}

View File

@ -0,0 +1,172 @@
import 'package:appflowy/ai/ai.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/ai_chat/application/ai_model_switch_listener.dart';
import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:protobuf/protobuf.dart';
import 'package:universal_platform/universal_platform.dart';
typedef OnModelStateChangedCallback = void Function(AiType, bool, String);
typedef OnAvailableModelsChangedCallback = void Function(
List<AIModelPB>,
AIModelPB?,
);
class AIModelStateNotifier {
AIModelStateNotifier({required this.objectId})
: _localAIListener =
UniversalPlatform.isDesktop ? LocalAIStateListener() : null,
_aiModelSwitchListener = AIModelSwitchListener(objectId: objectId) {
_startListening();
_init();
}
final String objectId;
final LocalAIStateListener? _localAIListener;
final AIModelSwitchListener _aiModelSwitchListener;
LocalAIPB? _localAIState;
AvailableModelsPB? _availableModels;
// callbacks
final List<OnModelStateChangedCallback> _stateChangedCallbacks = [];
final List<OnAvailableModelsChangedCallback>
_availableModelsChangedCallbacks = [];
void _startListening() {
if (UniversalPlatform.isDesktop) {
_localAIListener?.start(
stateCallback: (state) async {
_localAIState = state;
_notifyStateChanged();
if (state.state == RunningStatePB.Running ||
state.state == RunningStatePB.Stopped) {
await _loadAvailableModels();
_notifyAvailableModelsChanged();
}
},
);
}
_aiModelSwitchListener.start(
onUpdateSelectedModel: (model) async {
final updatedModels = _availableModels?.deepCopy()
?..selectedModel = model;
_availableModels = updatedModels;
_notifyAvailableModelsChanged();
if (model.isLocal && UniversalPlatform.isDesktop) {
await _loadLocalAiState();
}
_notifyStateChanged();
},
);
}
void _init() async {
await Future.wait([_loadLocalAiState(), _loadAvailableModels()]);
_notifyStateChanged();
_notifyAvailableModelsChanged();
}
void addListener({
OnModelStateChangedCallback? onStateChanged,
OnAvailableModelsChangedCallback? onAvailableModelsChanged,
}) {
if (onStateChanged != null) {
_stateChangedCallbacks.add(onStateChanged);
}
if (onAvailableModelsChanged != null) {
_availableModelsChangedCallbacks.add(onAvailableModelsChanged);
}
}
void removeListener({
OnModelStateChangedCallback? onStateChanged,
OnAvailableModelsChangedCallback? onAvailableModelsChanged,
}) {
if (onStateChanged != null) {
_stateChangedCallbacks.remove(onStateChanged);
}
if (onAvailableModelsChanged != null) {
_availableModelsChangedCallbacks.remove(onAvailableModelsChanged);
}
}
Future<void> dispose() async {
_stateChangedCallbacks.clear();
_availableModelsChangedCallbacks.clear();
await _localAIListener?.stop();
await _aiModelSwitchListener.stop();
}
(AiType, String, bool) getState() {
if (UniversalPlatform.isMobile) {
return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true);
}
final availableModels = _availableModels;
final localAiState = _localAIState;
if (availableModels == null) {
Log.warn("No available models");
return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true);
}
if (localAiState == null) {
Log.warn("Cannot get local AI state");
return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true);
}
if (!availableModels.selectedModel.isLocal) {
return (AiType.cloud, LocaleKeys.chat_inputMessageHint.tr(), true);
}
final editable = localAiState.state == RunningStatePB.Running;
final hintText = editable
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
: LocaleKeys.settings_aiPage_keys_localAIInitializing.tr();
return (AiType.local, hintText, editable);
}
(List<AIModelPB>, AIModelPB?) getAvailableModels() {
final availableModels = _availableModels;
if (availableModels == null) {
return ([], null);
}
return (availableModels.models, availableModels.selectedModel);
}
void _notifyAvailableModelsChanged() {
final (models, selectedModel) = getAvailableModels();
for (final callback in _availableModelsChangedCallbacks) {
callback(models, selectedModel);
}
}
void _notifyStateChanged() {
final (type, hintText, isEditable) = getState();
for (final callback in _stateChangedCallbacks) {
callback(type, isEditable, hintText);
}
}
Future<void> _loadAvailableModels() {
final payload = AvailableModelsQueryPB(source: objectId);
return AIEventGetAvailableModels(payload).send().fold(
(models) => _availableModels = models,
(err) => Log.error("Failed to get available models: $err"),
);
}
Future<void> _loadLocalAiState() {
return AIEventGetLocalAIState().send().fold(
(localAIState) => _localAIState = localAIState,
(error) => Log.error("Failed to get local AI state: $error"),
);
}
}

View File

@ -1,6 +1,6 @@
import 'dart:async';
import 'package:appflowy/ai/service/ai_input_control.dart';
import 'package:appflowy/ai/service/ai_model_state_notifier.dart';
import 'package:appflowy/plugins/ai_chat/application/chat_entity.dart';
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
@ -14,17 +14,18 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
AIPromptInputBloc({
required String objectId,
required PredefinedFormat? predefinedFormat,
}) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId),
super(AIPromptInputState.initial(objectId, predefinedFormat)) {
}) : aiModelStateNotifier = AIModelStateNotifier(objectId: objectId),
super(AIPromptInputState.initial(predefinedFormat)) {
_dispatch();
_startListening();
_init();
}
final AIModelStateNotifier _aiModelStateNotifier;
final AIModelStateNotifier aiModelStateNotifier;
@override
Future<void> close() async {
await _aiModelStateNotifier.stop();
await aiModelStateNotifier.dispose();
return super.close();
}
@ -36,7 +37,6 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
emit(
state.copyWith(
aiType: aiType,
supportChatWithFile: false,
editable: editable,
hintText: hintText,
),
@ -103,16 +103,17 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
);
}
void _init() {
_aiModelStateNotifier.startListening(
onChanged: (aiType, editable, hintText) {
if (!isClosed) {
add(AIPromptInputEvent.updateAIState(aiType, editable, hintText));
}
void _startListening() {
aiModelStateNotifier.addListener(
onStateChanged: (aiType, editable, hintText) {
add(AIPromptInputEvent.updateAIState(aiType, editable, hintText));
},
);
}
_aiModelStateNotifier.init();
void _init() {
final (aiType, hintText, isEditable) = aiModelStateNotifier.getState();
add(AIPromptInputEvent.updateAIState(aiType, isEditable, hintText));
}
Map<String, dynamic> consumeMetadata() {
@ -155,7 +156,6 @@ class AIPromptInputEvent with _$AIPromptInputEvent {
@freezed
class AIPromptInputState with _$AIPromptInputState {
const factory AIPromptInputState({
required String objectId,
required AiType aiType,
required bool supportChatWithFile,
required bool showPredefinedFormats,
@ -166,12 +166,8 @@ class AIPromptInputState with _$AIPromptInputState {
required String hintText,
}) = _AIPromptInputState;
factory AIPromptInputState.initial(
String objectId,
PredefinedFormat? format,
) =>
factory AIPromptInputState.initial(PredefinedFormat? format) =>
AIPromptInputState(
objectId: objectId,
aiType: AiType.cloud,
supportChatWithFile: false,
showPredefinedFormats: format != null,
@ -182,11 +178,3 @@ class AIPromptInputState with _$AIPromptInputState {
hintText: '',
);
}
enum AiType {
cloud,
local;
bool get isCloud => this == cloud;
bool get isLocal => this == local;
}

View File

@ -1,66 +1,66 @@
import 'dart:async';
import 'package:appflowy/ai/service/ai_input_control.dart';
import 'package:appflowy/ai/service/ai_model_state_notifier.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pbserver.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart';
import 'package:protobuf/protobuf.dart';
part 'select_model_bloc.freezed.dart';
class SelectModelBloc extends Bloc<SelectModelEvent, SelectModelState> {
SelectModelBloc({
required this.objectId,
}) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId),
super(const SelectModelState()) {
_aiModelStateNotifier.init();
_aiModelStateNotifier.startListening(
onAvailableModelsChanged: (models) {
if (!isClosed) {
add(SelectModelEvent.didLoadModels(models));
}
},
);
required AIModelStateNotifier aiModelStateNotifier,
}) : _aiModelStateNotifier = aiModelStateNotifier,
super(SelectModelState.initial(aiModelStateNotifier)) {
on<SelectModelEvent>(
(event, emit) async {
await event.when(
selectModel: (AIModelPB model) async {
await AIEventUpdateSelectedModel(
(event, emit) {
event.when(
selectModel: (model) {
AIEventUpdateSelectedModel(
UpdateSelectedModelPB(
source: objectId,
source: _aiModelStateNotifier.objectId,
selectedModel: model,
),
).send();
state.availableModels?.freeze();
final newAvailableModels = state.availableModels?.rebuild((m) {
m.selectedModel = model;
});
emit(state.copyWith(selectedModel: model));
},
didLoadModels: (models, selectedModel) {
emit(
state.copyWith(
availableModels: newAvailableModels,
SelectModelState(
models: models,
selectedModel: selectedModel,
),
);
},
didLoadModels: (AvailableModelsPB models) {
emit(state.copyWith(availableModels: models));
},
);
},
);
_aiModelStateNotifier.addListener(
onAvailableModelsChanged: _onAvailableModelsChanged,
);
}
final String objectId;
final AIModelStateNotifier _aiModelStateNotifier;
@override
Future<void> close() async {
await _aiModelStateNotifier.stop();
_aiModelStateNotifier.removeListener(
onAvailableModelsChanged: _onAvailableModelsChanged,
);
await super.close();
}
void _onAvailableModelsChanged(
List<AIModelPB> models,
AIModelPB? selectedModel,
) {
if (!isClosed) {
add(SelectModelEvent.didLoadModels(models, selectedModel));
}
}
}
@freezed
@ -70,13 +70,23 @@ class SelectModelEvent with _$SelectModelEvent {
) = _SelectModel;
const factory SelectModelEvent.didLoadModels(
AvailableModelsPB models,
List<AIModelPB> models,
AIModelPB? selectedModel,
) = _DidLoadModels;
}
@freezed
class SelectModelState with _$SelectModelState {
const factory SelectModelState({
AvailableModelsPB? availableModels,
required List<AIModelPB> models,
required AIModelPB? selectedModel,
}) = _SelectModelState;
factory SelectModelState.initial(AIModelStateNotifier notifier) {
final (models, selectedModel) = notifier.getAvailableModels();
return SelectModelState(
models: models,
selectedModel: selectedModel,
);
}
}

View File

@ -1,17 +1,14 @@
import 'package:appflowy/ai/ai.dart';
import 'package:appflowy/ai/service/select_model_bloc.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/ai_chat/application/chat_input_control_cubit.dart';
import 'package:appflowy/plugins/ai_chat/presentation/layout_define.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:appflowy/util/theme_extension.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:extended_text_field/extended_text_field.dart';
import 'package:flowy_infra/file_picker/file_picker_service.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flowy_infra_ui/style_widget/hover.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
@ -169,7 +166,6 @@ class _DesktopPromptInputState extends State<DesktopPromptInput> {
top: null,
child: TextFieldTapRegion(
child: _PromptBottomActions(
objectId: state.objectId,
showPredefinedFormats:
state.showPredefinedFormats,
onTogglePredefinedFormatSection: () =>
@ -567,7 +563,6 @@ class PromptInputTextField extends StatelessWidget {
class _PromptBottomActions extends StatelessWidget {
const _PromptBottomActions({
required this.objectId,
required this.sendButtonState,
required this.showPredefinedFormats,
required this.onTogglePredefinedFormatSection,
@ -579,7 +574,6 @@ class _PromptBottomActions extends StatelessWidget {
this.extraBottomActionButton,
});
final String objectId;
final bool showPredefinedFormats;
final void Function() onTogglePredefinedFormatSection;
final void Function() onStartMention;
@ -600,10 +594,16 @@ class _PromptBottomActions extends StatelessWidget {
return Row(
children: [
_predefinedFormatButton(),
SelectModelButton(objectId: objectId),
const HSpace(
DesktopAIChatSizes.inputActionBarButtonSpacing,
),
SelectModelMenu(
aiModelStateNotifier:
context.read<AIPromptInputBloc>().aiModelStateNotifier,
),
const Spacer(),
if (state.aiType.isCloud) ...[
_selectSourcesButton(context),
_selectSourcesButton(),
const HSpace(
DesktopAIChatSizes.inputActionBarButtonSpacing,
),
@ -639,7 +639,7 @@ class _PromptBottomActions extends StatelessWidget {
);
}
Widget _selectSourcesButton(BuildContext context) {
Widget _selectSourcesButton() {
return PromptInputDesktopSelectSourcesButton(
onUpdateSelectedSources: onUpdateSelectedSources,
selectedSourcesNotifier: selectedSourcesNotifier,
@ -686,225 +686,3 @@ class _PromptBottomActions extends StatelessWidget {
);
}
}
class SelectModelButton extends StatefulWidget {
const SelectModelButton({
super.key,
required this.objectId,
});
final String objectId;
@override
State<SelectModelButton> createState() => _SelectModelButtonState();
}
class _SelectModelButtonState extends State<SelectModelButton> {
final popoverController = PopoverController();
late SelectModelBloc bloc;
@override
void initState() {
super.initState();
bloc = SelectModelBloc(objectId: widget.objectId);
}
@override
void dispose() {
popoverController.close();
bloc.close();
super.dispose();
}
@override
Widget build(BuildContext context) {
return BlocProvider.value(
value: bloc,
child: BlocBuilder<SelectModelBloc, SelectModelState>(
builder: (context, state) {
return AppFlowyPopover(
// constraints: BoxConstraints.loose(const Size(250, 200)),
offset: const Offset(0.0, -10.0),
direction: PopoverDirection.topWithLeftAligned,
margin: EdgeInsets.zero,
controller: popoverController,
onOpen: () {},
onClose: () {},
popupBuilder: (_) {
return BlocProvider.value(
value: bloc,
child: _PopoverSelectModel(
onClose: () => popoverController.close(),
),
);
},
child: _CurrentModelButton(
key: ValueKey(state.availableModels?.selectedModel.name),
modelName: state.availableModels?.selectedModel.name ?? "",
onTap: () => popoverController.show(),
),
);
},
),
);
}
}
class _PopoverSelectModel extends StatelessWidget {
const _PopoverSelectModel({
required this.onClose,
});
final VoidCallback onClose;
@override
Widget build(BuildContext context) {
return BlocBuilder<SelectModelBloc, SelectModelState>(
builder: (context, state) {
if (state.availableModels == null ||
state.availableModels!.models.isEmpty) {
return const SizedBox.shrink();
}
// Separate models into local and cloud models
final localModels = state.availableModels!.models
.where((model) => model.isLocal)
.toList();
final cloudModels = state.availableModels!.models
.where((model) => !model.isLocal)
.toList();
return Padding(
padding: const EdgeInsets.fromLTRB(8, 4, 8, 12),
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
// Local AI Models Section
if (localModels.isNotEmpty) ...[
_ModelSectionHeader(
title: LocaleKeys.chat_changeFormat_localModel.tr(),
),
const SizedBox(height: 4),
...localModels.map(
(model) => _ModelItem(
model: model,
onTap: () {
context.read<SelectModelBloc>().add(
SelectModelEvent.selectModel(model),
);
onClose();
},
),
),
const SizedBox(height: 8),
],
// Cloud AI Models Section
if (cloudModels.isNotEmpty) ...[
if (localModels.isNotEmpty)
_ModelSectionHeader(
title: LocaleKeys.chat_changeFormat_cloudModel.tr(),
),
const VSpace(4),
...cloudModels.map(
(model) => _ModelItem(
model: model,
onTap: () {
context.read<SelectModelBloc>().add(
SelectModelEvent.selectModel(model),
);
onClose();
},
),
),
],
],
),
);
},
);
}
}
class _ModelSectionHeader extends StatelessWidget {
const _ModelSectionHeader({
required this.title,
});
final String title;
@override
Widget build(BuildContext context) {
return Padding(
padding: const EdgeInsets.only(top: 4, bottom: 2),
child: FlowyText(
title,
fontSize: 12,
color: Theme.of(context).hintColor,
fontWeight: FontWeight.w500,
),
);
}
}
class _ModelItem extends StatelessWidget {
const _ModelItem({
required this.model,
required this.onTap,
});
final AIModelPB model;
final VoidCallback onTap;
@override
Widget build(BuildContext context) {
final modelName = model.name;
return FlowyTextButton(
modelName,
fillColor: Colors.transparent,
onPressed: onTap,
);
}
}
class _CurrentModelButton extends StatelessWidget {
const _CurrentModelButton({
required this.modelName,
required this.onTap,
super.key,
});
final String modelName;
final VoidCallback onTap;
@override
Widget build(BuildContext context) {
return FlowyTooltip(
message: LocaleKeys.chat_changeFormat_switchModel.tr(),
child: GestureDetector(
onTap: onTap,
behavior: HitTestBehavior.opaque,
child: SizedBox(
height: DesktopAIPromptSizes.actionBarButtonSize,
child: FlowyHover(
style: const HoverStyle(
borderRadius: BorderRadius.all(Radius.circular(8)),
),
child: Padding(
padding: const EdgeInsetsDirectional.fromSTEB(6, 6, 4, 6),
child: FlowyText(
modelName,
fontSize: 12,
figmaLineHeight: 16,
color: Theme.of(context).hintColor,
),
),
),
),
),
);
}
}

View File

@ -104,6 +104,7 @@ class ChangeFormatBar extends StatelessWidget {
},
child: FlowyTooltip(
message: format.i18n,
preferBelow: false,
child: SizedBox.square(
dimension: _buttonSize,
child: FlowyHover(
@ -150,6 +151,7 @@ class ChangeFormatBar extends StatelessWidget {
},
child: FlowyTooltip(
message: format.i18n,
preferBelow: false,
child: SizedBox.square(
dimension: _buttonSize,
child: FlowyHover(

View File

@ -0,0 +1,222 @@
import 'package:appflowy/ai/ai.dart';
import 'package:appflowy/generated/flowy_svgs.g.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flowy_infra_ui/style_widget/hover.dart';
import 'package:flutter/material.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
class SelectModelMenu extends StatefulWidget {
const SelectModelMenu({
super.key,
required this.aiModelStateNotifier,
});
final AIModelStateNotifier aiModelStateNotifier;
@override
State<SelectModelMenu> createState() => _SelectModelMenuState();
}
class _SelectModelMenuState extends State<SelectModelMenu> {
final popoverController = PopoverController();
@override
Widget build(BuildContext context) {
return BlocProvider(
create: (context) => SelectModelBloc(
aiModelStateNotifier: widget.aiModelStateNotifier,
),
child: BlocBuilder<SelectModelBloc, SelectModelState>(
builder: (context, state) {
if (state.selectedModel == null) {
return const SizedBox.shrink();
}
return AppFlowyPopover(
offset: Offset(-12.0, 0.0),
constraints: BoxConstraints(maxWidth: 250, maxHeight: 600),
direction: PopoverDirection.topWithLeftAligned,
margin: EdgeInsets.zero,
controller: popoverController,
popupBuilder: (popoverContext) {
return SelectModelPopoverContent(
models: state.models,
selectedModel: state.selectedModel,
onSelectModel: (model) {
if (model != state.selectedModel) {
context
.read<SelectModelBloc>()
.add(SelectModelEvent.selectModel(model));
}
popoverController.close();
},
);
},
child: _CurrentModelButton(
modelName: state.selectedModel!.name,
onTap: () => popoverController.show(),
),
);
},
),
);
}
}
class SelectModelPopoverContent extends StatelessWidget {
const SelectModelPopoverContent({
super.key,
required this.models,
required this.selectedModel,
this.onSelectModel,
});
final List<AIModelPB> models;
final AIModelPB? selectedModel;
final void Function(AIModelPB)? onSelectModel;
@override
Widget build(BuildContext context) {
if (models.isEmpty) {
return const SizedBox.shrink();
}
// separate models into local and cloud models
final localModels = models.where((model) => model.isLocal).toList();
final cloudModels = models.where((model) => !model.isLocal).toList();
return Padding(
padding: const EdgeInsets.all(8.0),
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
if (localModels.isNotEmpty) ...[
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_localModel.tr(),
),
const VSpace(4.0),
],
...localModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
if (cloudModels.isNotEmpty && localModels.isNotEmpty) ...[
const VSpace(8.0),
_ModelSectionHeader(
title: LocaleKeys.chat_switchModel_cloudModel.tr(),
),
const VSpace(4.0),
],
...cloudModels.map(
(model) => _ModelItem(
model: model,
isSelected: model == selectedModel,
onTap: () => onSelectModel?.call(model),
),
),
],
),
);
}
}
class _ModelSectionHeader extends StatelessWidget {
const _ModelSectionHeader({
required this.title,
});
final String title;
@override
Widget build(BuildContext context) {
return Padding(
padding: const EdgeInsets.only(top: 4, bottom: 2),
child: FlowyText(
title,
fontSize: 12,
figmaLineHeight: 16,
color: Theme.of(context).hintColor,
fontWeight: FontWeight.w500,
),
);
}
}
class _ModelItem extends StatelessWidget {
const _ModelItem({
required this.model,
required this.isSelected,
required this.onTap,
});
final AIModelPB model;
final bool isSelected;
final VoidCallback onTap;
@override
Widget build(BuildContext context) {
return SizedBox(
height: 32,
child: FlowyButton(
onTap: onTap,
margin: EdgeInsets.symmetric(horizontal: 8.0, vertical: 6.0),
text: FlowyText(model.name),
rightIcon: isSelected ? FlowySvg(FlowySvgs.check_s) : null,
),
);
}
}
class _CurrentModelButton extends StatelessWidget {
const _CurrentModelButton({
required this.modelName,
required this.onTap,
});
final String modelName;
final VoidCallback onTap;
@override
Widget build(BuildContext context) {
return FlowyTooltip(
message: LocaleKeys.chat_switchModel_label.tr(),
child: GestureDetector(
onTap: onTap,
behavior: HitTestBehavior.opaque,
child: SizedBox(
height: DesktopAIPromptSizes.actionBarButtonSize,
child: FlowyHover(
style: const HoverStyle(
borderRadius: BorderRadius.all(Radius.circular(8)),
),
child: Padding(
padding: const EdgeInsetsDirectional.all(4.0),
child: Row(
children: [
FlowyText(
modelName,
fontSize: 12,
figmaLineHeight: 16,
color: Theme.of(context).hintColor,
),
HSpace(2.0),
FlowySvg(
FlowySvgs.ai_source_drop_down_s,
color: Theme.of(context).hintColor,
size: const Size.square(8),
),
],
),
),
),
),
),
);
}
}

View File

@ -239,9 +239,9 @@ class ChatBloc extends Bloc<ChatEvent, ChatState> {
),
);
},
regenerateAnswer: (id, format) {
regenerateAnswer: (id, format, model) {
_clearRelatedQuestions();
_regenerateAnswer(id, format);
_regenerateAnswer(id, format, model);
lastSentMessage = null;
isFetchingRelatedQuestions = false;
@ -483,6 +483,7 @@ class ChatBloc extends Bloc<ChatEvent, ChatState> {
void _regenerateAnswer(
String answerMessageIdString,
PredefinedFormat? format,
AIModelPB? model,
) async {
final id = temporaryMessageIDMap.entries
.firstWhereOrNull((e) => e.value == answerMessageIdString)
@ -505,6 +506,9 @@ class ChatBloc extends Bloc<ChatEvent, ChatState> {
if (format != null) {
payload.format = format.toPB();
}
if (model != null) {
payload.model = model;
}
await AIEventRegenerateResponse(payload).send().fold(
(success) {
@ -637,6 +641,7 @@ class ChatEvent with _$ChatEvent {
const factory ChatEvent.regenerateAnswer(
String id,
PredefinedFormat? format,
AIModelPB? model,
) = _RegenerateAnswer;
// streaming answer

View File

@ -265,10 +265,13 @@ class _ChatContentPage extends StatelessWidget {
_onSelectMetadata(context, metadata),
onRegenerate: () => context
.read<ChatBloc>()
.add(ChatEvent.regenerateAnswer(message.id, null)),
.add(ChatEvent.regenerateAnswer(message.id, null, null)),
onChangeFormat: (format) => context
.read<ChatBloc>()
.add(ChatEvent.regenerateAnswer(message.id, format)),
.add(ChatEvent.regenerateAnswer(message.id, format, null)),
onChangeModel: (model) => context
.read<ChatBloc>()
.add(ChatEvent.regenerateAnswer(message.id, null, model)),
onStopStream: () => context.read<ChatBloc>().add(
const ChatEvent.stopStream(),
),

View File

@ -0,0 +1,145 @@
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/mobile/presentation/base/app_bar/app_bar_actions.dart';
import 'package:appflowy/mobile/presentation/bottom_sheet/bottom_sheet.dart';
import 'package:appflowy/mobile/presentation/widgets/widgets.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:collection/collection.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flutter/material.dart';
Future<AIModelPB?> showChangeModelBottomSheet(
BuildContext context,
List<AIModelPB> models,
) {
return showMobileBottomSheet<AIModelPB?>(
context,
showDragHandle: true,
builder: (context) => _ChangeModelBottomSheetContent(models: models),
);
}
class _ChangeModelBottomSheetContent extends StatefulWidget {
const _ChangeModelBottomSheetContent({
required this.models,
});
final List<AIModelPB> models;
@override
State<_ChangeModelBottomSheetContent> createState() =>
_ChangeModelBottomSheetContentState();
}
class _ChangeModelBottomSheetContentState
extends State<_ChangeModelBottomSheetContent> {
AIModelPB? model;
@override
Widget build(BuildContext context) {
return Column(
mainAxisSize: MainAxisSize.min,
children: [
_Header(
onCancel: () => Navigator.of(context).pop(),
onDone: () => Navigator.of(context).pop(model),
),
const VSpace(4.0),
_Body(
models: widget.models,
selectedModel: model,
onSelectModel: (format) {
setState(() => model = format);
},
),
const VSpace(16.0),
],
);
}
}
class _Header extends StatelessWidget {
const _Header({
required this.onCancel,
required this.onDone,
});
final VoidCallback onCancel;
final VoidCallback onDone;
@override
Widget build(BuildContext context) {
return SizedBox(
height: 44.0,
child: Stack(
children: [
Align(
alignment: Alignment.centerLeft,
child: AppBarBackButton(
padding: const EdgeInsets.symmetric(
vertical: 12,
horizontal: 16,
),
onTap: onCancel,
),
),
Align(
child: Container(
constraints: const BoxConstraints(maxWidth: 250),
child: FlowyText(
LocaleKeys.chat_switchModel_label.tr(),
fontSize: 17.0,
fontWeight: FontWeight.w500,
overflow: TextOverflow.ellipsis,
),
),
),
Align(
alignment: Alignment.centerRight,
child: AppBarDoneButton(
onTap: onDone,
),
),
],
),
);
}
}
class _Body extends StatelessWidget {
const _Body({
required this.models,
required this.selectedModel,
required this.onSelectModel,
});
final List<AIModelPB> models;
final AIModelPB? selectedModel;
final void Function(AIModelPB) onSelectModel;
@override
Widget build(BuildContext context) {
return Column(
mainAxisSize: MainAxisSize.min,
children: models
.mapIndexed(
(index, model) => _buildModelButton(model, index == 0),
)
.toList(),
);
}
Widget _buildModelButton(
AIModelPB model, [
bool isFirst = false,
]) {
return FlowyOptionTile.checkbox(
text: model.name,
isSelected: model == selectedModel,
showTopBorder: isFirst,
onTap: () {
onSelectModel(model);
},
);
}
}

View File

@ -21,6 +21,7 @@ import 'package:appflowy/workspace/application/view/view_ext.dart';
import 'package:appflowy/workspace/presentation/home/menu/sidebar/space/shared_widget.dart';
import 'package:appflowy/workspace/presentation/home/menu/view/view_item.dart';
import 'package:appflowy/workspace/presentation/widgets/dialogs.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:appflowy_result/appflowy_result.dart';
@ -41,6 +42,7 @@ class AIMessageActionBar extends StatefulWidget {
required this.showDecoration,
this.onRegenerate,
this.onChangeFormat,
this.onChangeModel,
this.onOverrideVisibility,
});
@ -48,6 +50,7 @@ class AIMessageActionBar extends StatefulWidget {
final bool showDecoration;
final void Function()? onRegenerate;
final void Function(PredefinedFormat)? onChangeFormat;
final void Function(AIModelPB)? onChangeModel;
final void Function(bool)? onOverrideVisibility;
@override
@ -126,6 +129,12 @@ class _AIMessageActionBarState extends State<AIMessageActionBar> {
popoverMutex: popoverMutex,
onOverrideVisibility: widget.onOverrideVisibility,
),
ChangeModelButton(
isInHoverBar: widget.showDecoration,
onRegenerate: widget.onChangeModel,
popoverMutex: popoverMutex,
onOverrideVisibility: widget.onOverrideVisibility,
),
SaveToPageButton(
textMessage: widget.message as TextMessage,
isInHoverBar: widget.showDecoration,
@ -405,6 +414,85 @@ class _ChangeFormatPopoverContentState
}
}
class ChangeModelButton extends StatefulWidget {
const ChangeModelButton({
super.key,
required this.isInHoverBar,
this.popoverMutex,
this.onRegenerate,
this.onOverrideVisibility,
});
final bool isInHoverBar;
final PopoverMutex? popoverMutex;
final void Function(AIModelPB)? onRegenerate;
final void Function(bool)? onOverrideVisibility;
@override
State<ChangeModelButton> createState() => _ChangeModelButtonState();
}
class _ChangeModelButtonState extends State<ChangeModelButton> {
final popoverController = PopoverController();
@override
Widget build(BuildContext context) {
return AppFlowyPopover(
controller: popoverController,
mutex: widget.popoverMutex,
triggerActions: PopoverTriggerFlags.none,
margin: EdgeInsets.zero,
offset: Offset(8, 0),
direction: PopoverDirection.rightWithBottomAligned,
constraints: BoxConstraints(maxWidth: 250, maxHeight: 600),
onClose: () => widget.onOverrideVisibility?.call(false),
child: buildButton(context),
popupBuilder: (_) {
final bloc = context.read<AIPromptInputBloc>();
final (models, _) = bloc.aiModelStateNotifier.getAvailableModels();
return SelectModelPopoverContent(
models: models,
selectedModel: null,
onSelectModel: widget.onRegenerate,
);
},
);
}
Widget buildButton(BuildContext context) {
return FlowyTooltip(
message: LocaleKeys.chat_switchModel_label.tr(),
child: FlowyIconButton(
width: 32.0,
height: DesktopAIChatSizes.messageActionBarIconSize,
hoverColor: AFThemeExtension.of(context).lightGreyHover,
radius: widget.isInHoverBar
? DesktopAIChatSizes.messageHoverActionBarIconRadius
: DesktopAIChatSizes.messageActionBarIconRadius,
icon: Row(
mainAxisSize: MainAxisSize.min,
children: [
FlowySvg(
FlowySvgs.ai_sparks_s,
color: Theme.of(context).hintColor,
size: const Size.square(16),
),
FlowySvg(
FlowySvgs.ai_source_drop_down_s,
color: Theme.of(context).hintColor,
size: const Size.square(8),
),
],
),
onPressed: () {
widget.onOverrideVisibility?.call(true);
popoverController.show();
},
),
);
}
}
class SaveToPageButton extends StatefulWidget {
const SaveToPageButton({
super.key,

View File

@ -12,6 +12,7 @@ import 'package:appflowy/shared/markdown_to_document.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:appflowy/workspace/application/view/view_ext.dart';
import 'package:appflowy/workspace/presentation/widgets/dialogs.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra/theme_extension.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
@ -23,6 +24,7 @@ import 'package:universal_platform/universal_platform.dart';
import '../chat_avatar.dart';
import '../layout_define.dart';
import 'ai_change_model_bottom_sheet.dart';
import 'ai_message_action_bar.dart';
import 'ai_change_format_bottom_sheet.dart';
import 'message_util.dart';
@ -41,6 +43,7 @@ class ChatAIMessageBubble extends StatelessWidget {
this.isSelectingMessages = false,
this.onRegenerate,
this.onChangeFormat,
this.onChangeModel,
});
final Message message;
@ -50,6 +53,7 @@ class ChatAIMessageBubble extends StatelessWidget {
final bool isSelectingMessages;
final void Function()? onRegenerate;
final void Function(PredefinedFormat)? onChangeFormat;
final void Function(AIModelPB)? onChangeModel;
@override
Widget build(BuildContext context) {
@ -73,6 +77,7 @@ class ChatAIMessageBubble extends StatelessWidget {
message: message,
onRegenerate: onRegenerate,
onChangeFormat: onChangeFormat,
onChangeModel: onChangeModel,
child: child,
);
}
@ -82,6 +87,7 @@ class ChatAIMessageBubble extends StatelessWidget {
message: message,
onRegenerate: onRegenerate,
onChangeFormat: onChangeFormat,
onChangeModel: onChangeModel,
child: child,
);
}
@ -91,6 +97,7 @@ class ChatAIMessageBubble extends StatelessWidget {
message: message,
onRegenerate: onRegenerate,
onChangeFormat: onChangeFormat,
onChangeModel: onChangeModel,
child: child,
);
}
@ -103,12 +110,14 @@ class ChatAIBottomInlineActions extends StatelessWidget {
required this.message,
this.onRegenerate,
this.onChangeFormat,
this.onChangeModel,
});
final Widget child;
final Message message;
final void Function()? onRegenerate;
final void Function(PredefinedFormat)? onChangeFormat;
final void Function(AIModelPB)? onChangeModel;
@override
Widget build(BuildContext context) {
@ -127,6 +136,7 @@ class ChatAIBottomInlineActions extends StatelessWidget {
showDecoration: false,
onRegenerate: onRegenerate,
onChangeFormat: onChangeFormat,
onChangeModel: onChangeModel,
),
),
const VSpace(32.0),
@ -142,12 +152,14 @@ class ChatAIMessageHover extends StatefulWidget {
required this.message,
this.onRegenerate,
this.onChangeFormat,
this.onChangeModel,
});
final Widget child;
final Message message;
final void Function()? onRegenerate;
final void Function(PredefinedFormat)? onChangeFormat;
final void Function(AIModelPB)? onChangeModel;
@override
State<ChatAIMessageHover> createState() => _ChatAIMessageHoverState();
@ -229,6 +241,7 @@ class _ChatAIMessageHoverState extends State<ChatAIMessageHover> {
showDecoration: true,
onRegenerate: widget.onRegenerate,
onChangeFormat: widget.onChangeFormat,
onChangeModel: widget.onChangeModel,
onOverrideVisibility: (visibility) {
overrideVisibility = visibility;
},
@ -302,12 +315,14 @@ class ChatAIMessagePopup extends StatelessWidget {
required this.message,
this.onRegenerate,
this.onChangeFormat,
this.onChangeModel,
});
final Widget child;
final Message message;
final void Function()? onRegenerate;
final void Function(PredefinedFormat)? onChangeFormat;
final void Function(AIModelPB)? onChangeModel;
@override
Widget build(BuildContext context) {
@ -328,6 +343,8 @@ class ChatAIMessagePopup extends StatelessWidget {
_divider(),
_changeFormatButton(context),
_divider(),
_changeModelButton(context),
_divider(),
_saveToPageButton(context),
],
);
@ -399,6 +416,25 @@ class ChatAIMessagePopup extends StatelessWidget {
);
}
Widget _changeModelButton(BuildContext context) {
return MobileQuickActionButton(
onTap: () async {
final bloc = context.read<AIPromptInputBloc>();
final (models, _) = bloc.aiModelStateNotifier.getAvailableModels();
final result = await showChangeModelBottomSheet(context, models);
if (result != null) {
onChangeModel?.call(result);
if (context.mounted) {
Navigator.of(context).pop();
}
}
},
icon: FlowySvgs.ai_sparks_s,
iconSize: const Size.square(20),
text: LocaleKeys.chat_switchModel_label.tr(),
);
}
Widget _saveToPageButton(BuildContext context) {
return MobileQuickActionButton(
onTap: () async {

View File

@ -4,6 +4,7 @@ import 'package:appflowy/plugins/ai_chat/application/chat_ai_message_bloc.dart';
import 'package:appflowy/plugins/ai_chat/application/chat_bloc.dart';
import 'package:appflowy/plugins/ai_chat/application/chat_entity.dart';
import 'package:appflowy/plugins/ai_chat/application/chat_message_stream.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:fixnum/fixnum.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
@ -36,6 +37,7 @@ class ChatAIMessageWidget extends StatelessWidget {
this.onSelectedMetadata,
this.onRegenerate,
this.onChangeFormat,
this.onChangeModel,
this.isLastMessage = false,
this.isStreaming = false,
this.isSelectingMessages = false,
@ -53,6 +55,7 @@ class ChatAIMessageWidget extends StatelessWidget {
final void Function()? onRegenerate;
final void Function() onStopStream;
final void Function(PredefinedFormat)? onChangeFormat;
final void Function(AIModelPB)? onChangeModel;
final bool isStreaming;
final bool isLastMessage;
final bool isSelectingMessages;
@ -110,6 +113,7 @@ class ChatAIMessageWidget extends StatelessWidget {
isSelectingMessages: isSelectingMessages,
onRegenerate: onRegenerate,
onChangeFormat: onChangeFormat,
onChangeModel: onChangeModel,
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [

View File

@ -247,14 +247,16 @@
"table": "Table",
"blankDescription": "Format response",
"defaultDescription": "Auto mode",
"localModel": "Local Model",
"cloudModel": "Cloud Model",
"switchModel": "Switch model",
"textWithImageDescription": "@:chat.changeFormat.text with image",
"numberWithImageDescription": "@:chat.changeFormat.number with image",
"bulletWithImageDescription": "@:chat.changeFormat.bullet with image",
"tableWithImageDescription": "@:chat.changeFormat.table with image"
},
"switchModel": {
"label": "Switch model",
"localModel": "Local Model",
"cloudModel": "Cloud Model"
},
"selectBanner": {
"saveButton": "Add to …",
"selectMessages": "Select messages",
@ -3199,4 +3201,4 @@
"rewrite": "Rewrite",
"insertBelow": "Insert below"
}
}
}

View File

@ -248,22 +248,23 @@ impl AIManager {
answer_message_id: i64,
answer_stream_port: i64,
format: Option<PredefinedFormatPB>,
model: Option<AIModelPB>,
) -> FlowyResult<()> {
let chat = self.get_or_create_chat_instance(chat_id).await?;
let question_message_id = chat
.get_question_id_from_answer_id(answer_message_id)
.await?;
let preferred_model = self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(chat_id));
let model = model.map_or_else(
|| {
self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(chat_id))
},
|model| Some(model.into()),
);
chat
.stream_regenerate_response(
question_message_id,
answer_stream_port,
format,
preferred_model,
)
.stream_regenerate_response(question_message_id, answer_stream_port, format, model)
.await?;
Ok(())
}

View File

@ -102,6 +102,9 @@ pub struct RegenerateResponsePB {
#[pb(index = 4, one_of)]
pub format: Option<PredefinedFormatPB>,
#[pb(index = 5, one_of)]
pub model: Option<AIModelPB>,
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]

View File

@ -99,6 +99,7 @@ pub(crate) async fn regenerate_response_handler(
data.answer_message_id,
data.answer_stream_port,
data.format,
data.model,
)
.await?;
Ok(())