mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2025-12-24 21:56:47 +00:00
chore: support switch ai model in chat or ai writer
This commit is contained in:
parent
ad695e43b9
commit
05949d2f87
138
frontend/appflowy_flutter/lib/ai/service/ai_input_control.dart
Normal file
138
frontend/appflowy_flutter/lib/ai/service/ai_input_control.dart
Normal file
@ -0,0 +1,138 @@
|
||||
import 'package:appflowy/ai/service/ai_prompt_input_bloc.dart';
|
||||
import 'package:appflowy/generated/locale_keys.g.dart';
|
||||
import 'package:appflowy/plugins/ai_chat/application/ai_model_switch_listener.dart';
|
||||
import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart';
|
||||
import 'package:appflowy_backend/dispatch/dispatch.dart';
|
||||
import 'package:appflowy_backend/log.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
|
||||
import 'package:appflowy_result/appflowy_result.dart';
|
||||
import 'package:easy_localization/easy_localization.dart';
|
||||
import 'package:protobuf/protobuf.dart';
|
||||
import 'package:universal_platform/universal_platform.dart';
|
||||
|
||||
class AIModelStateNotifier {
|
||||
AIModelStateNotifier({required this.objectId})
|
||||
: _isDesktop = UniversalPlatform.isDesktop,
|
||||
_localAIListener =
|
||||
UniversalPlatform.isDesktop ? LocalAIStateListener() : null,
|
||||
_aiModelSwitchListener = AIModelSwitchListener(chatId: objectId);
|
||||
|
||||
final String objectId;
|
||||
final bool _isDesktop;
|
||||
final LocalAIStateListener? _localAIListener;
|
||||
final AIModelSwitchListener _aiModelSwitchListener;
|
||||
LocalAIPB? _localAIState;
|
||||
AvailableModelsPB? _availableModels;
|
||||
|
||||
// callbacks
|
||||
void Function(AiType, bool, String)? onChanged;
|
||||
void Function(AvailableModelsPB)? onAvailableModelsChanged;
|
||||
String hintText() {
|
||||
final aiType = getCurrentAiType();
|
||||
if (aiType.isLocal) {
|
||||
return isEditable()
|
||||
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
|
||||
: LocaleKeys.settings_aiPage_keys_localAIInitializing.tr();
|
||||
}
|
||||
return LocaleKeys.chat_inputMessageHint.tr();
|
||||
}
|
||||
|
||||
AiType getCurrentAiType() {
|
||||
// On non-desktop platforms, always return cloud type
|
||||
if (!_isDesktop) {
|
||||
return AiType.cloud;
|
||||
}
|
||||
|
||||
return _availableModels?.selectedModel.isLocal == true
|
||||
? AiType.local
|
||||
: AiType.cloud;
|
||||
}
|
||||
|
||||
bool isEditable() {
|
||||
// On non-desktop platforms, always editable (cloud-only)
|
||||
if (!_isDesktop) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return getCurrentAiType().isLocal
|
||||
? _localAIState?.state == RunningStatePB.Running
|
||||
: true;
|
||||
}
|
||||
|
||||
void _notifyStateChanged() {
|
||||
onChanged?.call(getCurrentAiType(), isEditable(), hintText());
|
||||
}
|
||||
|
||||
Future<void> init() async {
|
||||
await _loadAvailableModels();
|
||||
}
|
||||
|
||||
Future<void> _loadAvailableModels() async {
|
||||
final payload = AvailableModelsQueryPB(source: objectId);
|
||||
final result = await AIEventGetAvailableModels(payload).send();
|
||||
result.fold(
|
||||
(models) {
|
||||
_availableModels = models;
|
||||
onAvailableModelsChanged?.call(models);
|
||||
_notifyStateChanged();
|
||||
},
|
||||
(err) {
|
||||
Log.error("Failed to get available models: $err");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
void startListening({
|
||||
void Function(AiType, bool, String)? onChanged,
|
||||
void Function(AvailableModelsPB)? onAvailableModelsChanged,
|
||||
}) {
|
||||
this.onChanged = onChanged;
|
||||
this.onAvailableModelsChanged = onAvailableModelsChanged;
|
||||
|
||||
// Only start local AI listener on desktop platforms
|
||||
if (_isDesktop) {
|
||||
_localAIListener?.start(
|
||||
stateCallback: (state) {
|
||||
_localAIState = state;
|
||||
|
||||
if (state.state == RunningStatePB.Running ||
|
||||
state.state == RunningStatePB.Stopped) {
|
||||
_loadAvailableModels();
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
_aiModelSwitchListener.start(
|
||||
onUpdateSelectedModel: (model) {
|
||||
if (_availableModels != null) {
|
||||
final updatedModels = _availableModels!.deepCopy()
|
||||
..selectedModel = model;
|
||||
_availableModels = updatedModels;
|
||||
onAvailableModelsChanged?.call(updatedModels);
|
||||
}
|
||||
|
||||
if (model.isLocal && _isDesktop) {
|
||||
AIEventGetLocalAIState().send().fold(
|
||||
(localAIState) {
|
||||
_localAIState = localAIState;
|
||||
_notifyStateChanged();
|
||||
},
|
||||
(error) {
|
||||
Log.error("Failed to get local AI state: $error");
|
||||
_notifyStateChanged();
|
||||
},
|
||||
);
|
||||
} else {
|
||||
_notifyStateChanged();
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Future<void> stop() async {
|
||||
onChanged = null;
|
||||
await _localAIListener?.stop();
|
||||
await _aiModelSwitchListener.stop();
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,8 @@
|
||||
import 'dart:async';
|
||||
|
||||
import 'package:appflowy/generated/locale_keys.g.dart';
|
||||
import 'package:appflowy/ai/service/ai_input_control.dart';
|
||||
import 'package:appflowy/plugins/ai_chat/application/chat_entity.dart';
|
||||
import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart';
|
||||
import 'package:appflowy_backend/dispatch/dispatch.dart';
|
||||
import 'package:appflowy_backend/log.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
|
||||
import 'package:appflowy_result/appflowy_result.dart';
|
||||
import 'package:easy_localization/easy_localization.dart';
|
||||
import 'package:flutter_bloc/flutter_bloc.dart';
|
||||
import 'package:freezed_annotation/freezed_annotation.dart';
|
||||
|
||||
@ -18,19 +12,19 @@ part 'ai_prompt_input_bloc.freezed.dart';
|
||||
|
||||
class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
|
||||
AIPromptInputBloc({
|
||||
required String objectId,
|
||||
required PredefinedFormat? predefinedFormat,
|
||||
}) : _listener = LocalAIStateListener(),
|
||||
super(AIPromptInputState.initial(predefinedFormat)) {
|
||||
}) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId),
|
||||
super(AIPromptInputState.initial(objectId, predefinedFormat)) {
|
||||
_dispatch();
|
||||
_startListening();
|
||||
_init();
|
||||
}
|
||||
|
||||
final LocalAIStateListener _listener;
|
||||
final AIModelStateNotifier _aiModelStateNotifier;
|
||||
|
||||
@override
|
||||
Future<void> close() async {
|
||||
await _listener.stop();
|
||||
await _aiModelStateNotifier.stop();
|
||||
return super.close();
|
||||
}
|
||||
|
||||
@ -38,29 +32,11 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
|
||||
on<AIPromptInputEvent>(
|
||||
(event, emit) {
|
||||
event.when(
|
||||
updateAIState: (localAIState) {
|
||||
final aiType = localAIState.enabled ? AiType.local : AiType.cloud;
|
||||
// final supportChatWithFile =
|
||||
// aiType.isLocal && localAIState.state == RunningStatePB.Running;
|
||||
// If local ai is enabled, user can only send messages when the AI is running
|
||||
final editable = localAIState.enabled
|
||||
? localAIState.state == RunningStatePB.Running
|
||||
: true;
|
||||
|
||||
var hintText = aiType.isLocal
|
||||
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
|
||||
: LocaleKeys.chat_inputMessageHint.tr();
|
||||
|
||||
if (editable == false && aiType.isLocal) {
|
||||
hintText =
|
||||
LocaleKeys.settings_aiPage_keys_localAIInitializing.tr();
|
||||
}
|
||||
|
||||
updateAIState: (aiType, editable, hintText) {
|
||||
emit(
|
||||
state.copyWith(
|
||||
aiType: aiType,
|
||||
supportChatWithFile: false,
|
||||
localAIState: localAIState,
|
||||
editable: editable,
|
||||
hintText: hintText,
|
||||
),
|
||||
@ -127,25 +103,16 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
|
||||
);
|
||||
}
|
||||
|
||||
void _startListening() {
|
||||
_listener.start(
|
||||
stateCallback: (pluginState) {
|
||||
if (!isClosed) {
|
||||
add(AIPromptInputEvent.updateAIState(pluginState));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
void _init() {
|
||||
AIEventGetLocalAIState().send().fold(
|
||||
(localAIState) {
|
||||
_aiModelStateNotifier.startListening(
|
||||
onChanged: (aiType, editable, hintText) {
|
||||
if (!isClosed) {
|
||||
add(AIPromptInputEvent.updateAIState(localAIState));
|
||||
add(AIPromptInputEvent.updateAIState(aiType, editable, hintText));
|
||||
}
|
||||
},
|
||||
Log.error,
|
||||
);
|
||||
|
||||
_aiModelStateNotifier.init();
|
||||
}
|
||||
|
||||
Map<String, dynamic> consumeMetadata() {
|
||||
@ -164,8 +131,12 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
|
||||
|
||||
@freezed
|
||||
class AIPromptInputEvent with _$AIPromptInputEvent {
|
||||
const factory AIPromptInputEvent.updateAIState(LocalAIPB localAIState) =
|
||||
_UpdateAIState;
|
||||
const factory AIPromptInputEvent.updateAIState(
|
||||
AiType aiType,
|
||||
bool editable,
|
||||
String hintText,
|
||||
) = _UpdateAIState;
|
||||
|
||||
const factory AIPromptInputEvent.toggleShowPredefinedFormat() =
|
||||
_ToggleShowPredefinedFormat;
|
||||
const factory AIPromptInputEvent.updatePredefinedFormat(
|
||||
@ -184,24 +155,27 @@ class AIPromptInputEvent with _$AIPromptInputEvent {
|
||||
@freezed
|
||||
class AIPromptInputState with _$AIPromptInputState {
|
||||
const factory AIPromptInputState({
|
||||
required String objectId,
|
||||
required AiType aiType,
|
||||
required bool supportChatWithFile,
|
||||
required bool showPredefinedFormats,
|
||||
required PredefinedFormat? predefinedFormat,
|
||||
required LocalAIPB? localAIState,
|
||||
required List<ChatFile> attachedFiles,
|
||||
required List<ViewPB> mentionedPages,
|
||||
required bool editable,
|
||||
required String hintText,
|
||||
}) = _AIPromptInputState;
|
||||
|
||||
factory AIPromptInputState.initial(PredefinedFormat? format) =>
|
||||
factory AIPromptInputState.initial(
|
||||
String objectId,
|
||||
PredefinedFormat? format,
|
||||
) =>
|
||||
AIPromptInputState(
|
||||
objectId: objectId,
|
||||
aiType: AiType.cloud,
|
||||
supportChatWithFile: false,
|
||||
showPredefinedFormats: format != null,
|
||||
predefinedFormat: format,
|
||||
localAIState: null,
|
||||
attachedFiles: [],
|
||||
mentionedPages: [],
|
||||
editable: true,
|
||||
|
||||
@ -0,0 +1,82 @@
|
||||
import 'dart:async';
|
||||
|
||||
import 'package:appflowy/ai/service/ai_input_control.dart';
|
||||
import 'package:appflowy_backend/dispatch/dispatch.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pbserver.dart';
|
||||
import 'package:flutter_bloc/flutter_bloc.dart';
|
||||
import 'package:freezed_annotation/freezed_annotation.dart';
|
||||
import 'package:protobuf/protobuf.dart';
|
||||
|
||||
part 'select_model_bloc.freezed.dart';
|
||||
|
||||
class SelectModelBloc extends Bloc<SelectModelEvent, SelectModelState> {
|
||||
SelectModelBloc({
|
||||
required this.objectId,
|
||||
}) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId),
|
||||
super(const SelectModelState()) {
|
||||
_aiModelStateNotifier.init();
|
||||
_aiModelStateNotifier.startListening(
|
||||
onAvailableModelsChanged: (models) {
|
||||
if (!isClosed) {
|
||||
add(SelectModelEvent.didLoadModels(models));
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
on<SelectModelEvent>(
|
||||
(event, emit) async {
|
||||
await event.when(
|
||||
selectModel: (AIModelPB model) async {
|
||||
await AIEventUpdateSelectedModel(
|
||||
UpdateSelectedModelPB(
|
||||
source: objectId,
|
||||
selectedModel: model,
|
||||
),
|
||||
).send();
|
||||
|
||||
state.availableModels?.freeze();
|
||||
final newAvailableModels = state.availableModels?.rebuild((m) {
|
||||
m.selectedModel = model;
|
||||
});
|
||||
|
||||
emit(
|
||||
state.copyWith(
|
||||
availableModels: newAvailableModels,
|
||||
),
|
||||
);
|
||||
},
|
||||
didLoadModels: (AvailableModelsPB models) {
|
||||
emit(state.copyWith(availableModels: models));
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
final String objectId;
|
||||
final AIModelStateNotifier _aiModelStateNotifier;
|
||||
|
||||
@override
|
||||
Future<void> close() async {
|
||||
await _aiModelStateNotifier.stop();
|
||||
await super.close();
|
||||
}
|
||||
}
|
||||
|
||||
@freezed
|
||||
class SelectModelEvent with _$SelectModelEvent {
|
||||
const factory SelectModelEvent.selectModel(
|
||||
AIModelPB model,
|
||||
) = _SelectModel;
|
||||
|
||||
const factory SelectModelEvent.didLoadModels(
|
||||
AvailableModelsPB models,
|
||||
) = _DidLoadModels;
|
||||
}
|
||||
|
||||
@freezed
|
||||
class SelectModelState with _$SelectModelState {
|
||||
const factory SelectModelState({
|
||||
AvailableModelsPB? availableModels,
|
||||
}) = _SelectModelState;
|
||||
}
|
||||
@ -1,14 +1,17 @@
|
||||
import 'package:appflowy/ai/ai.dart';
|
||||
import 'package:appflowy/ai/service/select_model_bloc.dart';
|
||||
import 'package:appflowy/generated/locale_keys.g.dart';
|
||||
import 'package:appflowy/plugins/ai_chat/application/chat_input_control_cubit.dart';
|
||||
import 'package:appflowy/plugins/ai_chat/presentation/layout_define.dart';
|
||||
import 'package:appflowy/startup/startup.dart';
|
||||
import 'package:appflowy/util/theme_extension.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
|
||||
import 'package:easy_localization/easy_localization.dart';
|
||||
import 'package:extended_text_field/extended_text_field.dart';
|
||||
import 'package:flowy_infra/file_picker/file_picker_service.dart';
|
||||
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
|
||||
import 'package:flowy_infra_ui/style_widget/hover.dart';
|
||||
import 'package:flutter/material.dart';
|
||||
import 'package:flutter/services.dart';
|
||||
import 'package:flutter_bloc/flutter_bloc.dart';
|
||||
@ -165,6 +168,7 @@ class _DesktopPromptInputState extends State<DesktopPromptInput> {
|
||||
top: null,
|
||||
child: TextFieldTapRegion(
|
||||
child: _PromptBottomActions(
|
||||
objectId: state.objectId,
|
||||
showPredefinedFormats:
|
||||
state.showPredefinedFormats,
|
||||
onTogglePredefinedFormatSection: () =>
|
||||
@ -561,6 +565,7 @@ class PromptInputTextField extends StatelessWidget {
|
||||
|
||||
class _PromptBottomActions extends StatelessWidget {
|
||||
const _PromptBottomActions({
|
||||
required this.objectId,
|
||||
required this.sendButtonState,
|
||||
required this.showPredefinedFormats,
|
||||
required this.onTogglePredefinedFormatSection,
|
||||
@ -572,6 +577,7 @@ class _PromptBottomActions extends StatelessWidget {
|
||||
this.extraBottomActionButton,
|
||||
});
|
||||
|
||||
final String objectId;
|
||||
final bool showPredefinedFormats;
|
||||
final void Function() onTogglePredefinedFormatSection;
|
||||
final void Function() onStartMention;
|
||||
@ -589,15 +595,10 @@ class _PromptBottomActions extends StatelessWidget {
|
||||
margin: DesktopAIChatSizes.inputActionBarMargin,
|
||||
child: BlocBuilder<AIPromptInputBloc, AIPromptInputState>(
|
||||
builder: (context, state) {
|
||||
if (state.localAIState == null) {
|
||||
return Align(
|
||||
alignment: AlignmentDirectional.centerEnd,
|
||||
child: _sendButton(),
|
||||
);
|
||||
}
|
||||
return Row(
|
||||
children: [
|
||||
_predefinedFormatButton(),
|
||||
SelectModelButton(objectId: objectId),
|
||||
const Spacer(),
|
||||
if (state.aiType.isCloud) ...[
|
||||
_selectSourcesButton(context),
|
||||
@ -683,3 +684,160 @@ class _PromptBottomActions extends StatelessWidget {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class SelectModelButton extends StatefulWidget {
|
||||
const SelectModelButton({
|
||||
super.key,
|
||||
required this.objectId,
|
||||
});
|
||||
|
||||
final String objectId;
|
||||
|
||||
@override
|
||||
State<SelectModelButton> createState() => _SelectModelButtonState();
|
||||
}
|
||||
|
||||
class _SelectModelButtonState extends State<SelectModelButton> {
|
||||
final popoverController = PopoverController();
|
||||
late SelectModelBloc bloc;
|
||||
|
||||
@override
|
||||
void initState() {
|
||||
super.initState();
|
||||
bloc = SelectModelBloc(objectId: widget.objectId);
|
||||
}
|
||||
|
||||
@override
|
||||
void dispose() {
|
||||
popoverController.close();
|
||||
bloc.close();
|
||||
super.dispose();
|
||||
}
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
return BlocProvider.value(
|
||||
value: bloc,
|
||||
child: BlocBuilder<SelectModelBloc, SelectModelState>(
|
||||
builder: (context, state) {
|
||||
return AppFlowyPopover(
|
||||
// constraints: BoxConstraints.loose(const Size(250, 200)),
|
||||
offset: const Offset(0.0, -10.0),
|
||||
direction: PopoverDirection.topWithLeftAligned,
|
||||
margin: EdgeInsets.zero,
|
||||
controller: popoverController,
|
||||
onOpen: () {},
|
||||
onClose: () {},
|
||||
popupBuilder: (_) {
|
||||
return BlocProvider.value(
|
||||
value: bloc,
|
||||
child: _PopoverSelectModel(
|
||||
onClose: () => popoverController.close(),
|
||||
),
|
||||
);
|
||||
},
|
||||
child: _CurrentModelButton(
|
||||
modelName: state.availableModels?.selectedModel.name ?? "",
|
||||
onTap: () => popoverController.show(),
|
||||
),
|
||||
);
|
||||
},
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class _PopoverSelectModel extends StatelessWidget {
|
||||
const _PopoverSelectModel({
|
||||
required this.onClose,
|
||||
});
|
||||
|
||||
final VoidCallback onClose;
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
return BlocBuilder<SelectModelBloc, SelectModelState>(
|
||||
builder: (context, state) {
|
||||
return ListView.builder(
|
||||
shrinkWrap: true,
|
||||
itemCount: state.availableModels?.models.length ?? 0,
|
||||
padding: const EdgeInsets.fromLTRB(8, 4, 8, 12),
|
||||
itemBuilder: (context, index) {
|
||||
return _ModelItem(
|
||||
model: state.availableModels!.models[index],
|
||||
onTap: () {
|
||||
context.read<SelectModelBloc>().add(
|
||||
SelectModelEvent.selectModel(
|
||||
state.availableModels!.models[index],
|
||||
),
|
||||
);
|
||||
onClose();
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class _ModelItem extends StatelessWidget {
|
||||
const _ModelItem({
|
||||
required this.model,
|
||||
required this.onTap,
|
||||
});
|
||||
|
||||
final AIModelPB model;
|
||||
final VoidCallback onTap;
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
var modelName = model.name;
|
||||
if (model.isLocal) {
|
||||
modelName += " (${LocaleKeys.chat_changeFormat_localModel.tr()})";
|
||||
}
|
||||
return FlowyTextButton(
|
||||
modelName,
|
||||
fillColor: Colors.transparent,
|
||||
onPressed: onTap,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class _CurrentModelButton extends StatelessWidget {
|
||||
const _CurrentModelButton({
|
||||
required this.modelName,
|
||||
required this.onTap,
|
||||
});
|
||||
|
||||
final String modelName;
|
||||
final VoidCallback onTap;
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
return FlowyTooltip(
|
||||
message: LocaleKeys.chat_changeFormat_switchModel.tr(),
|
||||
child: GestureDetector(
|
||||
onTap: onTap,
|
||||
behavior: HitTestBehavior.opaque,
|
||||
child: SizedBox(
|
||||
height: DesktopAIPromptSizes.actionBarButtonSize,
|
||||
child: FlowyHover(
|
||||
style: const HoverStyle(
|
||||
borderRadius: BorderRadius.all(Radius.circular(8)),
|
||||
),
|
||||
child: Padding(
|
||||
padding: const EdgeInsetsDirectional.fromSTEB(6, 6, 4, 6),
|
||||
child: FlowyText(
|
||||
modelName,
|
||||
fontSize: 12,
|
||||
figmaLineHeight: 16,
|
||||
color: Theme.of(context).hintColor,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
import 'dart:async';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:appflowy/plugins/ai_chat/application/chat_notification.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-ai/notification.pb.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-error/errors.pb.dart';
|
||||
import 'package:appflowy_backend/protobuf/flowy-notification/subject.pb.dart';
|
||||
import 'package:appflowy_backend/rust_stream.dart';
|
||||
import 'package:appflowy_result/appflowy_result.dart';
|
||||
|
||||
typedef OnUpdateSelectedModel = void Function(AIModelPB model);
|
||||
|
||||
class AIModelSwitchListener {
|
||||
AIModelSwitchListener({required this.chatId}) {
|
||||
_parser = ChatNotificationParser(id: chatId, callback: _callback);
|
||||
_subscription = RustStreamReceiver.listen(
|
||||
(observable) => _parser?.parse(observable),
|
||||
);
|
||||
}
|
||||
|
||||
final String chatId;
|
||||
StreamSubscription<SubscribeObject>? _subscription;
|
||||
ChatNotificationParser? _parser;
|
||||
|
||||
void start({
|
||||
OnUpdateSelectedModel? onUpdateSelectedModel,
|
||||
}) {
|
||||
this.onUpdateSelectedModel = onUpdateSelectedModel;
|
||||
}
|
||||
|
||||
OnUpdateSelectedModel? onUpdateSelectedModel;
|
||||
|
||||
void _callback(
|
||||
ChatNotification ty,
|
||||
FlowyResult<Uint8List, FlowyError> result,
|
||||
) {
|
||||
result.map((r) {
|
||||
switch (ty) {
|
||||
case ChatNotification.DidUpdateSelectedModel:
|
||||
onUpdateSelectedModel?.call(AIModelPB.fromBuffer(r));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Future<void> stop() async {
|
||||
await _subscription?.cancel();
|
||||
_subscription = null;
|
||||
}
|
||||
}
|
||||
@ -73,6 +73,7 @@ class AIChatPage extends StatelessWidget {
|
||||
/// [AIPromptInputBloc] is used to handle the user prompt
|
||||
BlocProvider(
|
||||
create: (_) => AIPromptInputBloc(
|
||||
objectId: view.id,
|
||||
predefinedFormat: PredefinedFormat(
|
||||
imageFormat: ImageFormat.text,
|
||||
textFormat: TextFormat.bulletList,
|
||||
|
||||
@ -2,6 +2,7 @@ import 'package:appflowy/ai/ai.dart';
|
||||
import 'package:appflowy/generated/flowy_svgs.g.dart';
|
||||
import 'package:appflowy/generated/locale_keys.g.dart';
|
||||
import 'package:appflowy/plugins/ai_chat/presentation/message/ai_markdown_text.dart';
|
||||
import 'package:appflowy/plugins/document/application/document_bloc.dart';
|
||||
import 'package:appflowy/util/theme_extension.dart';
|
||||
import 'package:appflowy/workspace/application/view/view_bloc.dart';
|
||||
import 'package:appflowy_editor/appflowy_editor.dart';
|
||||
@ -124,9 +125,12 @@ class _AIWriterBlockComponentState extends State<AiWriterBlockComponent> {
|
||||
return const SizedBox.shrink();
|
||||
}
|
||||
|
||||
final documentId = context.read<DocumentBloc?>()?.documentId;
|
||||
|
||||
return BlocProvider(
|
||||
create: (_) => AIPromptInputBloc(
|
||||
predefinedFormat: null,
|
||||
objectId: documentId ?? editorState.document.root.id,
|
||||
),
|
||||
child: LayoutBuilder(
|
||||
builder: (context, constraints) {
|
||||
|
||||
@ -173,7 +173,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
|
||||
}
|
||||
|
||||
void _loadModelList() {
|
||||
AIEventGetAvailableModels().send().then((result) {
|
||||
AIEventGetServerAvailableModels().send().then((result) {
|
||||
result.fold((config) {
|
||||
if (!isClosed) {
|
||||
add(SettingsAIEvent.didLoadAvailableModels(config.models));
|
||||
|
||||
@ -247,6 +247,8 @@
|
||||
"table": "Table",
|
||||
"blankDescription": "Format response",
|
||||
"defaultDescription": "Auto mode",
|
||||
"localModel": "Local Model",
|
||||
"switchModel": "Switch model",
|
||||
"textWithImageDescription": "@:chat.changeFormat.text with image",
|
||||
"numberWithImageDescription": "@:chat.changeFormat.number with image",
|
||||
"bulletWithImageDescription": "@:chat.changeFormat.bullet with image",
|
||||
@ -3191,4 +3193,4 @@
|
||||
"rewrite": "Rewrite",
|
||||
"insertBelow": "Insert below"
|
||||
}
|
||||
}
|
||||
}
|
||||
43
frontend/rust-lib/Cargo.lock
generated
43
frontend/rust-lib/Cargo.lock
generated
@ -163,7 +163,7 @@ checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7"
|
||||
[[package]]
|
||||
name = "app-error"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -183,7 +183,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "appflowy-ai-client"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
@ -788,7 +788,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "client-api"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"again",
|
||||
"anyhow",
|
||||
@ -843,7 +843,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "client-api-entity"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"collab-entity",
|
||||
"collab-rt-entity",
|
||||
@ -856,7 +856,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "client-websocket"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
@ -1129,7 +1129,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "collab-rt-entity"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -1151,7 +1151,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "collab-rt-protocol"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
@ -1398,7 +1398,7 @@ dependencies = [
|
||||
"cssparser-macros",
|
||||
"dtoa-short",
|
||||
"itoa",
|
||||
"phf 0.11.2",
|
||||
"phf 0.8.0",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
@ -1546,7 +1546,7 @@ checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308"
|
||||
[[package]]
|
||||
name = "database-entity"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"bytes",
|
||||
@ -2072,6 +2072,7 @@ dependencies = [
|
||||
"flowy-error",
|
||||
"futures",
|
||||
"lib-infra",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
@ -2979,7 +2980,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "gotrue"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"getrandom 0.2.10",
|
||||
@ -2994,7 +2995,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "gotrue-entity"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"app-error",
|
||||
"jsonwebtoken",
|
||||
@ -3609,7 +3610,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "infra"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
@ -4646,7 +4647,7 @@ version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3dfb61232e34fcb633f43d12c58f83c1df82962dcdfa565a4e866ffc17dafe12"
|
||||
dependencies = [
|
||||
"phf_macros 0.8.0",
|
||||
"phf_macros",
|
||||
"phf_shared 0.8.0",
|
||||
"proc-macro-hack",
|
||||
]
|
||||
@ -4666,7 +4667,6 @@ version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
|
||||
dependencies = [
|
||||
"phf_macros 0.11.3",
|
||||
"phf_shared 0.11.2",
|
||||
]
|
||||
|
||||
@ -4734,19 +4734,6 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_macros"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216"
|
||||
dependencies = [
|
||||
"phf_generator 0.11.2",
|
||||
"phf_shared 0.11.2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.94",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_shared"
|
||||
version = "0.8.0"
|
||||
@ -6176,7 +6163,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "shared-entity"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
|
||||
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"app-error",
|
||||
|
||||
@ -103,8 +103,8 @@ dashmap = "6.0.1"
|
||||
# Run the script.add_workspace_members:
|
||||
# scripts/tool/update_client_api_rev.sh new_rev_id
|
||||
# ⚠️⚠️⚠️️
|
||||
client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "1d06321789482dd31a44d515eefa06e00871d8ad" }
|
||||
client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "1d06321789482dd31a44d515eefa06e00871d8ad" }
|
||||
client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "d77dacae32bc25440ca61f675900b80fba6cc9a2" }
|
||||
client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "d77dacae32bc25440ca61f675900b80fba6cc9a2" }
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 0
|
||||
|
||||
@ -153,8 +153,8 @@ pub fn parse_event_crate(event_crate: &TsEventCrate) -> Vec<EventASTContext> {
|
||||
attrs
|
||||
.iter()
|
||||
.filter(|attr| !attr.attrs.event_attrs.ignore)
|
||||
.enumerate()
|
||||
.map(|(_index, variant)| EventASTContext::from(&variant.attrs))
|
||||
.into_iter()
|
||||
.map(|variant| EventASTContext::from(&variant.attrs))
|
||||
.collect::<Vec<_>>()
|
||||
},
|
||||
_ => vec![],
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
use crate::event_builder::EventBuilder;
|
||||
use crate::EventIntegrationTest;
|
||||
use flowy_ai::entities::{
|
||||
ChatMessageListPB, ChatMessageTypePB, CompleteTextPB, CompleteTextTaskPB, CompletionTypePB,
|
||||
LoadNextChatMessagePB, LoadPrevChatMessagePB, SendChatPayloadPB,
|
||||
ChatMessageListPB, ChatMessageTypePB, LoadNextChatMessagePB, LoadPrevChatMessagePB,
|
||||
SendChatPayloadPB,
|
||||
};
|
||||
use flowy_ai::event_map::AIEvent;
|
||||
use flowy_folder::entities::{CreateViewPayloadPB, ViewLayoutPB, ViewPB};
|
||||
@ -87,27 +87,4 @@ impl EventIntegrationTest {
|
||||
.await
|
||||
.parse::<ChatMessageListPB>()
|
||||
}
|
||||
|
||||
pub async fn complete_text(
|
||||
&self,
|
||||
text: &str,
|
||||
completion_type: CompletionTypePB,
|
||||
) -> CompleteTextTaskPB {
|
||||
let payload = CompleteTextPB {
|
||||
text: text.to_string(),
|
||||
completion_type,
|
||||
stream_port: 0,
|
||||
object_id: "".to_string(),
|
||||
rag_ids: vec![],
|
||||
format: None,
|
||||
history: vec![],
|
||||
custom_prompt: None,
|
||||
};
|
||||
EventBuilder::new(self.clone())
|
||||
.event(AIEvent::CompleteText)
|
||||
.payload(payload)
|
||||
.async_send()
|
||||
.await
|
||||
.parse::<CompleteTextTaskPB>()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
use event_integration_test::user_event::use_localhost_af_cloud;
|
||||
use event_integration_test::EventIntegrationTest;
|
||||
use flowy_ai::entities::CompletionTypePB;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn af_cloud_complete_text_test() {
|
||||
use_localhost_af_cloud().await;
|
||||
let test = EventIntegrationTest::new().await;
|
||||
test.af_cloud_sign_up().await;
|
||||
|
||||
let _workspace_id = test.get_current_workspace().await.id;
|
||||
let _task = test
|
||||
.complete_text("hello world", CompletionTypePB::MakeLonger)
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(6)).await;
|
||||
}
|
||||
@ -1,2 +1 @@
|
||||
mod ai_tool_test;
|
||||
mod chat_message_test;
|
||||
|
||||
@ -12,3 +12,4 @@ client-api = { workspace = true }
|
||||
bytes.workspace = true
|
||||
futures.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
@ -14,6 +14,7 @@ pub use client_api::error::{AppResponseError, ErrorCode as AppErrorCode};
|
||||
use flowy_error::FlowyError;
|
||||
use futures::stream::BoxStream;
|
||||
use lib_infra::async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
@ -21,6 +22,22 @@ use std::path::Path;
|
||||
pub type ChatMessageStream = BoxStream<'static, Result<ChatMessage, AppResponseError>>;
|
||||
pub type StreamAnswer = BoxStream<'static, Result<QuestionStreamValue, FlowyError>>;
|
||||
pub type StreamComplete = BoxStream<'static, Result<CompletionStreamValue, FlowyError>>;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone)]
|
||||
pub struct AIModel {
|
||||
pub name: String,
|
||||
pub is_local: bool,
|
||||
}
|
||||
|
||||
impl Default for AIModel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
name: "default".to_string(),
|
||||
is_local: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ChatCloudService: Send + Sync + 'static {
|
||||
async fn create_chat(
|
||||
@ -55,6 +72,7 @@ pub trait ChatCloudService: Send + Sync + 'static {
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
format: ResponseFormat,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamAnswer, FlowyError>;
|
||||
|
||||
async fn get_answer(
|
||||
@ -90,6 +108,7 @@ pub trait ChatCloudService: Send + Sync + 'static {
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
params: CompleteTextParams,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamComplete, FlowyError>;
|
||||
|
||||
async fn embed_file(
|
||||
@ -121,4 +140,5 @@ pub trait ChatCloudService: Send + Sync + 'static {
|
||||
) -> Result<(), FlowyError>;
|
||||
|
||||
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError>;
|
||||
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError>;
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
use crate::chat::Chat;
|
||||
use crate::entities::{
|
||||
ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB, FilePB, PredefinedFormatPB,
|
||||
RepeatedRelatedQuestionPB, StreamMessageParams,
|
||||
AIModelPB, AvailableModelsPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB,
|
||||
FilePB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams,
|
||||
};
|
||||
use crate::local_ai::controller::LocalAIController;
|
||||
use crate::middleware::chat_service_mw::AICloudServiceMiddleware;
|
||||
@ -10,12 +10,13 @@ use std::collections::HashMap;
|
||||
|
||||
use appflowy_plugin::manager::PluginManager;
|
||||
use dashmap::DashMap;
|
||||
use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, ModelList, UpdateChatParams};
|
||||
use flowy_ai_pub::cloud::{AIModel, ChatCloudService, ChatSettings, UpdateChatParams};
|
||||
use flowy_error::{FlowyError, FlowyResult};
|
||||
use flowy_sqlite::kv::KVStorePreferences;
|
||||
use flowy_sqlite::DBConnection;
|
||||
|
||||
use crate::notification::{chat_notification_builder, ChatNotification};
|
||||
use crate::util::ai_available_models_key;
|
||||
use collab_integrate::persistence::collab_metadata_sql::{
|
||||
batch_insert_collab_metadata, batch_select_collab_metadata, AFCollabMetadata,
|
||||
};
|
||||
@ -24,6 +25,7 @@ use lib_infra::async_trait::async_trait;
|
||||
use lib_infra::util::timestamp;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{error, info, trace};
|
||||
|
||||
pub trait AIUserService: Send + Sync + 'static {
|
||||
@ -59,7 +61,8 @@ pub struct AIManager {
|
||||
pub external_service: Arc<dyn AIExternalService>,
|
||||
chats: Arc<DashMap<String, Arc<Chat>>>,
|
||||
pub local_ai: Arc<LocalAIController>,
|
||||
store_preferences: Arc<KVStorePreferences>,
|
||||
pub store_preferences: Arc<KVStorePreferences>,
|
||||
server_models: Arc<RwLock<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl AIManager {
|
||||
@ -99,6 +102,7 @@ impl AIManager {
|
||||
local_ai,
|
||||
external_service,
|
||||
store_preferences,
|
||||
server_models: Arc::new(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,6 +118,7 @@ impl AIManager {
|
||||
chat_id.to_string(),
|
||||
self.user_service.clone(),
|
||||
self.cloud_service_wm.clone(),
|
||||
self.store_preferences.clone(),
|
||||
))
|
||||
});
|
||||
if self.local_ai.is_running() {
|
||||
@ -205,24 +210,25 @@ impl AIManager {
|
||||
save_chat(self.user_service.sqlite_connection(*uid)?, chat_id)?;
|
||||
|
||||
let chat = Arc::new(Chat::new(
|
||||
self.user_service.user_id().unwrap(),
|
||||
self.user_service.user_id()?,
|
||||
chat_id.to_string(),
|
||||
self.user_service.clone(),
|
||||
self.cloud_service_wm.clone(),
|
||||
self.store_preferences.clone(),
|
||||
));
|
||||
self.chats.insert(chat_id.to_string(), chat.clone());
|
||||
Ok(chat)
|
||||
}
|
||||
|
||||
pub async fn stream_chat_message<'a>(
|
||||
&'a self,
|
||||
params: &'a StreamMessageParams<'a>,
|
||||
pub async fn stream_chat_message(
|
||||
&self,
|
||||
params: StreamMessageParams,
|
||||
) -> Result<ChatMessagePB, FlowyError> {
|
||||
let chat = self.get_or_create_chat_instance(params.chat_id).await?;
|
||||
let question = chat.stream_chat_message(params).await?;
|
||||
let chat = self.get_or_create_chat_instance(¶ms.chat_id).await?;
|
||||
let question = chat.stream_chat_message(¶ms).await?;
|
||||
let _ = self
|
||||
.external_service
|
||||
.notify_did_send_message(params.chat_id, params.message)
|
||||
.notify_did_send_message(¶ms.chat_id, ¶ms.message)
|
||||
.await;
|
||||
Ok(question)
|
||||
}
|
||||
@ -238,19 +244,149 @@ impl AIManager {
|
||||
let question_message_id = chat
|
||||
.get_question_id_from_answer_id(answer_message_id)
|
||||
.await?;
|
||||
|
||||
let preferred_model = self
|
||||
.store_preferences
|
||||
.get_object::<AIModel>(&ai_available_models_key(chat_id));
|
||||
chat
|
||||
.stream_regenerate_response(question_message_id, answer_stream_port, format)
|
||||
.stream_regenerate_response(
|
||||
question_message_id,
|
||||
answer_stream_port,
|
||||
format,
|
||||
preferred_model,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_available_models(&self) -> FlowyResult<ModelList> {
|
||||
pub async fn get_workspace_select_model(&self) -> FlowyResult<String> {
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
let model = self
|
||||
.cloud_service_wm
|
||||
.get_workspace_default_model(&workspace_id)
|
||||
.await?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
pub async fn get_server_available_models(&self) -> FlowyResult<Vec<String>> {
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
|
||||
// First, try reading from the cache.
|
||||
{
|
||||
let cached_models = self.server_models.read().await;
|
||||
if !cached_models.is_empty() {
|
||||
return Ok(cached_models.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss: fetch from the cloud.
|
||||
let list = self
|
||||
.cloud_service_wm
|
||||
.get_available_models(&workspace_id)
|
||||
.await?;
|
||||
Ok(list)
|
||||
let models = list
|
||||
.models
|
||||
.into_iter()
|
||||
.map(|m| m.name)
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
// Update the cache.
|
||||
*self.server_models.write().await = models.clone();
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
pub async fn update_selected_model(&self, source: String, model: AIModelPB) -> FlowyResult<()> {
|
||||
let source_key = ai_available_models_key(&source);
|
||||
self
|
||||
.store_preferences
|
||||
.set_object::<AIModel>(&source_key, &model.clone().into())?;
|
||||
|
||||
chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel)
|
||||
.payload(model)
|
||||
.send();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_available_models(&self, source: String) -> FlowyResult<AvailableModelsPB> {
|
||||
// Build the models list from server models and mark them as non-local.
|
||||
let mut models: Vec<AIModelPB> = self
|
||||
.get_server_available_models()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|name| AIModelPB {
|
||||
name,
|
||||
is_local: false,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Optionally add the local plugin model.
|
||||
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
|
||||
models.push(AIModelPB {
|
||||
name: local_model,
|
||||
is_local: true,
|
||||
});
|
||||
}
|
||||
|
||||
if models.is_empty() {
|
||||
return Ok(AvailableModelsPB {
|
||||
models,
|
||||
selected_model: None,
|
||||
});
|
||||
}
|
||||
|
||||
let source_key = ai_available_models_key(&source);
|
||||
|
||||
// Retrieve stored selected model, if any.
|
||||
let stored_selected = self.store_preferences.get_object::<AIModel>(&source_key);
|
||||
|
||||
// Get workspace default model once.
|
||||
let workspace_default = self.get_workspace_select_model().await.ok();
|
||||
|
||||
// Determine the effective selected model.
|
||||
let effective_selected = stored_selected.unwrap_or_else(|| {
|
||||
if let Some(ws_name) = workspace_default.clone() {
|
||||
let model = AIModel {
|
||||
name: ws_name,
|
||||
is_local: false,
|
||||
};
|
||||
// Store the default if not present.
|
||||
let _ = self.store_preferences.set_object(&source_key, &model);
|
||||
model
|
||||
} else {
|
||||
AIModel::default()
|
||||
}
|
||||
});
|
||||
|
||||
// Find a matching model in the available list.
|
||||
let used_model = models
|
||||
.iter()
|
||||
.find(|m| m.name == effective_selected.name)
|
||||
.cloned()
|
||||
.or_else(|| {
|
||||
// If no match, try to use the workspace default if available.
|
||||
if let Some(ws_name) = workspace_default {
|
||||
Some(AIModelPB {
|
||||
name: ws_name,
|
||||
is_local: false,
|
||||
})
|
||||
} else {
|
||||
models.first().cloned()
|
||||
}
|
||||
});
|
||||
|
||||
// Update the stored preference if a different model is used.
|
||||
if let Some(ref used) = used_model {
|
||||
if used.name != effective_selected.name {
|
||||
self
|
||||
.store_preferences
|
||||
.set_object::<AIModel>(&source_key, &AIModel::from(used.clone()))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AvailableModelsPB {
|
||||
models,
|
||||
selected_model: used_model,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_or_create_chat_instance(&self, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
|
||||
@ -258,10 +394,11 @@ impl AIManager {
|
||||
match chat {
|
||||
None => {
|
||||
let chat = Arc::new(Chat::new(
|
||||
self.user_service.user_id().unwrap(),
|
||||
self.user_service.user_id()?,
|
||||
chat_id.to_string(),
|
||||
self.user_service.clone(),
|
||||
self.cloud_service_wm.clone(),
|
||||
self.store_preferences.clone(),
|
||||
));
|
||||
self.chats.insert(chat_id.to_string(), chat.clone());
|
||||
Ok(chat)
|
||||
@ -363,7 +500,6 @@ impl AIManager {
|
||||
|
||||
pub async fn update_rag_ids(&self, chat_id: &str, rag_ids: Vec<String>) -> FlowyResult<()> {
|
||||
info!("[Chat] update chat:{} rag ids: {:?}", chat_id, rag_ids);
|
||||
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
let update_setting = UpdateChatParams {
|
||||
name: None,
|
||||
|
||||
@ -10,11 +10,13 @@ use crate::persistence::{
|
||||
ChatMessageTable,
|
||||
};
|
||||
use crate::stream_message::StreamMessage;
|
||||
use crate::util::ai_available_models_key;
|
||||
use allo_isolate::Isolate;
|
||||
use flowy_ai_pub::cloud::{
|
||||
ChatCloudService, ChatMessage, MessageCursor, QuestionStreamValue, ResponseFormat,
|
||||
AIModel, ChatCloudService, ChatMessage, MessageCursor, QuestionStreamValue, ResponseFormat,
|
||||
};
|
||||
use flowy_error::{ErrorCode, FlowyError, FlowyResult};
|
||||
use flowy_sqlite::kv::KVStorePreferences;
|
||||
use flowy_sqlite::DBConnection;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use lib_infra::isolate_stream::IsolateSink;
|
||||
@ -39,6 +41,7 @@ pub struct Chat {
|
||||
latest_message_id: Arc<AtomicI64>,
|
||||
stop_stream: Arc<AtomicBool>,
|
||||
stream_buffer: Arc<Mutex<StringBuffer>>,
|
||||
store_preferences: Arc<KVStorePreferences>,
|
||||
}
|
||||
|
||||
impl Chat {
|
||||
@ -47,6 +50,7 @@ impl Chat {
|
||||
chat_id: String,
|
||||
user_service: Arc<dyn AIUserService>,
|
||||
chat_service: Arc<AICloudServiceMiddleware>,
|
||||
store_preferences: Arc<KVStorePreferences>,
|
||||
) -> Chat {
|
||||
Chat {
|
||||
uid,
|
||||
@ -57,6 +61,7 @@ impl Chat {
|
||||
latest_message_id: Default::default(),
|
||||
stop_stream: Arc::new(AtomicBool::new(false)),
|
||||
stream_buffer: Arc::new(Mutex::new(StringBuffer::default())),
|
||||
store_preferences,
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,9 +86,9 @@ impl Chat {
|
||||
}
|
||||
|
||||
#[instrument(level = "info", skip_all, err)]
|
||||
pub async fn stream_chat_message<'a>(
|
||||
&'a self,
|
||||
params: &'a StreamMessageParams<'a>,
|
||||
pub async fn stream_chat_message(
|
||||
&self,
|
||||
params: &StreamMessageParams,
|
||||
) -> Result<ChatMessagePB, FlowyError> {
|
||||
trace!(
|
||||
"[Chat] stream chat message: chat_id={}, message={}, message_type={:?}, metadata={:?}, format={:?}",
|
||||
@ -113,7 +118,7 @@ impl Chat {
|
||||
.create_question(
|
||||
&workspace_id,
|
||||
&self.chat_id,
|
||||
params.message,
|
||||
¶ms.message,
|
||||
params.message_type.clone(),
|
||||
&[],
|
||||
)
|
||||
@ -138,7 +143,9 @@ impl Chat {
|
||||
// Save message to disk
|
||||
save_and_notify_message(uid, &self.chat_id, &self.user_service, question.clone())?;
|
||||
let format = params.format.clone().map(Into::into).unwrap_or_default();
|
||||
|
||||
let preferred_ai_model = self
|
||||
.store_preferences
|
||||
.get_object::<AIModel>(&ai_available_models_key(&self.chat_id));
|
||||
self.stream_response(
|
||||
params.answer_stream_port,
|
||||
answer_stream_buffer,
|
||||
@ -146,6 +153,7 @@ impl Chat {
|
||||
workspace_id,
|
||||
question.message_id,
|
||||
format,
|
||||
preferred_ai_model,
|
||||
);
|
||||
|
||||
let question_pb = ChatMessagePB::from(question);
|
||||
@ -158,6 +166,7 @@ impl Chat {
|
||||
question_id: i64,
|
||||
answer_stream_port: i64,
|
||||
format: Option<PredefinedFormatPB>,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> FlowyResult<()> {
|
||||
trace!(
|
||||
"[Chat] regenerate and stream chat message: chat_id={}",
|
||||
@ -183,11 +192,13 @@ impl Chat {
|
||||
workspace_id,
|
||||
question_id,
|
||||
format,
|
||||
ai_model,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn stream_response(
|
||||
&self,
|
||||
answer_stream_port: i64,
|
||||
@ -196,6 +207,7 @@ impl Chat {
|
||||
workspace_id: String,
|
||||
question_id: i64,
|
||||
format: ResponseFormat,
|
||||
ai_model: Option<AIModel>,
|
||||
) {
|
||||
let stop_stream = self.stop_stream.clone();
|
||||
let chat_id = self.chat_id.clone();
|
||||
@ -204,7 +216,7 @@ impl Chat {
|
||||
tokio::spawn(async move {
|
||||
let mut answer_sink = IsolateSink::new(Isolate::new(answer_stream_port));
|
||||
match cloud_service
|
||||
.stream_answer(&workspace_id, &chat_id, question_id, format)
|
||||
.stream_answer(&workspace_id, &chat_id, question_id, format, ai_model)
|
||||
.await
|
||||
{
|
||||
Ok(mut stream) => {
|
||||
|
||||
@ -4,14 +4,16 @@ use allo_isolate::Isolate;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use flowy_ai_pub::cloud::{
|
||||
ChatCloudService, CompleteTextParams, CompletionMetadata, CompletionStreamValue, CompletionType,
|
||||
CustomPrompt,
|
||||
AIModel, ChatCloudService, CompleteTextParams, CompletionMetadata, CompletionStreamValue,
|
||||
CompletionType, CustomPrompt,
|
||||
};
|
||||
use flowy_error::{FlowyError, FlowyResult};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use lib_infra::isolate_stream::IsolateSink;
|
||||
|
||||
use crate::util::ai_available_models_key;
|
||||
use flowy_sqlite::kv::KVStorePreferences;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tokio::select;
|
||||
use tracing::info;
|
||||
@ -20,17 +22,20 @@ pub struct AICompletion {
|
||||
tasks: Arc<DashMap<String, tokio::sync::mpsc::Sender<()>>>,
|
||||
cloud_service: Weak<dyn ChatCloudService>,
|
||||
user_service: Weak<dyn AIUserService>,
|
||||
store_preferences: Arc<KVStorePreferences>,
|
||||
}
|
||||
|
||||
impl AICompletion {
|
||||
pub fn new(
|
||||
cloud_service: Weak<dyn ChatCloudService>,
|
||||
user_service: Weak<dyn AIUserService>,
|
||||
store_preferences: Arc<KVStorePreferences>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tasks: Arc::new(DashMap::new()),
|
||||
cloud_service,
|
||||
user_service,
|
||||
store_preferences,
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,7 +58,17 @@ impl AICompletion {
|
||||
.ok_or_else(FlowyError::internal)?
|
||||
.workspace_id()?;
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1);
|
||||
let task = CompletionTask::new(workspace_id, complete, self.cloud_service.clone(), rx);
|
||||
let preferred_model = self
|
||||
.store_preferences
|
||||
.get_object::<AIModel>(&ai_available_models_key(&complete.object_id));
|
||||
|
||||
let task = CompletionTask::new(
|
||||
workspace_id,
|
||||
complete,
|
||||
preferred_model,
|
||||
self.cloud_service.clone(),
|
||||
rx,
|
||||
);
|
||||
let task_id = task.task_id.clone();
|
||||
self.tasks.insert(task_id.clone(), tx);
|
||||
|
||||
@ -74,12 +89,14 @@ pub struct CompletionTask {
|
||||
stop_rx: tokio::sync::mpsc::Receiver<()>,
|
||||
context: CompleteTextPB,
|
||||
cloud_service: Weak<dyn ChatCloudService>,
|
||||
preferred_model: Option<AIModel>,
|
||||
}
|
||||
|
||||
impl CompletionTask {
|
||||
pub fn new(
|
||||
workspace_id: String,
|
||||
context: CompleteTextPB,
|
||||
preferred_model: Option<AIModel>,
|
||||
cloud_service: Weak<dyn ChatCloudService>,
|
||||
stop_rx: tokio::sync::mpsc::Receiver<()>,
|
||||
) -> Self {
|
||||
@ -89,6 +106,7 @@ impl CompletionTask {
|
||||
context,
|
||||
cloud_service,
|
||||
stop_rx,
|
||||
preferred_model,
|
||||
}
|
||||
}
|
||||
|
||||
@ -129,7 +147,7 @@ impl CompletionTask {
|
||||
|
||||
info!("start completion: {:?}", params);
|
||||
match cloud_service
|
||||
.stream_complete(&self.workspace_id, params)
|
||||
.stream_complete(&self.workspace_id, params, self.preferred_model)
|
||||
.await
|
||||
{
|
||||
Ok(mut stream) => loop {
|
||||
|
||||
@ -4,8 +4,9 @@ use std::collections::HashMap;
|
||||
use crate::local_ai::controller::LocalAISetting;
|
||||
use crate::local_ai::resource::PendingResource;
|
||||
use flowy_ai_pub::cloud::{
|
||||
ChatMessage, ChatMessageMetadata, ChatMessageType, CompletionMessage, LLMModel, OutputContent,
|
||||
OutputLayout, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion, ResponseFormat,
|
||||
AIModel, ChatMessage, ChatMessageMetadata, ChatMessageType, CompletionMessage, LLMModel,
|
||||
OutputContent, OutputLayout, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
|
||||
ResponseFormat,
|
||||
};
|
||||
use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
|
||||
use lib_infra::validator_fn::required_not_empty_str;
|
||||
@ -76,9 +77,9 @@ pub struct StreamChatPayloadPB {
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct StreamMessageParams<'a> {
|
||||
pub chat_id: &'a str,
|
||||
pub message: &'a str,
|
||||
pub struct StreamMessageParams {
|
||||
pub chat_id: String,
|
||||
pub message: String,
|
||||
pub message_type: ChatMessageType,
|
||||
pub answer_stream_port: i64,
|
||||
pub question_stream_port: i64,
|
||||
@ -182,12 +183,65 @@ pub struct ChatMessageListPB {
|
||||
pub total: i64,
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
|
||||
pub struct ModelConfigPB {
|
||||
#[derive(Default, ProtoBuf, Clone, Debug)]
|
||||
pub struct ServerAvailableModelsPB {
|
||||
#[pb(index = 1)]
|
||||
pub models: String,
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
|
||||
pub struct AvailableModelsQueryPB {
|
||||
#[pb(index = 1)]
|
||||
#[validate(custom(function = "required_not_empty_str"))]
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
|
||||
pub struct UpdateSelectedModelPB {
|
||||
#[pb(index = 1)]
|
||||
#[validate(custom(function = "required_not_empty_str"))]
|
||||
pub source: String,
|
||||
|
||||
#[pb(index = 2)]
|
||||
pub selected_model: AIModelPB,
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Clone, Debug)]
|
||||
pub struct AvailableModelsPB {
|
||||
#[pb(index = 1)]
|
||||
pub models: Vec<AIModelPB>,
|
||||
|
||||
#[pb(index = 2, one_of)]
|
||||
pub selected_model: Option<AIModelPB>,
|
||||
}
|
||||
|
||||
#[derive(Default, ProtoBuf, Clone, Debug)]
|
||||
pub struct AIModelPB {
|
||||
#[pb(index = 1)]
|
||||
pub name: String,
|
||||
|
||||
#[pb(index = 2)]
|
||||
pub is_local: bool,
|
||||
}
|
||||
|
||||
impl From<AIModel> for AIModelPB {
|
||||
fn from(model: AIModel) -> Self {
|
||||
Self {
|
||||
name: model.name,
|
||||
is_local: model.is_local,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AIModelPB> for AIModel {
|
||||
fn from(value: AIModelPB) -> Self {
|
||||
AIModel {
|
||||
name: value.name,
|
||||
is_local: value.is_local,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RepeatedChatMessage> for ChatMessageListPB {
|
||||
fn from(repeated_chat_message: RepeatedChatMessage) -> Self {
|
||||
let messages = repeated_chat_message
|
||||
@ -471,14 +525,14 @@ pub struct PendingResourcePB {
|
||||
pub enum PendingResourceTypePB {
|
||||
#[default]
|
||||
LocalAIAppRes = 0,
|
||||
AIModel = 1,
|
||||
ModelRes = 1,
|
||||
}
|
||||
|
||||
impl From<PendingResource> for PendingResourceTypePB {
|
||||
fn from(value: PendingResource) -> Self {
|
||||
match value {
|
||||
PendingResource::PluginExecutableNotReady { .. } => PendingResourceTypePB::LocalAIAppRes,
|
||||
_ => PendingResourceTypePB::AIModel,
|
||||
_ => PendingResourceTypePB::ModelRes,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -559,6 +613,9 @@ pub struct UpdateChatSettingsPB {
|
||||
|
||||
#[pb(index = 2)]
|
||||
pub rag_ids: Vec<String>,
|
||||
|
||||
#[pb(index = 3)]
|
||||
pub chat_model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, ProtoBuf)]
|
||||
|
||||
@ -69,8 +69,8 @@ pub(crate) async fn stream_chat_message_handler(
|
||||
trace!("Stream chat message with metadata: {:?}", metadata);
|
||||
|
||||
let params = StreamMessageParams {
|
||||
chat_id: &chat_id,
|
||||
message: &message,
|
||||
chat_id,
|
||||
message,
|
||||
message_type,
|
||||
answer_stream_port,
|
||||
question_stream_port,
|
||||
@ -79,7 +79,7 @@ pub(crate) async fn stream_chat_message_handler(
|
||||
};
|
||||
|
||||
let ai_manager = upgrade_ai_manager(ai_manager)?;
|
||||
let result = ai_manager.stream_chat_message(¶ms).await?;
|
||||
let result = ai_manager.stream_chat_message(params).await?;
|
||||
data_result_ok(result)
|
||||
}
|
||||
|
||||
@ -103,19 +103,36 @@ pub(crate) async fn regenerate_response_handler(
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) async fn get_model_list_handler(
|
||||
pub(crate) async fn get_server_model_list_handler(
|
||||
ai_manager: AFPluginState<Weak<AIManager>>,
|
||||
) -> DataResult<ModelConfigPB, FlowyError> {
|
||||
) -> DataResult<ServerAvailableModelsPB, FlowyError> {
|
||||
let ai_manager = upgrade_ai_manager(ai_manager)?;
|
||||
let available_models = ai_manager.get_available_models().await?;
|
||||
let models = available_models
|
||||
.models
|
||||
.into_iter()
|
||||
.map(|m| m.name)
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let models = ai_manager.get_server_available_models().await?;
|
||||
let models = serde_json::to_string(&json!({"models": models}))?;
|
||||
data_result_ok(ModelConfigPB { models })
|
||||
data_result_ok(ServerAvailableModelsPB { models })
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) async fn get_chat_models_handler(
|
||||
data: AFPluginData<AvailableModelsQueryPB>,
|
||||
ai_manager: AFPluginState<Weak<AIManager>>,
|
||||
) -> DataResult<AvailableModelsPB, FlowyError> {
|
||||
let data = data.try_into_inner()?;
|
||||
let ai_manager = upgrade_ai_manager(ai_manager)?;
|
||||
let models = ai_manager.get_available_models(data.source).await?;
|
||||
data_result_ok(models)
|
||||
}
|
||||
|
||||
pub(crate) async fn update_selected_model_handler(
|
||||
data: AFPluginData<UpdateSelectedModelPB>,
|
||||
ai_manager: AFPluginState<Weak<AIManager>>,
|
||||
) -> Result<(), FlowyError> {
|
||||
let data = data.try_into_inner()?;
|
||||
let ai_manager = upgrade_ai_manager(ai_manager)?;
|
||||
ai_manager
|
||||
.update_selected_model(data.source, data.selected_model)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, err)]
|
||||
|
||||
@ -10,9 +10,14 @@ use crate::ai_manager::AIManager;
|
||||
use crate::event_handler::*;
|
||||
|
||||
pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
|
||||
let user_service = Arc::downgrade(&ai_manager.upgrade().unwrap().user_service);
|
||||
let cloud_service = Arc::downgrade(&ai_manager.upgrade().unwrap().cloud_service_wm);
|
||||
let ai_tools = Arc::new(AICompletion::new(cloud_service, user_service));
|
||||
let strong_ai_manager = ai_manager.upgrade().unwrap();
|
||||
let user_service = Arc::downgrade(&strong_ai_manager.user_service);
|
||||
let cloud_service = Arc::downgrade(&strong_ai_manager.cloud_service_wm);
|
||||
let ai_tools = Arc::new(AICompletion::new(
|
||||
cloud_service,
|
||||
user_service,
|
||||
strong_ai_manager.store_preferences.clone(),
|
||||
));
|
||||
AFPlugin::new()
|
||||
.name("flowy-ai")
|
||||
.state(ai_manager)
|
||||
@ -35,12 +40,17 @@ pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
|
||||
AIEvent::UpdateLocalAISetting,
|
||||
update_local_ai_setting_handler,
|
||||
)
|
||||
.event(AIEvent::GetAvailableModels, get_model_list_handler)
|
||||
.event(
|
||||
AIEvent::GetServerAvailableModels,
|
||||
get_server_model_list_handler,
|
||||
)
|
||||
.event(AIEvent::CreateChatContext, create_chat_context_handler)
|
||||
.event(AIEvent::GetChatInfo, create_chat_context_handler)
|
||||
.event(AIEvent::GetChatSettings, get_chat_settings_handler)
|
||||
.event(AIEvent::UpdateChatSettings, update_chat_settings_handler)
|
||||
.event(AIEvent::RegenerateResponse, regenerate_response_handler)
|
||||
.event(AIEvent::GetAvailableModels, get_chat_models_handler)
|
||||
.event(AIEvent::UpdateSelectedModel, update_selected_model_handler)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)]
|
||||
@ -105,12 +115,18 @@ pub enum AIEvent {
|
||||
#[event(input = "RegenerateResponsePB")]
|
||||
RegenerateResponse = 27,
|
||||
|
||||
#[event(output = "ModelConfigPB")]
|
||||
GetAvailableModels = 28,
|
||||
#[event(output = "ServerAvailableModelsPB")]
|
||||
GetServerAvailableModels = 28,
|
||||
|
||||
#[event(output = "LocalAISettingPB")]
|
||||
GetLocalAISetting = 29,
|
||||
|
||||
#[event(input = "LocalAISettingPB")]
|
||||
UpdateLocalAISetting = 30,
|
||||
|
||||
#[event(input = "AvailableModelsQueryPB", output = "AvailableModelsPB")]
|
||||
GetAvailableModels = 31,
|
||||
|
||||
#[event(input = "UpdateSelectedModelPB")]
|
||||
UpdateSelectedModel = 32,
|
||||
}
|
||||
|
||||
@ -11,3 +11,4 @@ pub mod notification;
|
||||
mod persistence;
|
||||
mod protobuf;
|
||||
mod stream_message;
|
||||
mod util;
|
||||
|
||||
@ -193,6 +193,10 @@ impl LocalAIController {
|
||||
/// AppFlowy store the value in local storage isolated by workspace id. Each workspace can have
|
||||
/// different settings.
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
if !get_operating_system().is_desktop() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Ok(key) = self
|
||||
.user_service
|
||||
.workspace_id()
|
||||
@ -204,6 +208,13 @@ impl LocalAIController {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_plugin_chat_model(&self) -> Option<String> {
|
||||
if !self.is_enabled() {
|
||||
return None;
|
||||
}
|
||||
Some(self.resource.get_llm_setting().chat_model_name)
|
||||
}
|
||||
|
||||
pub fn open_chat(&self, chat_id: &str) {
|
||||
if !self.is_enabled() {
|
||||
return;
|
||||
|
||||
@ -9,7 +9,7 @@ use appflowy_plugin::error::PluginError;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use flowy_ai_pub::cloud::{
|
||||
AppErrorCode, AppResponseError, ChatCloudService, ChatMessage, ChatMessageMetadata,
|
||||
AIModel, AppErrorCode, AppResponseError, ChatCloudService, ChatMessage, ChatMessageMetadata,
|
||||
ChatMessageType, ChatSettings, CompleteTextParams, CompletionStream, LocalAIConfig,
|
||||
MessageCursor, ModelList, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
|
||||
ResponseFormat, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
|
||||
@ -25,7 +25,7 @@ use futures_util::SinkExt;
|
||||
use serde_json::{json, Value};
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tracing::trace;
|
||||
use tracing::{info, trace};
|
||||
|
||||
pub struct AICloudServiceMiddleware {
|
||||
cloud_service: Arc<dyn ChatCloudService>,
|
||||
@ -156,12 +156,19 @@ impl ChatCloudService for AICloudServiceMiddleware {
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
question_id: i64,
|
||||
message_id: i64,
|
||||
format: ResponseFormat,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamAnswer, FlowyError> {
|
||||
if self.local_ai.is_enabled() {
|
||||
let use_local_ai = match &ai_model {
|
||||
None => false,
|
||||
Some(model) => model.is_local,
|
||||
};
|
||||
|
||||
info!("stream_answer use model: {:?}", ai_model);
|
||||
if use_local_ai {
|
||||
if self.local_ai.is_running() {
|
||||
let row = self.get_message_record(question_id)?;
|
||||
let row = self.get_message_record(message_id)?;
|
||||
match self
|
||||
.local_ai
|
||||
.stream_question(chat_id, &row.content, Some(json!(format)), json!({}))
|
||||
@ -179,7 +186,7 @@ impl ChatCloudService for AICloudServiceMiddleware {
|
||||
} else {
|
||||
self
|
||||
.cloud_service
|
||||
.stream_answer(workspace_id, chat_id, question_id, format)
|
||||
.stream_answer(workspace_id, chat_id, message_id, format, ai_model)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@ -273,34 +280,45 @@ impl ChatCloudService for AICloudServiceMiddleware {
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
params: CompleteTextParams,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamComplete, FlowyError> {
|
||||
if self.local_ai.is_running() {
|
||||
match self
|
||||
.local_ai
|
||||
.complete_text_v2(
|
||||
¶ms.text,
|
||||
params.completion_type.unwrap() as u8,
|
||||
Some(json!(params.format)),
|
||||
Some(json!(params.metadata)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => Ok(
|
||||
CompletionStream::new(
|
||||
stream.map_err(|err| AppResponseError::new(AppErrorCode::Internal, err.to_string())),
|
||||
let use_local_ai = match &ai_model {
|
||||
None => false,
|
||||
Some(model) => model.is_local,
|
||||
};
|
||||
|
||||
info!("stream_complete use model: {:?}", ai_model);
|
||||
if use_local_ai {
|
||||
if self.local_ai.is_running() {
|
||||
match self
|
||||
.local_ai
|
||||
.complete_text_v2(
|
||||
¶ms.text,
|
||||
params.completion_type.unwrap() as u8,
|
||||
Some(json!(params.format)),
|
||||
Some(json!(params.metadata)),
|
||||
)
|
||||
.map_err(FlowyError::from)
|
||||
.boxed(),
|
||||
),
|
||||
Err(err) => {
|
||||
self.handle_plugin_error(err);
|
||||
Ok(stream::once(async { Err(FlowyError::local_ai_unavailable()) }).boxed())
|
||||
},
|
||||
.await
|
||||
{
|
||||
Ok(stream) => Ok(
|
||||
CompletionStream::new(
|
||||
stream.map_err(|err| AppResponseError::new(AppErrorCode::Internal, err.to_string())),
|
||||
)
|
||||
.map_err(FlowyError::from)
|
||||
.boxed(),
|
||||
),
|
||||
Err(err) => {
|
||||
self.handle_plugin_error(err);
|
||||
Ok(stream::once(async { Err(FlowyError::local_ai_unavailable()) }).boxed())
|
||||
},
|
||||
}
|
||||
} else {
|
||||
Err(FlowyError::local_ai_not_ready())
|
||||
}
|
||||
} else {
|
||||
self
|
||||
.cloud_service
|
||||
.stream_complete(workspace_id, params)
|
||||
.stream_complete(workspace_id, params, ai_model)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@ -364,4 +382,11 @@ impl ChatCloudService for AICloudServiceMiddleware {
|
||||
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError> {
|
||||
self.cloud_service.get_available_models(workspace_id).await
|
||||
}
|
||||
|
||||
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError> {
|
||||
self
|
||||
.cloud_service
|
||||
.get_workspace_default_model(workspace_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ pub enum ChatNotification {
|
||||
UpdateLocalAIState = 6,
|
||||
DidUpdateChatSettings = 7,
|
||||
LocalAIResourceUpdated = 8,
|
||||
DidUpdateSelectedModel = 9,
|
||||
}
|
||||
|
||||
impl std::convert::From<ChatNotification> for i32 {
|
||||
|
||||
3
frontend/rust-lib/flowy-ai/src/util.rs
Normal file
3
frontend/rust-lib/flowy-ai/src/util.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub fn ai_available_models_key(object_id: &str) -> String {
|
||||
format!("ai_available_models_{}", object_id)
|
||||
}
|
||||
@ -74,14 +74,14 @@ dart = [
|
||||
"flowy-ai/dart",
|
||||
"flowy-storage/dart",
|
||||
]
|
||||
ts = [
|
||||
"flowy-user/tauri_ts",
|
||||
"flowy-folder/tauri_ts",
|
||||
"flowy-search/tauri_ts",
|
||||
"flowy-database2/ts",
|
||||
"flowy-ai/tauri_ts",
|
||||
"flowy-storage/tauri_ts",
|
||||
]
|
||||
#ts = [
|
||||
# "flowy-user/tauri_ts",
|
||||
# "flowy-folder/tauri_ts",
|
||||
# "flowy-search/tauri_ts",
|
||||
# "flowy-database2/ts",
|
||||
# "flowy-ai/tauri_ts",
|
||||
# "flowy-storage/tauri_ts",
|
||||
#]
|
||||
openssl_vendored = ["flowy-sqlite/openssl_vendored"]
|
||||
|
||||
# Enable/Disable AppFlowy Verbose Log Configuration
|
||||
|
||||
@ -21,7 +21,7 @@ use collab_integrate::collab_builder::{
|
||||
CollabCloudPluginProvider, CollabPluginProviderContext, CollabPluginProviderType,
|
||||
};
|
||||
use flowy_ai_pub::cloud::{
|
||||
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
|
||||
AIModel, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
|
||||
CompleteTextParams, LocalAIConfig, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat,
|
||||
StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
|
||||
};
|
||||
@ -705,13 +705,14 @@ impl ChatCloudService for ServerProvider {
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
format: ResponseFormat,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamAnswer, FlowyError> {
|
||||
let workspace_id = workspace_id.to_string();
|
||||
let chat_id = chat_id.to_string();
|
||||
let server = self.get_server()?;
|
||||
server
|
||||
.chat_service()
|
||||
.stream_answer(&workspace_id, &chat_id, message_id, format)
|
||||
.stream_answer(&workspace_id, &chat_id, message_id, format, ai_model)
|
||||
.await
|
||||
}
|
||||
|
||||
@ -772,12 +773,13 @@ impl ChatCloudService for ServerProvider {
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
params: CompleteTextParams,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamComplete, FlowyError> {
|
||||
let workspace_id = workspace_id.to_string();
|
||||
let server = self.get_server()?;
|
||||
server
|
||||
.chat_service()
|
||||
.stream_complete(&workspace_id, params)
|
||||
.stream_complete(&workspace_id, params, ai_model)
|
||||
.await
|
||||
}
|
||||
|
||||
@ -846,6 +848,14 @@ impl ChatCloudService for ServerProvider {
|
||||
.get_available_models(workspace_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError> {
|
||||
self
|
||||
.get_server()?
|
||||
.chat_service()
|
||||
.get_workspace_default_model(workspace_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@ -7,8 +7,8 @@ use client_api::entity::chat_dto::{
|
||||
RepeatedChatMessage,
|
||||
};
|
||||
use flowy_ai_pub::cloud::{
|
||||
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, LocalAIConfig,
|
||||
ModelList, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
|
||||
AIModel, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
|
||||
LocalAIConfig, ModelList, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
|
||||
};
|
||||
use flowy_error::FlowyError;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
@ -101,12 +101,14 @@ where
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
format: ResponseFormat,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamAnswer, FlowyError> {
|
||||
trace!(
|
||||
"stream_answer: workspace_id={}, chat_id={}, format={:?}",
|
||||
"stream_answer: workspace_id={}, chat_id={}, format={:?}, model: {:?}",
|
||||
workspace_id,
|
||||
chat_id,
|
||||
format
|
||||
format,
|
||||
ai_model,
|
||||
);
|
||||
let try_get_client = self.inner.try_get_client();
|
||||
let result = try_get_client?
|
||||
@ -117,6 +119,7 @@ where
|
||||
question_id: message_id,
|
||||
format,
|
||||
},
|
||||
ai_model.map(|v| v.name),
|
||||
)
|
||||
.await;
|
||||
|
||||
@ -189,11 +192,12 @@ where
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
params: CompleteTextParams,
|
||||
ai_model: Option<AIModel>,
|
||||
) -> Result<StreamComplete, FlowyError> {
|
||||
let stream = self
|
||||
.inner
|
||||
.try_get_client()?
|
||||
.stream_completion_v2(workspace_id, params)
|
||||
.stream_completion_v2(workspace_id, params, ai_model.map(|v| v.name))
|
||||
.await
|
||||
.map_err(FlowyError::from)?
|
||||
.map_err(FlowyError::from);
|
||||
@ -280,4 +284,13 @@ where
|
||||
.await?;
|
||||
Ok(list)
|
||||
}
|
||||
|
||||
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError> {
|
||||
let setting = self
|
||||
.inner
|
||||
.try_get_client()?
|
||||
.get_workspace_settings(workspace_id)
|
||||
.await?;
|
||||
Ok(setting.ai_model)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use client_api::entity::ai_dto::{LocalAIConfig, RepeatedRelatedQuestion};
|
||||
use flowy_ai_pub::cloud::{
|
||||
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
|
||||
AIModel, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
|
||||
CompleteTextParams, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat, StreamAnswer,
|
||||
StreamComplete, SubscriptionPlan, UpdateChatParams,
|
||||
};
|
||||
@ -52,6 +52,7 @@ impl ChatCloudService for DefaultChatCloudServiceImpl {
|
||||
_chat_id: &str,
|
||||
_message_id: i64,
|
||||
_format: ResponseFormat,
|
||||
_ai_model: Option<AIModel>,
|
||||
) -> Result<StreamAnswer, FlowyError> {
|
||||
Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
|
||||
}
|
||||
@ -97,6 +98,7 @@ impl ChatCloudService for DefaultChatCloudServiceImpl {
|
||||
&self,
|
||||
_workspace_id: &str,
|
||||
_params: CompleteTextParams,
|
||||
_ai_model: Option<AIModel>,
|
||||
) -> Result<StreamComplete, FlowyError> {
|
||||
Err(FlowyError::not_support().with_context("complete text is not supported in local server."))
|
||||
}
|
||||
@ -148,4 +150,8 @@ impl ChatCloudService for DefaultChatCloudServiceImpl {
|
||||
async fn get_available_models(&self, _workspace_id: &str) -> Result<ModelList, FlowyError> {
|
||||
Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
|
||||
}
|
||||
|
||||
async fn get_workspace_default_model(&self, _workspace_id: &str) -> Result<String, FlowyError> {
|
||||
Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user