chore: support switch ai model in chat or ai writer

This commit is contained in:
Nathan 2025-03-23 21:53:05 +08:00
parent ad695e43b9
commit 05949d2f87
32 changed files with 938 additions and 235 deletions

View File

@ -0,0 +1,138 @@
import 'package:appflowy/ai/service/ai_prompt_input_bloc.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/ai_chat/application/ai_model_switch_listener.dart';
import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:protobuf/protobuf.dart';
import 'package:universal_platform/universal_platform.dart';
class AIModelStateNotifier {
AIModelStateNotifier({required this.objectId})
: _isDesktop = UniversalPlatform.isDesktop,
_localAIListener =
UniversalPlatform.isDesktop ? LocalAIStateListener() : null,
_aiModelSwitchListener = AIModelSwitchListener(chatId: objectId);
final String objectId;
final bool _isDesktop;
final LocalAIStateListener? _localAIListener;
final AIModelSwitchListener _aiModelSwitchListener;
LocalAIPB? _localAIState;
AvailableModelsPB? _availableModels;
// callbacks
void Function(AiType, bool, String)? onChanged;
void Function(AvailableModelsPB)? onAvailableModelsChanged;
String hintText() {
final aiType = getCurrentAiType();
if (aiType.isLocal) {
return isEditable()
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
: LocaleKeys.settings_aiPage_keys_localAIInitializing.tr();
}
return LocaleKeys.chat_inputMessageHint.tr();
}
AiType getCurrentAiType() {
// On non-desktop platforms, always return cloud type
if (!_isDesktop) {
return AiType.cloud;
}
return _availableModels?.selectedModel.isLocal == true
? AiType.local
: AiType.cloud;
}
bool isEditable() {
// On non-desktop platforms, always editable (cloud-only)
if (!_isDesktop) {
return true;
}
return getCurrentAiType().isLocal
? _localAIState?.state == RunningStatePB.Running
: true;
}
void _notifyStateChanged() {
onChanged?.call(getCurrentAiType(), isEditable(), hintText());
}
Future<void> init() async {
await _loadAvailableModels();
}
Future<void> _loadAvailableModels() async {
final payload = AvailableModelsQueryPB(source: objectId);
final result = await AIEventGetAvailableModels(payload).send();
result.fold(
(models) {
_availableModels = models;
onAvailableModelsChanged?.call(models);
_notifyStateChanged();
},
(err) {
Log.error("Failed to get available models: $err");
},
);
}
void startListening({
void Function(AiType, bool, String)? onChanged,
void Function(AvailableModelsPB)? onAvailableModelsChanged,
}) {
this.onChanged = onChanged;
this.onAvailableModelsChanged = onAvailableModelsChanged;
// Only start local AI listener on desktop platforms
if (_isDesktop) {
_localAIListener?.start(
stateCallback: (state) {
_localAIState = state;
if (state.state == RunningStatePB.Running ||
state.state == RunningStatePB.Stopped) {
_loadAvailableModels();
}
},
);
}
_aiModelSwitchListener.start(
onUpdateSelectedModel: (model) {
if (_availableModels != null) {
final updatedModels = _availableModels!.deepCopy()
..selectedModel = model;
_availableModels = updatedModels;
onAvailableModelsChanged?.call(updatedModels);
}
if (model.isLocal && _isDesktop) {
AIEventGetLocalAIState().send().fold(
(localAIState) {
_localAIState = localAIState;
_notifyStateChanged();
},
(error) {
Log.error("Failed to get local AI state: $error");
_notifyStateChanged();
},
);
} else {
_notifyStateChanged();
}
},
);
}
Future<void> stop() async {
onChanged = null;
await _localAIListener?.stop();
await _aiModelSwitchListener.stop();
}
}

View File

@ -1,14 +1,8 @@
import 'dart:async';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/ai/service/ai_input_control.dart';
import 'package:appflowy/plugins/ai_chat/application/chat_entity.dart';
import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart';
@ -18,19 +12,19 @@ part 'ai_prompt_input_bloc.freezed.dart';
class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
AIPromptInputBloc({
required String objectId,
required PredefinedFormat? predefinedFormat,
}) : _listener = LocalAIStateListener(),
super(AIPromptInputState.initial(predefinedFormat)) {
}) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId),
super(AIPromptInputState.initial(objectId, predefinedFormat)) {
_dispatch();
_startListening();
_init();
}
final LocalAIStateListener _listener;
final AIModelStateNotifier _aiModelStateNotifier;
@override
Future<void> close() async {
await _listener.stop();
await _aiModelStateNotifier.stop();
return super.close();
}
@ -38,29 +32,11 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
on<AIPromptInputEvent>(
(event, emit) {
event.when(
updateAIState: (localAIState) {
final aiType = localAIState.enabled ? AiType.local : AiType.cloud;
// final supportChatWithFile =
// aiType.isLocal && localAIState.state == RunningStatePB.Running;
// If local ai is enabled, user can only send messages when the AI is running
final editable = localAIState.enabled
? localAIState.state == RunningStatePB.Running
: true;
var hintText = aiType.isLocal
? LocaleKeys.chat_inputLocalAIMessageHint.tr()
: LocaleKeys.chat_inputMessageHint.tr();
if (editable == false && aiType.isLocal) {
hintText =
LocaleKeys.settings_aiPage_keys_localAIInitializing.tr();
}
updateAIState: (aiType, editable, hintText) {
emit(
state.copyWith(
aiType: aiType,
supportChatWithFile: false,
localAIState: localAIState,
editable: editable,
hintText: hintText,
),
@ -127,25 +103,16 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
);
}
void _startListening() {
_listener.start(
stateCallback: (pluginState) {
if (!isClosed) {
add(AIPromptInputEvent.updateAIState(pluginState));
}
},
);
}
void _init() {
AIEventGetLocalAIState().send().fold(
(localAIState) {
_aiModelStateNotifier.startListening(
onChanged: (aiType, editable, hintText) {
if (!isClosed) {
add(AIPromptInputEvent.updateAIState(localAIState));
add(AIPromptInputEvent.updateAIState(aiType, editable, hintText));
}
},
Log.error,
);
_aiModelStateNotifier.init();
}
Map<String, dynamic> consumeMetadata() {
@ -164,8 +131,12 @@ class AIPromptInputBloc extends Bloc<AIPromptInputEvent, AIPromptInputState> {
@freezed
class AIPromptInputEvent with _$AIPromptInputEvent {
const factory AIPromptInputEvent.updateAIState(LocalAIPB localAIState) =
_UpdateAIState;
const factory AIPromptInputEvent.updateAIState(
AiType aiType,
bool editable,
String hintText,
) = _UpdateAIState;
const factory AIPromptInputEvent.toggleShowPredefinedFormat() =
_ToggleShowPredefinedFormat;
const factory AIPromptInputEvent.updatePredefinedFormat(
@ -184,24 +155,27 @@ class AIPromptInputEvent with _$AIPromptInputEvent {
@freezed
class AIPromptInputState with _$AIPromptInputState {
const factory AIPromptInputState({
required String objectId,
required AiType aiType,
required bool supportChatWithFile,
required bool showPredefinedFormats,
required PredefinedFormat? predefinedFormat,
required LocalAIPB? localAIState,
required List<ChatFile> attachedFiles,
required List<ViewPB> mentionedPages,
required bool editable,
required String hintText,
}) = _AIPromptInputState;
factory AIPromptInputState.initial(PredefinedFormat? format) =>
factory AIPromptInputState.initial(
String objectId,
PredefinedFormat? format,
) =>
AIPromptInputState(
objectId: objectId,
aiType: AiType.cloud,
supportChatWithFile: false,
showPredefinedFormats: format != null,
predefinedFormat: format,
localAIState: null,
attachedFiles: [],
mentionedPages: [],
editable: true,

View File

@ -0,0 +1,82 @@
import 'dart:async';
import 'package:appflowy/ai/service/ai_input_control.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pbserver.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart';
import 'package:protobuf/protobuf.dart';
part 'select_model_bloc.freezed.dart';
class SelectModelBloc extends Bloc<SelectModelEvent, SelectModelState> {
SelectModelBloc({
required this.objectId,
}) : _aiModelStateNotifier = AIModelStateNotifier(objectId: objectId),
super(const SelectModelState()) {
_aiModelStateNotifier.init();
_aiModelStateNotifier.startListening(
onAvailableModelsChanged: (models) {
if (!isClosed) {
add(SelectModelEvent.didLoadModels(models));
}
},
);
on<SelectModelEvent>(
(event, emit) async {
await event.when(
selectModel: (AIModelPB model) async {
await AIEventUpdateSelectedModel(
UpdateSelectedModelPB(
source: objectId,
selectedModel: model,
),
).send();
state.availableModels?.freeze();
final newAvailableModels = state.availableModels?.rebuild((m) {
m.selectedModel = model;
});
emit(
state.copyWith(
availableModels: newAvailableModels,
),
);
},
didLoadModels: (AvailableModelsPB models) {
emit(state.copyWith(availableModels: models));
},
);
},
);
}
final String objectId;
final AIModelStateNotifier _aiModelStateNotifier;
@override
Future<void> close() async {
await _aiModelStateNotifier.stop();
await super.close();
}
}
@freezed
class SelectModelEvent with _$SelectModelEvent {
const factory SelectModelEvent.selectModel(
AIModelPB model,
) = _SelectModel;
const factory SelectModelEvent.didLoadModels(
AvailableModelsPB models,
) = _DidLoadModels;
}
@freezed
class SelectModelState with _$SelectModelState {
const factory SelectModelState({
AvailableModelsPB? availableModels,
}) = _SelectModelState;
}

View File

@ -1,14 +1,17 @@
import 'package:appflowy/ai/ai.dart';
import 'package:appflowy/ai/service/select_model_bloc.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/ai_chat/application/chat_input_control_cubit.dart';
import 'package:appflowy/plugins/ai_chat/presentation/layout_define.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:appflowy/util/theme_extension.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:extended_text_field/extended_text_field.dart';
import 'package:flowy_infra/file_picker/file_picker_service.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flowy_infra_ui/style_widget/hover.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
@ -165,6 +168,7 @@ class _DesktopPromptInputState extends State<DesktopPromptInput> {
top: null,
child: TextFieldTapRegion(
child: _PromptBottomActions(
objectId: state.objectId,
showPredefinedFormats:
state.showPredefinedFormats,
onTogglePredefinedFormatSection: () =>
@ -561,6 +565,7 @@ class PromptInputTextField extends StatelessWidget {
class _PromptBottomActions extends StatelessWidget {
const _PromptBottomActions({
required this.objectId,
required this.sendButtonState,
required this.showPredefinedFormats,
required this.onTogglePredefinedFormatSection,
@ -572,6 +577,7 @@ class _PromptBottomActions extends StatelessWidget {
this.extraBottomActionButton,
});
final String objectId;
final bool showPredefinedFormats;
final void Function() onTogglePredefinedFormatSection;
final void Function() onStartMention;
@ -589,15 +595,10 @@ class _PromptBottomActions extends StatelessWidget {
margin: DesktopAIChatSizes.inputActionBarMargin,
child: BlocBuilder<AIPromptInputBloc, AIPromptInputState>(
builder: (context, state) {
if (state.localAIState == null) {
return Align(
alignment: AlignmentDirectional.centerEnd,
child: _sendButton(),
);
}
return Row(
children: [
_predefinedFormatButton(),
SelectModelButton(objectId: objectId),
const Spacer(),
if (state.aiType.isCloud) ...[
_selectSourcesButton(context),
@ -683,3 +684,160 @@ class _PromptBottomActions extends StatelessWidget {
);
}
}
class SelectModelButton extends StatefulWidget {
const SelectModelButton({
super.key,
required this.objectId,
});
final String objectId;
@override
State<SelectModelButton> createState() => _SelectModelButtonState();
}
class _SelectModelButtonState extends State<SelectModelButton> {
final popoverController = PopoverController();
late SelectModelBloc bloc;
@override
void initState() {
super.initState();
bloc = SelectModelBloc(objectId: widget.objectId);
}
@override
void dispose() {
popoverController.close();
bloc.close();
super.dispose();
}
@override
Widget build(BuildContext context) {
return BlocProvider.value(
value: bloc,
child: BlocBuilder<SelectModelBloc, SelectModelState>(
builder: (context, state) {
return AppFlowyPopover(
// constraints: BoxConstraints.loose(const Size(250, 200)),
offset: const Offset(0.0, -10.0),
direction: PopoverDirection.topWithLeftAligned,
margin: EdgeInsets.zero,
controller: popoverController,
onOpen: () {},
onClose: () {},
popupBuilder: (_) {
return BlocProvider.value(
value: bloc,
child: _PopoverSelectModel(
onClose: () => popoverController.close(),
),
);
},
child: _CurrentModelButton(
modelName: state.availableModels?.selectedModel.name ?? "",
onTap: () => popoverController.show(),
),
);
},
),
);
}
}
class _PopoverSelectModel extends StatelessWidget {
const _PopoverSelectModel({
required this.onClose,
});
final VoidCallback onClose;
@override
Widget build(BuildContext context) {
return BlocBuilder<SelectModelBloc, SelectModelState>(
builder: (context, state) {
return ListView.builder(
shrinkWrap: true,
itemCount: state.availableModels?.models.length ?? 0,
padding: const EdgeInsets.fromLTRB(8, 4, 8, 12),
itemBuilder: (context, index) {
return _ModelItem(
model: state.availableModels!.models[index],
onTap: () {
context.read<SelectModelBloc>().add(
SelectModelEvent.selectModel(
state.availableModels!.models[index],
),
);
onClose();
},
);
},
);
},
);
}
}
class _ModelItem extends StatelessWidget {
const _ModelItem({
required this.model,
required this.onTap,
});
final AIModelPB model;
final VoidCallback onTap;
@override
Widget build(BuildContext context) {
var modelName = model.name;
if (model.isLocal) {
modelName += " (${LocaleKeys.chat_changeFormat_localModel.tr()})";
}
return FlowyTextButton(
modelName,
fillColor: Colors.transparent,
onPressed: onTap,
);
}
}
class _CurrentModelButton extends StatelessWidget {
const _CurrentModelButton({
required this.modelName,
required this.onTap,
});
final String modelName;
final VoidCallback onTap;
@override
Widget build(BuildContext context) {
return FlowyTooltip(
message: LocaleKeys.chat_changeFormat_switchModel.tr(),
child: GestureDetector(
onTap: onTap,
behavior: HitTestBehavior.opaque,
child: SizedBox(
height: DesktopAIPromptSizes.actionBarButtonSize,
child: FlowyHover(
style: const HoverStyle(
borderRadius: BorderRadius.all(Radius.circular(8)),
),
child: Padding(
padding: const EdgeInsetsDirectional.fromSTEB(6, 6, 4, 6),
child: FlowyText(
modelName,
fontSize: 12,
figmaLineHeight: 16,
color: Theme.of(context).hintColor,
),
),
),
),
),
);
}
}

View File

@ -0,0 +1,53 @@
import 'dart:async';
import 'dart:typed_data';
import 'package:appflowy/plugins/ai_chat/application/chat_notification.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/notification.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-error/errors.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-notification/subject.pb.dart';
import 'package:appflowy_backend/rust_stream.dart';
import 'package:appflowy_result/appflowy_result.dart';
typedef OnUpdateSelectedModel = void Function(AIModelPB model);
class AIModelSwitchListener {
AIModelSwitchListener({required this.chatId}) {
_parser = ChatNotificationParser(id: chatId, callback: _callback);
_subscription = RustStreamReceiver.listen(
(observable) => _parser?.parse(observable),
);
}
final String chatId;
StreamSubscription<SubscribeObject>? _subscription;
ChatNotificationParser? _parser;
void start({
OnUpdateSelectedModel? onUpdateSelectedModel,
}) {
this.onUpdateSelectedModel = onUpdateSelectedModel;
}
OnUpdateSelectedModel? onUpdateSelectedModel;
void _callback(
ChatNotification ty,
FlowyResult<Uint8List, FlowyError> result,
) {
result.map((r) {
switch (ty) {
case ChatNotification.DidUpdateSelectedModel:
onUpdateSelectedModel?.call(AIModelPB.fromBuffer(r));
break;
default:
break;
}
});
}
Future<void> stop() async {
await _subscription?.cancel();
_subscription = null;
}
}

View File

@ -73,6 +73,7 @@ class AIChatPage extends StatelessWidget {
/// [AIPromptInputBloc] is used to handle the user prompt
BlocProvider(
create: (_) => AIPromptInputBloc(
objectId: view.id,
predefinedFormat: PredefinedFormat(
imageFormat: ImageFormat.text,
textFormat: TextFormat.bulletList,

View File

@ -2,6 +2,7 @@ import 'package:appflowy/ai/ai.dart';
import 'package:appflowy/generated/flowy_svgs.g.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/ai_chat/presentation/message/ai_markdown_text.dart';
import 'package:appflowy/plugins/document/application/document_bloc.dart';
import 'package:appflowy/util/theme_extension.dart';
import 'package:appflowy/workspace/application/view/view_bloc.dart';
import 'package:appflowy_editor/appflowy_editor.dart';
@ -124,9 +125,12 @@ class _AIWriterBlockComponentState extends State<AiWriterBlockComponent> {
return const SizedBox.shrink();
}
final documentId = context.read<DocumentBloc?>()?.documentId;
return BlocProvider(
create: (_) => AIPromptInputBloc(
predefinedFormat: null,
objectId: documentId ?? editorState.document.root.id,
),
child: LayoutBuilder(
builder: (context, constraints) {

View File

@ -173,7 +173,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
}
void _loadModelList() {
AIEventGetAvailableModels().send().then((result) {
AIEventGetServerAvailableModels().send().then((result) {
result.fold((config) {
if (!isClosed) {
add(SettingsAIEvent.didLoadAvailableModels(config.models));

View File

@ -247,6 +247,8 @@
"table": "Table",
"blankDescription": "Format response",
"defaultDescription": "Auto mode",
"localModel": "Local Model",
"switchModel": "Switch model",
"textWithImageDescription": "@:chat.changeFormat.text with image",
"numberWithImageDescription": "@:chat.changeFormat.number with image",
"bulletWithImageDescription": "@:chat.changeFormat.bullet with image",
@ -3191,4 +3193,4 @@
"rewrite": "Rewrite",
"insertBelow": "Insert below"
}
}
}

View File

@ -163,7 +163,7 @@ checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7"
[[package]]
name = "app-error"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"anyhow",
"bincode",
@ -183,7 +183,7 @@ dependencies = [
[[package]]
name = "appflowy-ai-client"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"anyhow",
"bytes",
@ -788,7 +788,7 @@ dependencies = [
[[package]]
name = "client-api"
version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"again",
"anyhow",
@ -843,7 +843,7 @@ dependencies = [
[[package]]
name = "client-api-entity"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"collab-entity",
"collab-rt-entity",
@ -856,7 +856,7 @@ dependencies = [
[[package]]
name = "client-websocket"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"futures-channel",
"futures-util",
@ -1129,7 +1129,7 @@ dependencies = [
[[package]]
name = "collab-rt-entity"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"anyhow",
"bincode",
@ -1151,7 +1151,7 @@ dependencies = [
[[package]]
name = "collab-rt-protocol"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"anyhow",
"async-trait",
@ -1398,7 +1398,7 @@ dependencies = [
"cssparser-macros",
"dtoa-short",
"itoa",
"phf 0.11.2",
"phf 0.8.0",
"smallvec",
]
@ -1546,7 +1546,7 @@ checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308"
[[package]]
name = "database-entity"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"bincode",
"bytes",
@ -2072,6 +2072,7 @@ dependencies = [
"flowy-error",
"futures",
"lib-infra",
"serde",
"serde_json",
]
@ -2979,7 +2980,7 @@ dependencies = [
[[package]]
name = "gotrue"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"anyhow",
"getrandom 0.2.10",
@ -2994,7 +2995,7 @@ dependencies = [
[[package]]
name = "gotrue-entity"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"app-error",
"jsonwebtoken",
@ -3609,7 +3610,7 @@ dependencies = [
[[package]]
name = "infra"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"anyhow",
"bytes",
@ -4646,7 +4647,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dfb61232e34fcb633f43d12c58f83c1df82962dcdfa565a4e866ffc17dafe12"
dependencies = [
"phf_macros 0.8.0",
"phf_macros",
"phf_shared 0.8.0",
"proc-macro-hack",
]
@ -4666,7 +4667,6 @@ version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
dependencies = [
"phf_macros 0.11.3",
"phf_shared 0.11.2",
]
@ -4734,19 +4734,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "phf_macros"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216"
dependencies = [
"phf_generator 0.11.2",
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn 2.0.94",
]
[[package]]
name = "phf_shared"
version = "0.8.0"
@ -6176,7 +6163,7 @@ dependencies = [
[[package]]
name = "shared-entity"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=1d06321789482dd31a44d515eefa06e00871d8ad#1d06321789482dd31a44d515eefa06e00871d8ad"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=d77dacae32bc25440ca61f675900b80fba6cc9a2#d77dacae32bc25440ca61f675900b80fba6cc9a2"
dependencies = [
"anyhow",
"app-error",

View File

@ -103,8 +103,8 @@ dashmap = "6.0.1"
# Run the script.add_workspace_members:
# scripts/tool/update_client_api_rev.sh new_rev_id
# ⚠️⚠️⚠️️
client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "1d06321789482dd31a44d515eefa06e00871d8ad" }
client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "1d06321789482dd31a44d515eefa06e00871d8ad" }
client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "d77dacae32bc25440ca61f675900b80fba6cc9a2" }
client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "d77dacae32bc25440ca61f675900b80fba6cc9a2" }
[profile.dev]
opt-level = 0

View File

@ -153,8 +153,8 @@ pub fn parse_event_crate(event_crate: &TsEventCrate) -> Vec<EventASTContext> {
attrs
.iter()
.filter(|attr| !attr.attrs.event_attrs.ignore)
.enumerate()
.map(|(_index, variant)| EventASTContext::from(&variant.attrs))
.into_iter()
.map(|variant| EventASTContext::from(&variant.attrs))
.collect::<Vec<_>>()
},
_ => vec![],

View File

@ -1,8 +1,8 @@
use crate::event_builder::EventBuilder;
use crate::EventIntegrationTest;
use flowy_ai::entities::{
ChatMessageListPB, ChatMessageTypePB, CompleteTextPB, CompleteTextTaskPB, CompletionTypePB,
LoadNextChatMessagePB, LoadPrevChatMessagePB, SendChatPayloadPB,
ChatMessageListPB, ChatMessageTypePB, LoadNextChatMessagePB, LoadPrevChatMessagePB,
SendChatPayloadPB,
};
use flowy_ai::event_map::AIEvent;
use flowy_folder::entities::{CreateViewPayloadPB, ViewLayoutPB, ViewPB};
@ -87,27 +87,4 @@ impl EventIntegrationTest {
.await
.parse::<ChatMessageListPB>()
}
pub async fn complete_text(
&self,
text: &str,
completion_type: CompletionTypePB,
) -> CompleteTextTaskPB {
let payload = CompleteTextPB {
text: text.to_string(),
completion_type,
stream_port: 0,
object_id: "".to_string(),
rag_ids: vec![],
format: None,
history: vec![],
custom_prompt: None,
};
EventBuilder::new(self.clone())
.event(AIEvent::CompleteText)
.payload(payload)
.async_send()
.await
.parse::<CompleteTextTaskPB>()
}
}

View File

@ -1,19 +0,0 @@
use event_integration_test::user_event::use_localhost_af_cloud;
use event_integration_test::EventIntegrationTest;
use flowy_ai::entities::CompletionTypePB;
use std::time::Duration;
#[tokio::test]
async fn af_cloud_complete_text_test() {
use_localhost_af_cloud().await;
let test = EventIntegrationTest::new().await;
test.af_cloud_sign_up().await;
let _workspace_id = test.get_current_workspace().await.id;
let _task = test
.complete_text("hello world", CompletionTypePB::MakeLonger)
.await;
tokio::time::sleep(Duration::from_secs(6)).await;
}

View File

@ -1,2 +1 @@
mod ai_tool_test;
mod chat_message_test;

View File

@ -12,3 +12,4 @@ client-api = { workspace = true }
bytes.workspace = true
futures.workspace = true
serde_json.workspace = true
serde.workspace = true

View File

@ -14,6 +14,7 @@ pub use client_api::error::{AppResponseError, ErrorCode as AppErrorCode};
use flowy_error::FlowyError;
use futures::stream::BoxStream;
use lib_infra::async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::path::Path;
@ -21,6 +22,22 @@ use std::path::Path;
pub type ChatMessageStream = BoxStream<'static, Result<ChatMessage, AppResponseError>>;
pub type StreamAnswer = BoxStream<'static, Result<QuestionStreamValue, FlowyError>>;
pub type StreamComplete = BoxStream<'static, Result<CompletionStreamValue, FlowyError>>;
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone)]
pub struct AIModel {
pub name: String,
pub is_local: bool,
}
impl Default for AIModel {
fn default() -> Self {
Self {
name: "default".to_string(),
is_local: false,
}
}
}
#[async_trait]
pub trait ChatCloudService: Send + Sync + 'static {
async fn create_chat(
@ -55,6 +72,7 @@ pub trait ChatCloudService: Send + Sync + 'static {
chat_id: &str,
message_id: i64,
format: ResponseFormat,
ai_model: Option<AIModel>,
) -> Result<StreamAnswer, FlowyError>;
async fn get_answer(
@ -90,6 +108,7 @@ pub trait ChatCloudService: Send + Sync + 'static {
&self,
workspace_id: &str,
params: CompleteTextParams,
ai_model: Option<AIModel>,
) -> Result<StreamComplete, FlowyError>;
async fn embed_file(
@ -121,4 +140,5 @@ pub trait ChatCloudService: Send + Sync + 'static {
) -> Result<(), FlowyError>;
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError>;
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError>;
}

View File

@ -1,7 +1,7 @@
use crate::chat::Chat;
use crate::entities::{
ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB, FilePB, PredefinedFormatPB,
RepeatedRelatedQuestionPB, StreamMessageParams,
AIModelPB, AvailableModelsPB, ChatInfoPB, ChatMessageListPB, ChatMessagePB, ChatSettingsPB,
FilePB, PredefinedFormatPB, RepeatedRelatedQuestionPB, StreamMessageParams,
};
use crate::local_ai::controller::LocalAIController;
use crate::middleware::chat_service_mw::AICloudServiceMiddleware;
@ -10,12 +10,13 @@ use std::collections::HashMap;
use appflowy_plugin::manager::PluginManager;
use dashmap::DashMap;
use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, ModelList, UpdateChatParams};
use flowy_ai_pub::cloud::{AIModel, ChatCloudService, ChatSettings, UpdateChatParams};
use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::kv::KVStorePreferences;
use flowy_sqlite::DBConnection;
use crate::notification::{chat_notification_builder, ChatNotification};
use crate::util::ai_available_models_key;
use collab_integrate::persistence::collab_metadata_sql::{
batch_insert_collab_metadata, batch_select_collab_metadata, AFCollabMetadata,
};
@ -24,6 +25,7 @@ use lib_infra::async_trait::async_trait;
use lib_infra::util::timestamp;
use std::path::PathBuf;
use std::sync::{Arc, Weak};
use tokio::sync::RwLock;
use tracing::{error, info, trace};
pub trait AIUserService: Send + Sync + 'static {
@ -59,7 +61,8 @@ pub struct AIManager {
pub external_service: Arc<dyn AIExternalService>,
chats: Arc<DashMap<String, Arc<Chat>>>,
pub local_ai: Arc<LocalAIController>,
store_preferences: Arc<KVStorePreferences>,
pub store_preferences: Arc<KVStorePreferences>,
server_models: Arc<RwLock<Vec<String>>>,
}
impl AIManager {
@ -99,6 +102,7 @@ impl AIManager {
local_ai,
external_service,
store_preferences,
server_models: Arc::new(Default::default()),
}
}
@ -114,6 +118,7 @@ impl AIManager {
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service_wm.clone(),
self.store_preferences.clone(),
))
});
if self.local_ai.is_running() {
@ -205,24 +210,25 @@ impl AIManager {
save_chat(self.user_service.sqlite_connection(*uid)?, chat_id)?;
let chat = Arc::new(Chat::new(
self.user_service.user_id().unwrap(),
self.user_service.user_id()?,
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service_wm.clone(),
self.store_preferences.clone(),
));
self.chats.insert(chat_id.to_string(), chat.clone());
Ok(chat)
}
pub async fn stream_chat_message<'a>(
&'a self,
params: &'a StreamMessageParams<'a>,
pub async fn stream_chat_message(
&self,
params: StreamMessageParams,
) -> Result<ChatMessagePB, FlowyError> {
let chat = self.get_or_create_chat_instance(params.chat_id).await?;
let question = chat.stream_chat_message(params).await?;
let chat = self.get_or_create_chat_instance(&params.chat_id).await?;
let question = chat.stream_chat_message(&params).await?;
let _ = self
.external_service
.notify_did_send_message(params.chat_id, params.message)
.notify_did_send_message(&params.chat_id, &params.message)
.await;
Ok(question)
}
@ -238,19 +244,149 @@ impl AIManager {
let question_message_id = chat
.get_question_id_from_answer_id(answer_message_id)
.await?;
let preferred_model = self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(chat_id));
chat
.stream_regenerate_response(question_message_id, answer_stream_port, format)
.stream_regenerate_response(
question_message_id,
answer_stream_port,
format,
preferred_model,
)
.await?;
Ok(())
}
pub async fn get_available_models(&self) -> FlowyResult<ModelList> {
pub async fn get_workspace_select_model(&self) -> FlowyResult<String> {
let workspace_id = self.user_service.workspace_id()?;
let model = self
.cloud_service_wm
.get_workspace_default_model(&workspace_id)
.await?;
Ok(model)
}
pub async fn get_server_available_models(&self) -> FlowyResult<Vec<String>> {
let workspace_id = self.user_service.workspace_id()?;
// First, try reading from the cache.
{
let cached_models = self.server_models.read().await;
if !cached_models.is_empty() {
return Ok(cached_models.clone());
}
}
// Cache miss: fetch from the cloud.
let list = self
.cloud_service_wm
.get_available_models(&workspace_id)
.await?;
Ok(list)
let models = list
.models
.into_iter()
.map(|m| m.name)
.collect::<Vec<String>>();
// Update the cache.
*self.server_models.write().await = models.clone();
Ok(models)
}
pub async fn update_selected_model(&self, source: String, model: AIModelPB) -> FlowyResult<()> {
let source_key = ai_available_models_key(&source);
self
.store_preferences
.set_object::<AIModel>(&source_key, &model.clone().into())?;
chat_notification_builder(&source, ChatNotification::DidUpdateSelectedModel)
.payload(model)
.send();
Ok(())
}
pub async fn get_available_models(&self, source: String) -> FlowyResult<AvailableModelsPB> {
// Build the models list from server models and mark them as non-local.
let mut models: Vec<AIModelPB> = self
.get_server_available_models()
.await?
.into_iter()
.map(|name| AIModelPB {
name,
is_local: false,
})
.collect();
// Optionally add the local plugin model.
if let Some(local_model) = self.local_ai.get_plugin_chat_model() {
models.push(AIModelPB {
name: local_model,
is_local: true,
});
}
if models.is_empty() {
return Ok(AvailableModelsPB {
models,
selected_model: None,
});
}
let source_key = ai_available_models_key(&source);
// Retrieve stored selected model, if any.
let stored_selected = self.store_preferences.get_object::<AIModel>(&source_key);
// Get workspace default model once.
let workspace_default = self.get_workspace_select_model().await.ok();
// Determine the effective selected model.
let effective_selected = stored_selected.unwrap_or_else(|| {
if let Some(ws_name) = workspace_default.clone() {
let model = AIModel {
name: ws_name,
is_local: false,
};
// Store the default if not present.
let _ = self.store_preferences.set_object(&source_key, &model);
model
} else {
AIModel::default()
}
});
// Find a matching model in the available list.
let used_model = models
.iter()
.find(|m| m.name == effective_selected.name)
.cloned()
.or_else(|| {
// If no match, try to use the workspace default if available.
if let Some(ws_name) = workspace_default {
Some(AIModelPB {
name: ws_name,
is_local: false,
})
} else {
models.first().cloned()
}
});
// Update the stored preference if a different model is used.
if let Some(ref used) = used_model {
if used.name != effective_selected.name {
self
.store_preferences
.set_object::<AIModel>(&source_key, &AIModel::from(used.clone()))?;
}
}
Ok(AvailableModelsPB {
models,
selected_model: used_model,
})
}
pub async fn get_or_create_chat_instance(&self, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
@ -258,10 +394,11 @@ impl AIManager {
match chat {
None => {
let chat = Arc::new(Chat::new(
self.user_service.user_id().unwrap(),
self.user_service.user_id()?,
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service_wm.clone(),
self.store_preferences.clone(),
));
self.chats.insert(chat_id.to_string(), chat.clone());
Ok(chat)
@ -363,7 +500,6 @@ impl AIManager {
pub async fn update_rag_ids(&self, chat_id: &str, rag_ids: Vec<String>) -> FlowyResult<()> {
info!("[Chat] update chat:{} rag ids: {:?}", chat_id, rag_ids);
let workspace_id = self.user_service.workspace_id()?;
let update_setting = UpdateChatParams {
name: None,

View File

@ -10,11 +10,13 @@ use crate::persistence::{
ChatMessageTable,
};
use crate::stream_message::StreamMessage;
use crate::util::ai_available_models_key;
use allo_isolate::Isolate;
use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, MessageCursor, QuestionStreamValue, ResponseFormat,
AIModel, ChatCloudService, ChatMessage, MessageCursor, QuestionStreamValue, ResponseFormat,
};
use flowy_error::{ErrorCode, FlowyError, FlowyResult};
use flowy_sqlite::kv::KVStorePreferences;
use flowy_sqlite::DBConnection;
use futures::{SinkExt, StreamExt};
use lib_infra::isolate_stream::IsolateSink;
@ -39,6 +41,7 @@ pub struct Chat {
latest_message_id: Arc<AtomicI64>,
stop_stream: Arc<AtomicBool>,
stream_buffer: Arc<Mutex<StringBuffer>>,
store_preferences: Arc<KVStorePreferences>,
}
impl Chat {
@ -47,6 +50,7 @@ impl Chat {
chat_id: String,
user_service: Arc<dyn AIUserService>,
chat_service: Arc<AICloudServiceMiddleware>,
store_preferences: Arc<KVStorePreferences>,
) -> Chat {
Chat {
uid,
@ -57,6 +61,7 @@ impl Chat {
latest_message_id: Default::default(),
stop_stream: Arc::new(AtomicBool::new(false)),
stream_buffer: Arc::new(Mutex::new(StringBuffer::default())),
store_preferences,
}
}
@ -81,9 +86,9 @@ impl Chat {
}
#[instrument(level = "info", skip_all, err)]
pub async fn stream_chat_message<'a>(
&'a self,
params: &'a StreamMessageParams<'a>,
pub async fn stream_chat_message(
&self,
params: &StreamMessageParams,
) -> Result<ChatMessagePB, FlowyError> {
trace!(
"[Chat] stream chat message: chat_id={}, message={}, message_type={:?}, metadata={:?}, format={:?}",
@ -113,7 +118,7 @@ impl Chat {
.create_question(
&workspace_id,
&self.chat_id,
params.message,
&params.message,
params.message_type.clone(),
&[],
)
@ -138,7 +143,9 @@ impl Chat {
// Save message to disk
save_and_notify_message(uid, &self.chat_id, &self.user_service, question.clone())?;
let format = params.format.clone().map(Into::into).unwrap_or_default();
let preferred_ai_model = self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(&self.chat_id));
self.stream_response(
params.answer_stream_port,
answer_stream_buffer,
@ -146,6 +153,7 @@ impl Chat {
workspace_id,
question.message_id,
format,
preferred_ai_model,
);
let question_pb = ChatMessagePB::from(question);
@ -158,6 +166,7 @@ impl Chat {
question_id: i64,
answer_stream_port: i64,
format: Option<PredefinedFormatPB>,
ai_model: Option<AIModel>,
) -> FlowyResult<()> {
trace!(
"[Chat] regenerate and stream chat message: chat_id={}",
@ -183,11 +192,13 @@ impl Chat {
workspace_id,
question_id,
format,
ai_model,
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn stream_response(
&self,
answer_stream_port: i64,
@ -196,6 +207,7 @@ impl Chat {
workspace_id: String,
question_id: i64,
format: ResponseFormat,
ai_model: Option<AIModel>,
) {
let stop_stream = self.stop_stream.clone();
let chat_id = self.chat_id.clone();
@ -204,7 +216,7 @@ impl Chat {
tokio::spawn(async move {
let mut answer_sink = IsolateSink::new(Isolate::new(answer_stream_port));
match cloud_service
.stream_answer(&workspace_id, &chat_id, question_id, format)
.stream_answer(&workspace_id, &chat_id, question_id, format, ai_model)
.await
{
Ok(mut stream) => {

View File

@ -4,14 +4,16 @@ use allo_isolate::Isolate;
use dashmap::DashMap;
use flowy_ai_pub::cloud::{
ChatCloudService, CompleteTextParams, CompletionMetadata, CompletionStreamValue, CompletionType,
CustomPrompt,
AIModel, ChatCloudService, CompleteTextParams, CompletionMetadata, CompletionStreamValue,
CompletionType, CustomPrompt,
};
use flowy_error::{FlowyError, FlowyResult};
use futures::{SinkExt, StreamExt};
use lib_infra::isolate_stream::IsolateSink;
use crate::util::ai_available_models_key;
use flowy_sqlite::kv::KVStorePreferences;
use std::sync::{Arc, Weak};
use tokio::select;
use tracing::info;
@ -20,17 +22,20 @@ pub struct AICompletion {
tasks: Arc<DashMap<String, tokio::sync::mpsc::Sender<()>>>,
cloud_service: Weak<dyn ChatCloudService>,
user_service: Weak<dyn AIUserService>,
store_preferences: Arc<KVStorePreferences>,
}
impl AICompletion {
pub fn new(
cloud_service: Weak<dyn ChatCloudService>,
user_service: Weak<dyn AIUserService>,
store_preferences: Arc<KVStorePreferences>,
) -> Self {
Self {
tasks: Arc::new(DashMap::new()),
cloud_service,
user_service,
store_preferences,
}
}
@ -53,7 +58,17 @@ impl AICompletion {
.ok_or_else(FlowyError::internal)?
.workspace_id()?;
let (tx, rx) = tokio::sync::mpsc::channel(1);
let task = CompletionTask::new(workspace_id, complete, self.cloud_service.clone(), rx);
let preferred_model = self
.store_preferences
.get_object::<AIModel>(&ai_available_models_key(&complete.object_id));
let task = CompletionTask::new(
workspace_id,
complete,
preferred_model,
self.cloud_service.clone(),
rx,
);
let task_id = task.task_id.clone();
self.tasks.insert(task_id.clone(), tx);
@ -74,12 +89,14 @@ pub struct CompletionTask {
stop_rx: tokio::sync::mpsc::Receiver<()>,
context: CompleteTextPB,
cloud_service: Weak<dyn ChatCloudService>,
preferred_model: Option<AIModel>,
}
impl CompletionTask {
pub fn new(
workspace_id: String,
context: CompleteTextPB,
preferred_model: Option<AIModel>,
cloud_service: Weak<dyn ChatCloudService>,
stop_rx: tokio::sync::mpsc::Receiver<()>,
) -> Self {
@ -89,6 +106,7 @@ impl CompletionTask {
context,
cloud_service,
stop_rx,
preferred_model,
}
}
@ -129,7 +147,7 @@ impl CompletionTask {
info!("start completion: {:?}", params);
match cloud_service
.stream_complete(&self.workspace_id, params)
.stream_complete(&self.workspace_id, params, self.preferred_model)
.await
{
Ok(mut stream) => loop {

View File

@ -4,8 +4,9 @@ use std::collections::HashMap;
use crate::local_ai::controller::LocalAISetting;
use crate::local_ai::resource::PendingResource;
use flowy_ai_pub::cloud::{
ChatMessage, ChatMessageMetadata, ChatMessageType, CompletionMessage, LLMModel, OutputContent,
OutputLayout, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion, ResponseFormat,
AIModel, ChatMessage, ChatMessageMetadata, ChatMessageType, CompletionMessage, LLMModel,
OutputContent, OutputLayout, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
ResponseFormat,
};
use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
use lib_infra::validator_fn::required_not_empty_str;
@ -76,9 +77,9 @@ pub struct StreamChatPayloadPB {
}
#[derive(Default, Debug)]
pub struct StreamMessageParams<'a> {
pub chat_id: &'a str,
pub message: &'a str,
pub struct StreamMessageParams {
pub chat_id: String,
pub message: String,
pub message_type: ChatMessageType,
pub answer_stream_port: i64,
pub question_stream_port: i64,
@ -182,12 +183,65 @@ pub struct ChatMessageListPB {
pub total: i64,
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct ModelConfigPB {
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct ServerAvailableModelsPB {
#[pb(index = 1)]
pub models: String,
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct AvailableModelsQueryPB {
#[pb(index = 1)]
#[validate(custom(function = "required_not_empty_str"))]
pub source: String,
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct UpdateSelectedModelPB {
#[pb(index = 1)]
#[validate(custom(function = "required_not_empty_str"))]
pub source: String,
#[pb(index = 2)]
pub selected_model: AIModelPB,
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct AvailableModelsPB {
#[pb(index = 1)]
pub models: Vec<AIModelPB>,
#[pb(index = 2, one_of)]
pub selected_model: Option<AIModelPB>,
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct AIModelPB {
#[pb(index = 1)]
pub name: String,
#[pb(index = 2)]
pub is_local: bool,
}
impl From<AIModel> for AIModelPB {
fn from(model: AIModel) -> Self {
Self {
name: model.name,
is_local: model.is_local,
}
}
}
impl From<AIModelPB> for AIModel {
fn from(value: AIModelPB) -> Self {
AIModel {
name: value.name,
is_local: value.is_local,
}
}
}
impl From<RepeatedChatMessage> for ChatMessageListPB {
fn from(repeated_chat_message: RepeatedChatMessage) -> Self {
let messages = repeated_chat_message
@ -471,14 +525,14 @@ pub struct PendingResourcePB {
pub enum PendingResourceTypePB {
#[default]
LocalAIAppRes = 0,
AIModel = 1,
ModelRes = 1,
}
impl From<PendingResource> for PendingResourceTypePB {
fn from(value: PendingResource) -> Self {
match value {
PendingResource::PluginExecutableNotReady { .. } => PendingResourceTypePB::LocalAIAppRes,
_ => PendingResourceTypePB::AIModel,
_ => PendingResourceTypePB::ModelRes,
}
}
}
@ -559,6 +613,9 @@ pub struct UpdateChatSettingsPB {
#[pb(index = 2)]
pub rag_ids: Vec<String>,
#[pb(index = 3)]
pub chat_model: String,
}
#[derive(Debug, Default, Clone, ProtoBuf)]

View File

@ -69,8 +69,8 @@ pub(crate) async fn stream_chat_message_handler(
trace!("Stream chat message with metadata: {:?}", metadata);
let params = StreamMessageParams {
chat_id: &chat_id,
message: &message,
chat_id,
message,
message_type,
answer_stream_port,
question_stream_port,
@ -79,7 +79,7 @@ pub(crate) async fn stream_chat_message_handler(
};
let ai_manager = upgrade_ai_manager(ai_manager)?;
let result = ai_manager.stream_chat_message(&params).await?;
let result = ai_manager.stream_chat_message(params).await?;
data_result_ok(result)
}
@ -103,19 +103,36 @@ pub(crate) async fn regenerate_response_handler(
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_model_list_handler(
pub(crate) async fn get_server_model_list_handler(
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<ModelConfigPB, FlowyError> {
) -> DataResult<ServerAvailableModelsPB, FlowyError> {
let ai_manager = upgrade_ai_manager(ai_manager)?;
let available_models = ai_manager.get_available_models().await?;
let models = available_models
.models
.into_iter()
.map(|m| m.name)
.collect::<Vec<String>>();
let models = ai_manager.get_server_available_models().await?;
let models = serde_json::to_string(&json!({"models": models}))?;
data_result_ok(ModelConfigPB { models })
data_result_ok(ServerAvailableModelsPB { models })
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_chat_models_handler(
data: AFPluginData<AvailableModelsQueryPB>,
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<AvailableModelsPB, FlowyError> {
let data = data.try_into_inner()?;
let ai_manager = upgrade_ai_manager(ai_manager)?;
let models = ai_manager.get_available_models(data.source).await?;
data_result_ok(models)
}
pub(crate) async fn update_selected_model_handler(
data: AFPluginData<UpdateSelectedModelPB>,
ai_manager: AFPluginState<Weak<AIManager>>,
) -> Result<(), FlowyError> {
let data = data.try_into_inner()?;
let ai_manager = upgrade_ai_manager(ai_manager)?;
ai_manager
.update_selected_model(data.source, data.selected_model)
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, err)]

View File

@ -10,9 +10,14 @@ use crate::ai_manager::AIManager;
use crate::event_handler::*;
pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
let user_service = Arc::downgrade(&ai_manager.upgrade().unwrap().user_service);
let cloud_service = Arc::downgrade(&ai_manager.upgrade().unwrap().cloud_service_wm);
let ai_tools = Arc::new(AICompletion::new(cloud_service, user_service));
let strong_ai_manager = ai_manager.upgrade().unwrap();
let user_service = Arc::downgrade(&strong_ai_manager.user_service);
let cloud_service = Arc::downgrade(&strong_ai_manager.cloud_service_wm);
let ai_tools = Arc::new(AICompletion::new(
cloud_service,
user_service,
strong_ai_manager.store_preferences.clone(),
));
AFPlugin::new()
.name("flowy-ai")
.state(ai_manager)
@ -35,12 +40,17 @@ pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
AIEvent::UpdateLocalAISetting,
update_local_ai_setting_handler,
)
.event(AIEvent::GetAvailableModels, get_model_list_handler)
.event(
AIEvent::GetServerAvailableModels,
get_server_model_list_handler,
)
.event(AIEvent::CreateChatContext, create_chat_context_handler)
.event(AIEvent::GetChatInfo, create_chat_context_handler)
.event(AIEvent::GetChatSettings, get_chat_settings_handler)
.event(AIEvent::UpdateChatSettings, update_chat_settings_handler)
.event(AIEvent::RegenerateResponse, regenerate_response_handler)
.event(AIEvent::GetAvailableModels, get_chat_models_handler)
.event(AIEvent::UpdateSelectedModel, update_selected_model_handler)
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)]
@ -105,12 +115,18 @@ pub enum AIEvent {
#[event(input = "RegenerateResponsePB")]
RegenerateResponse = 27,
#[event(output = "ModelConfigPB")]
GetAvailableModels = 28,
#[event(output = "ServerAvailableModelsPB")]
GetServerAvailableModels = 28,
#[event(output = "LocalAISettingPB")]
GetLocalAISetting = 29,
#[event(input = "LocalAISettingPB")]
UpdateLocalAISetting = 30,
#[event(input = "AvailableModelsQueryPB", output = "AvailableModelsPB")]
GetAvailableModels = 31,
#[event(input = "UpdateSelectedModelPB")]
UpdateSelectedModel = 32,
}

View File

@ -11,3 +11,4 @@ pub mod notification;
mod persistence;
mod protobuf;
mod stream_message;
mod util;

View File

@ -193,6 +193,10 @@ impl LocalAIController {
/// AppFlowy store the value in local storage isolated by workspace id. Each workspace can have
/// different settings.
pub fn is_enabled(&self) -> bool {
if !get_operating_system().is_desktop() {
return false;
}
if let Ok(key) = self
.user_service
.workspace_id()
@ -204,6 +208,13 @@ impl LocalAIController {
}
}
pub fn get_plugin_chat_model(&self) -> Option<String> {
if !self.is_enabled() {
return None;
}
Some(self.resource.get_llm_setting().chat_model_name)
}
pub fn open_chat(&self, chat_id: &str) {
if !self.is_enabled() {
return;

View File

@ -9,7 +9,7 @@ use appflowy_plugin::error::PluginError;
use std::collections::HashMap;
use flowy_ai_pub::cloud::{
AppErrorCode, AppResponseError, ChatCloudService, ChatMessage, ChatMessageMetadata,
AIModel, AppErrorCode, AppResponseError, ChatCloudService, ChatMessage, ChatMessageMetadata,
ChatMessageType, ChatSettings, CompleteTextParams, CompletionStream, LocalAIConfig,
MessageCursor, ModelList, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
ResponseFormat, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
@ -25,7 +25,7 @@ use futures_util::SinkExt;
use serde_json::{json, Value};
use std::path::Path;
use std::sync::{Arc, Weak};
use tracing::trace;
use tracing::{info, trace};
pub struct AICloudServiceMiddleware {
cloud_service: Arc<dyn ChatCloudService>,
@ -156,12 +156,19 @@ impl ChatCloudService for AICloudServiceMiddleware {
&self,
workspace_id: &str,
chat_id: &str,
question_id: i64,
message_id: i64,
format: ResponseFormat,
ai_model: Option<AIModel>,
) -> Result<StreamAnswer, FlowyError> {
if self.local_ai.is_enabled() {
let use_local_ai = match &ai_model {
None => false,
Some(model) => model.is_local,
};
info!("stream_answer use model: {:?}", ai_model);
if use_local_ai {
if self.local_ai.is_running() {
let row = self.get_message_record(question_id)?;
let row = self.get_message_record(message_id)?;
match self
.local_ai
.stream_question(chat_id, &row.content, Some(json!(format)), json!({}))
@ -179,7 +186,7 @@ impl ChatCloudService for AICloudServiceMiddleware {
} else {
self
.cloud_service
.stream_answer(workspace_id, chat_id, question_id, format)
.stream_answer(workspace_id, chat_id, message_id, format, ai_model)
.await
}
}
@ -273,34 +280,45 @@ impl ChatCloudService for AICloudServiceMiddleware {
&self,
workspace_id: &str,
params: CompleteTextParams,
ai_model: Option<AIModel>,
) -> Result<StreamComplete, FlowyError> {
if self.local_ai.is_running() {
match self
.local_ai
.complete_text_v2(
&params.text,
params.completion_type.unwrap() as u8,
Some(json!(params.format)),
Some(json!(params.metadata)),
)
.await
{
Ok(stream) => Ok(
CompletionStream::new(
stream.map_err(|err| AppResponseError::new(AppErrorCode::Internal, err.to_string())),
let use_local_ai = match &ai_model {
None => false,
Some(model) => model.is_local,
};
info!("stream_complete use model: {:?}", ai_model);
if use_local_ai {
if self.local_ai.is_running() {
match self
.local_ai
.complete_text_v2(
&params.text,
params.completion_type.unwrap() as u8,
Some(json!(params.format)),
Some(json!(params.metadata)),
)
.map_err(FlowyError::from)
.boxed(),
),
Err(err) => {
self.handle_plugin_error(err);
Ok(stream::once(async { Err(FlowyError::local_ai_unavailable()) }).boxed())
},
.await
{
Ok(stream) => Ok(
CompletionStream::new(
stream.map_err(|err| AppResponseError::new(AppErrorCode::Internal, err.to_string())),
)
.map_err(FlowyError::from)
.boxed(),
),
Err(err) => {
self.handle_plugin_error(err);
Ok(stream::once(async { Err(FlowyError::local_ai_unavailable()) }).boxed())
},
}
} else {
Err(FlowyError::local_ai_not_ready())
}
} else {
self
.cloud_service
.stream_complete(workspace_id, params)
.stream_complete(workspace_id, params, ai_model)
.await
}
}
@ -364,4 +382,11 @@ impl ChatCloudService for AICloudServiceMiddleware {
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError> {
self.cloud_service.get_available_models(workspace_id).await
}
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError> {
self
.cloud_service
.get_workspace_default_model(workspace_id)
.await
}
}

View File

@ -15,6 +15,7 @@ pub enum ChatNotification {
UpdateLocalAIState = 6,
DidUpdateChatSettings = 7,
LocalAIResourceUpdated = 8,
DidUpdateSelectedModel = 9,
}
impl std::convert::From<ChatNotification> for i32 {

View File

@ -0,0 +1,3 @@
pub fn ai_available_models_key(object_id: &str) -> String {
format!("ai_available_models_{}", object_id)
}

View File

@ -74,14 +74,14 @@ dart = [
"flowy-ai/dart",
"flowy-storage/dart",
]
ts = [
"flowy-user/tauri_ts",
"flowy-folder/tauri_ts",
"flowy-search/tauri_ts",
"flowy-database2/ts",
"flowy-ai/tauri_ts",
"flowy-storage/tauri_ts",
]
#ts = [
# "flowy-user/tauri_ts",
# "flowy-folder/tauri_ts",
# "flowy-search/tauri_ts",
# "flowy-database2/ts",
# "flowy-ai/tauri_ts",
# "flowy-storage/tauri_ts",
#]
openssl_vendored = ["flowy-sqlite/openssl_vendored"]
# Enable/Disable AppFlowy Verbose Log Configuration

View File

@ -21,7 +21,7 @@ use collab_integrate::collab_builder::{
CollabCloudPluginProvider, CollabPluginProviderContext, CollabPluginProviderType,
};
use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
AIModel, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
CompleteTextParams, LocalAIConfig, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat,
StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
};
@ -705,13 +705,14 @@ impl ChatCloudService for ServerProvider {
chat_id: &str,
message_id: i64,
format: ResponseFormat,
ai_model: Option<AIModel>,
) -> Result<StreamAnswer, FlowyError> {
let workspace_id = workspace_id.to_string();
let chat_id = chat_id.to_string();
let server = self.get_server()?;
server
.chat_service()
.stream_answer(&workspace_id, &chat_id, message_id, format)
.stream_answer(&workspace_id, &chat_id, message_id, format, ai_model)
.await
}
@ -772,12 +773,13 @@ impl ChatCloudService for ServerProvider {
&self,
workspace_id: &str,
params: CompleteTextParams,
ai_model: Option<AIModel>,
) -> Result<StreamComplete, FlowyError> {
let workspace_id = workspace_id.to_string();
let server = self.get_server()?;
server
.chat_service()
.stream_complete(&workspace_id, params)
.stream_complete(&workspace_id, params, ai_model)
.await
}
@ -846,6 +848,14 @@ impl ChatCloudService for ServerProvider {
.get_available_models(workspace_id)
.await
}
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError> {
self
.get_server()?
.chat_service()
.get_workspace_default_model(workspace_id)
.await
}
}
#[async_trait]

View File

@ -7,8 +7,8 @@ use client_api::entity::chat_dto::{
RepeatedChatMessage,
};
use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, LocalAIConfig,
ModelList, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
AIModel, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
LocalAIConfig, ModelList, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
};
use flowy_error::FlowyError;
use futures_util::{StreamExt, TryStreamExt};
@ -101,12 +101,14 @@ where
chat_id: &str,
message_id: i64,
format: ResponseFormat,
ai_model: Option<AIModel>,
) -> Result<StreamAnswer, FlowyError> {
trace!(
"stream_answer: workspace_id={}, chat_id={}, format={:?}",
"stream_answer: workspace_id={}, chat_id={}, format={:?}, model: {:?}",
workspace_id,
chat_id,
format
format,
ai_model,
);
let try_get_client = self.inner.try_get_client();
let result = try_get_client?
@ -117,6 +119,7 @@ where
question_id: message_id,
format,
},
ai_model.map(|v| v.name),
)
.await;
@ -189,11 +192,12 @@ where
&self,
workspace_id: &str,
params: CompleteTextParams,
ai_model: Option<AIModel>,
) -> Result<StreamComplete, FlowyError> {
let stream = self
.inner
.try_get_client()?
.stream_completion_v2(workspace_id, params)
.stream_completion_v2(workspace_id, params, ai_model.map(|v| v.name))
.await
.map_err(FlowyError::from)?
.map_err(FlowyError::from);
@ -280,4 +284,13 @@ where
.await?;
Ok(list)
}
async fn get_workspace_default_model(&self, workspace_id: &str) -> Result<String, FlowyError> {
let setting = self
.inner
.try_get_client()?
.get_workspace_settings(workspace_id)
.await?;
Ok(setting.ai_model)
}
}

View File

@ -1,6 +1,6 @@
use client_api::entity::ai_dto::{LocalAIConfig, RepeatedRelatedQuestion};
use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
AIModel, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
CompleteTextParams, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat, StreamAnswer,
StreamComplete, SubscriptionPlan, UpdateChatParams,
};
@ -52,6 +52,7 @@ impl ChatCloudService for DefaultChatCloudServiceImpl {
_chat_id: &str,
_message_id: i64,
_format: ResponseFormat,
_ai_model: Option<AIModel>,
) -> Result<StreamAnswer, FlowyError> {
Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
}
@ -97,6 +98,7 @@ impl ChatCloudService for DefaultChatCloudServiceImpl {
&self,
_workspace_id: &str,
_params: CompleteTextParams,
_ai_model: Option<AIModel>,
) -> Result<StreamComplete, FlowyError> {
Err(FlowyError::not_support().with_context("complete text is not supported in local server."))
}
@ -148,4 +150,8 @@ impl ChatCloudService for DefaultChatCloudServiceImpl {
async fn get_available_models(&self, _workspace_id: &str) -> Result<ModelList, FlowyError> {
Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
}
async fn get_workspace_default_model(&self, _workspace_id: &str) -> Result<String, FlowyError> {
Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
}
}