diff --git a/frontend/appflowy_flutter/integration_test/desktop/cloud/document/document_ai_writer_test.dart b/frontend/appflowy_flutter/integration_test/desktop/cloud/document/document_ai_writer_test.dart index f163608ccb..e5548b0d97 100644 --- a/frontend/appflowy_flutter/integration_test/desktop/cloud/document/document_ai_writer_test.dart +++ b/frontend/appflowy_flutter/integration_test/desktop/cloud/document/document_ai_writer_test.dart @@ -1,12 +1,19 @@ +import 'package:appflowy/ai/service/ai_entities.dart'; import 'package:appflowy/env/cloud_env.dart'; import 'package:appflowy/generated/locale_keys.g.dart'; +import 'package:appflowy/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_entities.dart'; import 'package:appflowy/plugins/document/presentation/editor_plugins/plugins.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/entities.pbenum.dart'; import 'package:appflowy_backend/protobuf/flowy-folder/view.pb.dart'; +import 'package:appflowy_editor/appflowy_editor.dart'; import 'package:easy_localization/easy_localization.dart'; +import 'package:flutter/services.dart'; import 'package:flutter_test/flutter_test.dart'; import 'package:integration_test/integration_test.dart'; +import '../../../shared/ai_test_op.dart'; import '../../../shared/constants.dart'; +import '../../../shared/mock/mock_ai.dart'; import '../../../shared/util.dart'; void main() { @@ -17,6 +24,7 @@ void main() { (tester) async { await tester.initializeAppFlowy( cloudType: AuthenticatorType.appflowyCloudSelfHost, + aiRepositoryBuilder: () => MockAIRepository(), ); await tester.tapGoogleLoginInButton(); await tester.expectToSeeHomePageWithGetStartedPage(); @@ -43,5 +51,159 @@ void main() { // expect the ai writer block is not in the document expect(find.byType(AiWriterBlockComponent), findsNothing); }); + + testWidgets('Improve writing', (tester) async { + await tester.initializeAppFlowy( + cloudType: AuthenticatorType.appflowyCloudSelfHost, + ); + await tester.tapGoogleLoginInButton(); + await tester.expectToSeeHomePageWithGetStartedPage(); + + const pageName = 'Document'; + await tester.createNewPageInSpace( + spaceName: Constants.generalSpaceName, + layout: ViewLayoutPB.Document, + pageName: pageName, + ); + + await tester.editor.tapLineOfEditorAt(0); + + // insert a paragraph + final text = 'I have an apple'; + await tester.editor.tapLineOfEditorAt(0); + await tester.ime.insertText(text); + await tester.editor.updateSelection( + Selection( + start: Position(path: [0]), + end: Position(path: [0], offset: text.length), + ), + ); + + await tester.pumpAndSettle(); + await tester.tapButton(find.byType(ImproveWritingButton)); + + final editorState = tester.editor.getCurrentEditorState(); + final document = editorState.document; + + expect(document.root.children.length, 3); + expect(document.root.children[1].type, ParagraphBlockKeys.type); + expect( + document.root.children[1].delta!.toPlainText(), + 'I have an apple and a banana', + ); + }); + + testWidgets('fix grammar', (tester) async { + await tester.initializeAppFlowy( + cloudType: AuthenticatorType.appflowyCloudSelfHost, + ); + await tester.tapGoogleLoginInButton(); + await tester.expectToSeeHomePageWithGetStartedPage(); + + const pageName = 'Document'; + await tester.createNewPageInSpace( + spaceName: Constants.generalSpaceName, + layout: ViewLayoutPB.Document, + pageName: pageName, + ); + + await tester.editor.tapLineOfEditorAt(0); + + // insert a paragraph + final text = 'We didn’t had enough money'; + await tester.editor.tapLineOfEditorAt(0); + await tester.ime.insertText(text); + await tester.editor.updateSelection( + Selection( + start: Position(path: [0]), + end: Position(path: [0], offset: text.length), + ), + ); + + await tester.pumpAndSettle(); + await tester.tapButton(find.byType(AiWriterToolbarActionList)); + await tester.tapButton( + find.text(AiWriterCommand.fixSpellingAndGrammar.i18n), + ); + await tester.pumpAndSettle(); + + final editorState = tester.editor.getCurrentEditorState(); + final document = editorState.document; + + expect(document.root.children.length, 3); + expect(document.root.children[1].type, ParagraphBlockKeys.type); + expect( + document.root.children[1].delta!.toPlainText(), + 'We didn’t have enough money', + ); + }); + + testWidgets('ask ai', (tester) async { + await tester.initializeAppFlowy( + cloudType: AuthenticatorType.appflowyCloudSelfHost, + aiRepositoryBuilder: () => MockAIRepository( + validator: _CompletionHistoryValidator(), + ), + ); + await tester.tapGoogleLoginInButton(); + await tester.expectToSeeHomePageWithGetStartedPage(); + + const pageName = 'Document'; + await tester.createNewPageInSpace( + spaceName: Constants.generalSpaceName, + layout: ViewLayoutPB.Document, + pageName: pageName, + ); + + await tester.editor.tapLineOfEditorAt(0); + + // insert a paragraph + final text = 'What is TPU?'; + await tester.editor.tapLineOfEditorAt(0); + await tester.ime.insertText(text); + await tester.editor.updateSelection( + Selection( + start: Position(path: [0]), + end: Position(path: [0], offset: text.length), + ), + ); + + await tester.pumpAndSettle(); + await tester.tapButton(find.byType(AiWriterToolbarActionList)); + await tester.tapButton( + find.text(AiWriterCommand.userQuestion.i18n), + ); + await tester.pumpAndSettle(); + + await tester.enterTextInPromptTextField("Explain the concept of TPU"); + + // click enter button + await tester.simulateKeyEvent(LogicalKeyboardKey.enter); + await tester.pumpAndSettle(Duration(seconds: 10)); + }); }); } + +class _CompletionHistoryValidator extends StreamCompletionValidator { + @override + bool validate( + String text, + String? objectId, + CompletionTypePB completionType, + PredefinedFormat? format, + List sourceIds, + List history, + ) { + assert(completionType == CompletionTypePB.UserQuestion); + assert( + history.length == 1, + "expect history length is 1, but got ${history.length}", + ); + assert( + history[0].content.trim() == "What is TPU?", + "expect history[0].content is 'What is TPU?', but got '${history[0].content.trim()}'", + ); + + return true; + } +} diff --git a/frontend/appflowy_flutter/integration_test/desktop/document/document_copy_and_paste_test.dart b/frontend/appflowy_flutter/integration_test/desktop/document/document_copy_and_paste_test.dart index c18b42939c..ec61034d40 100644 --- a/frontend/appflowy_flutter/integration_test/desktop/document/document_copy_and_paste_test.dart +++ b/frontend/appflowy_flutter/integration_test/desktop/document/document_copy_and_paste_test.dart @@ -1,5 +1,6 @@ import 'dart:io'; +import 'package:appflowy/env/cloud_env.dart'; import 'package:appflowy/generated/locale_keys.g.dart'; import 'package:appflowy/plugins/document/presentation/editor_plugins/block_menu/block_menu_button.dart'; import 'package:appflowy/plugins/document/presentation/editor_plugins/copy_and_paste/clipboard_service.dart'; @@ -530,6 +531,7 @@ extension on WidgetTester { (String, Uint8List?)? image, }) async { await initializeAppFlowy(); + await useAppFlowyCloudDevelop("http://localhost"); await tapAnonymousSignInButton(); // create a new document diff --git a/frontend/appflowy_flutter/integration_test/shared/ai_test_op.dart b/frontend/appflowy_flutter/integration_test/shared/ai_test_op.dart new file mode 100644 index 0000000000..6d5b34ba0c --- /dev/null +++ b/frontend/appflowy_flutter/integration_test/shared/ai_test_op.dart @@ -0,0 +1,23 @@ +import 'package:appflowy/ai/ai.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:extended_text_field/extended_text_field.dart'; + +extension AppFlowyAITest on WidgetTester { + Future enterTextInPromptTextField(String text) async { + // Wait for the text field to be visible + await pumpAndSettle(); + + // Find the ExtendedTextField widget + final textField = find.descendant( + of: find.byType(PromptInputTextField), + matching: find.byType(ExtendedTextField), + ); + expect(textField, findsOneWidget, reason: 'ExtendedTextField not found'); + + final widget = element(textField).widget as ExtendedTextField; + expect(widget.enabled, isTrue, reason: 'TextField is not enabled'); + + testTextInput.enterText(text); + await pumpAndSettle(const Duration(milliseconds: 300)); + } +} diff --git a/frontend/appflowy_flutter/integration_test/shared/base.dart b/frontend/appflowy_flutter/integration_test/shared/base.dart index 493cb4c1f0..bb0489a926 100644 --- a/frontend/appflowy_flutter/integration_test/shared/base.dart +++ b/frontend/appflowy_flutter/integration_test/shared/base.dart @@ -1,6 +1,7 @@ import 'dart:async'; import 'dart:io'; +import 'package:appflowy/ai/service/appflowy_ai_service.dart'; import 'package:appflowy/env/cloud_env.dart'; import 'package:appflowy/env/cloud_env_test.dart'; import 'package:appflowy/startup/entry_point.dart'; @@ -20,6 +21,8 @@ import 'package:path/path.dart' as p; import 'package:path_provider/path_provider.dart'; import 'package:universal_platform/universal_platform.dart'; +import 'mock/mock_ai.dart'; + class FlowyTestContext { FlowyTestContext({required this.applicationDataDirectory}); @@ -33,8 +36,9 @@ extension AppFlowyTestBase on WidgetTester { // use to specify the application data directory, if not specified, a temporary directory will be used. String? dataDirectory, Size windowSize = const Size(1600, 1200), - AuthenticatorType? cloudType, String? email, + AuthenticatorType? cloudType, + AIRepository Function()? aiRepositoryBuilder, }) async { if (Platform.isLinux || Platform.isWindows || Platform.isMacOS) { // Set the window size @@ -60,6 +64,10 @@ extension AppFlowyTestBase on WidgetTester { rustEnvs["GOTRUE_ADMIN_EMAIL"] = "admin@example.com"; rustEnvs["GOTRUE_ADMIN_PASSWORD"] = "password"; break; + case AuthenticatorType.appflowyCloudDevelop: + rustEnvs["GOTRUE_ADMIN_EMAIL"] = "admin@example.com"; + rustEnvs["GOTRUE_ADMIN_PASSWORD"] = "password"; + break; default: throw Exception("not supported"); } @@ -75,11 +83,32 @@ extension AppFlowyTestBase on WidgetTester { await useLocalServer(); break; case AuthenticatorType.appflowyCloudSelfHost: - await useTestSelfHostedAppFlowyCloud(); + await useSelfHostedAppFlowyCloud(TestEnv.afCloudUrl); getIt.unregister(); + getIt.unregister(); + getIt.registerFactory( () => AppFlowyCloudMockAuthService(email: email), ); + getIt.registerFactory( + aiRepositoryBuilder ?? () => MockAIRepository(), + ); + case AuthenticatorType.appflowyCloudDevelop: + if (integrationMode().isDevelop) { + await useAppFlowyCloudDevelop("http://localhost"); + } else { + await useSelfHostedAppFlowyCloud(TestEnv.afCloudUrl); + } + getIt.unregister(); + getIt.unregister(); + + getIt.registerFactory( + () => AppFlowyCloudMockAuthService(email: email), + ); + getIt.registerFactory( + aiRepositoryBuilder ?? () => MockAIRepository(), + ); + break; default: throw Exception("not supported"); } @@ -275,10 +304,6 @@ extension AppFlowyFinderTestBase on CommonFinders { } } -Future useTestSelfHostedAppFlowyCloud() async { - await useSelfHostedAppFlowyCloudWithURL(TestEnv.afCloudUrl); -} - Future mockApplicationDataStorage({ // use to append after the application data directory String? pathExtension, diff --git a/frontend/appflowy_flutter/integration_test/shared/mock/mock_ai.dart b/frontend/appflowy_flutter/integration_test/shared/mock/mock_ai.dart new file mode 100644 index 0000000000..8f6d6757b4 --- /dev/null +++ b/frontend/appflowy_flutter/integration_test/shared/mock/mock_ai.dart @@ -0,0 +1,112 @@ +import 'dart:async'; + +import 'package:appflowy/ai/service/ai_entities.dart'; +import 'package:appflowy/ai/service/appflowy_ai_service.dart'; +import 'package:appflowy/ai/service/error.dart'; +import 'package:appflowy/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_entities.dart'; +import 'package:appflowy_backend/protobuf/flowy-ai/entities.pbenum.dart'; +import 'package:mocktail/mocktail.dart'; + +final _mockAiMap = >>{ + CompletionTypePB.ImproveWriting: { + "I have an apple": [ + "I", + "have", + "an", + "apple", + "and", + "a", + "banana", + ], + }, + CompletionTypePB.SpellingAndGrammar: { + "We didn’t had enough money": [ + "We", + "didn’t", + "have", + "enough", + "money", + ], + }, + CompletionTypePB.UserQuestion: { + "Explain the concept of TPU": [ + "TPU", + "is", + "a", + "tensor", + "processing", + "unit", + "that", + "is", + "designed", + "to", + "accelerate", + "machine", + ], + }, +}; + +abstract class StreamCompletionValidator { + bool validate( + String text, + String? objectId, + CompletionTypePB completionType, + PredefinedFormat? format, + List sourceIds, + List history, + ); +} + +class MockCompletionStream extends Mock implements CompletionStream {} + +class MockAIRepository extends Mock implements AppFlowyAIService { + MockAIRepository({this.validator}); + StreamCompletionValidator? validator; + + @override + Future<(String, CompletionStream)?> streamCompletion({ + String? objectId, + required String text, + PredefinedFormat? format, + List sourceIds = const [], + List history = const [], + required CompletionTypePB completionType, + required Future Function() onStart, + required Future Function(String text) processMessage, + required Future Function(String text) processAssistMessage, + required Future Function() onEnd, + required void Function(AIError error) onError, + required void Function(LocalAIStreamingState state) + onLocalAIStreamingStateChange, + }) async { + if (validator != null) { + if (!validator!.validate( + text, + objectId, + completionType, + format, + sourceIds, + history, + )) { + throw Exception('Invalid completion'); + } + } + final stream = MockCompletionStream(); + unawaited( + Future(() async { + await onStart(); + final lines = _mockAiMap[completionType]?[text.trim()]; + + if (lines == null) { + throw Exception('No mock ai found for $text and $completionType'); + } + + for (final line in lines) { + await processMessage('$line '); + } + await onEnd(); + }), + ); + return ('mock_id', stream); + } +} diff --git a/frontend/appflowy_flutter/lib/ai/service/appflowy_ai_service.dart b/frontend/appflowy_flutter/lib/ai/service/appflowy_ai_service.dart index 39487652f8..b55f09e1d6 100644 --- a/frontend/appflowy_flutter/lib/ai/service/appflowy_ai_service.dart +++ b/frontend/appflowy_flutter/lib/ai/service/appflowy_ai_service.dart @@ -21,7 +21,7 @@ enum LocalAIStreamingState { } abstract class AIRepository { - Future streamCompletion({ + Future<(String, CompletionStream)?> streamCompletion({ String? objectId, required String text, PredefinedFormat? format, diff --git a/frontend/appflowy_flutter/lib/env/cloud_env.dart b/frontend/appflowy_flutter/lib/env/cloud_env.dart index 15f3ada42e..9e24e929b1 100644 --- a/frontend/appflowy_flutter/lib/env/cloud_env.dart +++ b/frontend/appflowy_flutter/lib/env/cloud_env.dart @@ -167,11 +167,16 @@ Future useBaseWebDomain(String? url) async { ); } -Future useSelfHostedAppFlowyCloudWithURL(String url) async { +Future useSelfHostedAppFlowyCloud(String url) async { await _setAuthenticatorType(AuthenticatorType.appflowyCloudSelfHost); await _setAppFlowyCloudUrl(url); } +Future useAppFlowyCloudDevelop(String url) async { + await _setAuthenticatorType(AuthenticatorType.appflowyCloudDevelop); + await _setAppFlowyCloudUrl(url); +} + Future useAppFlowyBetaCloudWithURL( String url, AuthenticatorType authenticatorType, diff --git a/frontend/appflowy_flutter/lib/mobile/presentation/setting/self_host/self_host_bottom_sheet.dart b/frontend/appflowy_flutter/lib/mobile/presentation/setting/self_host/self_host_bottom_sheet.dart index dd19c2489d..5e19667c68 100644 --- a/frontend/appflowy_flutter/lib/mobile/presentation/setting/self_host/self_host_bottom_sheet.dart +++ b/frontend/appflowy_flutter/lib/mobile/presentation/setting/self_host/self_host_bottom_sheet.dart @@ -87,7 +87,7 @@ class _SelfHostUrlBottomSheetState extends State { case SelfHostUrlBottomSheetType.shareDomain: await useBaseWebDomain(url); case SelfHostUrlBottomSheetType.cloudURL: - await useSelfHostedAppFlowyCloudWithURL(url); + await useSelfHostedAppFlowyCloud(url); } await runAppFlowy(); }, diff --git a/frontend/appflowy_flutter/lib/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_cubit.dart b/frontend/appflowy_flutter/lib/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_cubit.dart index 7fc93e1c07..e5ec888244 100644 --- a/frontend/appflowy_flutter/lib/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_cubit.dart +++ b/frontend/appflowy_flutter/lib/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_cubit.dart @@ -1,6 +1,7 @@ import 'dart:async'; import 'package:appflowy/ai/ai.dart'; +import 'package:appflowy/startup/startup.dart'; import 'package:appflowy/workspace/application/view/view_service.dart'; import 'package:appflowy_backend/dispatch/dispatch.dart'; import 'package:appflowy_backend/log.dart'; @@ -23,15 +24,14 @@ class AiWriterCubit extends Cubit { this.onCreateNode, this.onRemoveNode, this.onAppendToDocument, - AppFlowyAIService? aiService, - }) : _aiService = aiService ?? AppFlowyAIService(), + }) : _aiService = getIt(), _textRobot = MarkdownTextRobot(editorState: editorState), selectedSourcesNotifier = ValueNotifier([documentId]), super(IdleAiWriterState()); final String documentId; final EditorState editorState; - final AppFlowyAIService _aiService; + final AIRepository _aiService; final MarkdownTextRobot _textRobot; final void Function()? onCreateNode; final void Function()? onRemoveNode; @@ -295,7 +295,6 @@ class AiWriterCubit extends Cubit { } final selectionText = await editorState.getMarkdownInSelection(selection); - Log.warn('[AI writer] Selection is null'); if (command == AiWriterCommand.userQuestion) { records.add( diff --git a/frontend/appflowy_flutter/lib/startup/deps_resolver.dart b/frontend/appflowy_flutter/lib/startup/deps_resolver.dart index 621ba988cf..fffec2fd0b 100644 --- a/frontend/appflowy_flutter/lib/startup/deps_resolver.dart +++ b/frontend/appflowy_flutter/lib/startup/deps_resolver.dart @@ -1,3 +1,4 @@ +import 'package:appflowy/ai/service/appflowy_ai_service.dart'; import 'package:appflowy/core/config/kv.dart'; import 'package:appflowy/core/network_monitor.dart'; import 'package:appflowy/env/cloud_env.dart'; @@ -59,6 +60,7 @@ Future _resolveCloudDeps(GetIt getIt) async { final env = await AppFlowyCloudSharedEnv.fromEnv(); Log.info("cloud setting: $env"); getIt.registerFactory(() => env); + getIt.registerFactory(() => AppFlowyAIService()); if (isAppFlowyCloudEnabled) { getIt.registerSingleton( diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/appflowy_cloud_urls_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/appflowy_cloud_urls_bloc.dart index 5652904180..371fd75583 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/appflowy_cloud_urls_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/appflowy_cloud_urls_bloc.dart @@ -48,7 +48,7 @@ class AppFlowyCloudURLsBloc await validateUrl(state.updatedServerUrl).fold( (url) async { - await useSelfHostedAppFlowyCloudWithURL(url); + await useSelfHostedAppFlowyCloud(url); isSuccess = true; }, (err) async => emit(state.copyWith(urlError: err)), diff --git a/frontend/appflowy_flutter/test/bloc_test/ai_writer_test/ai_writer_bloc_test.dart b/frontend/appflowy_flutter/test/bloc_test/ai_writer_test/ai_writer_bloc_test.dart index bcd8b13d39..d1873cbe8b 100644 --- a/frontend/appflowy_flutter/test/bloc_test/ai_writer_test/ai_writer_bloc_test.dart +++ b/frontend/appflowy_flutter/test/bloc_test/ai_writer_test/ai_writer_bloc_test.dart @@ -4,6 +4,7 @@ import 'package:appflowy/ai/ai.dart'; import 'package:appflowy/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_cubit.dart'; import 'package:appflowy/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_entities.dart'; import 'package:appflowy/plugins/document/presentation/editor_plugins/plugins.dart'; +import 'package:appflowy/startup/startup.dart'; import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart'; import 'package:appflowy_editor/appflowy_editor.dart'; import 'package:bloc_test/bloc_test.dart'; @@ -145,6 +146,13 @@ class _MockErrorRepository extends Mock implements AppFlowyAIService { } } +void registerMockRepository(AppFlowyAIService mock) { + if (getIt.isRegistered()) { + getIt.unregister(); + } + getIt.registerFactory(() => mock); +} + void main() { group('AIWriterCubit:', () { const text1 = '1. Select text to style using the toolbar menu.'; @@ -174,10 +182,10 @@ void main() { ); final editorState = EditorState(document: document) ..selection = selection; + registerMockRepository(_MockAIRepository()); return AiWriterCubit( documentId: '', editorState: editorState, - aiService: _MockAIRepository(), ); }, act: (bloc) => bloc.register( @@ -230,10 +238,10 @@ void main() { ); final editorState = EditorState(document: document) ..selection = selection; + registerMockRepository(_MockErrorRepository()); return AiWriterCubit( documentId: '', editorState: editorState, - aiService: _MockErrorRepository(), ); }, act: (bloc) => bloc.register( @@ -279,10 +287,10 @@ void main() { final editorState = EditorState(document: document) ..selection = selection; final aiNode = editorState.getNodeAtPath([3])!; + registerMockRepository(_MockAIRepository()); final bloc = AiWriterCubit( documentId: '', editorState: editorState, - aiService: _MockAIRepository(), ); bloc.register(aiNode); await blocResponseFuture(); @@ -327,10 +335,10 @@ void main() { final editorState = EditorState(document: document) ..selection = selection; final aiNode = editorState.getNodeAtPath([3])!; + registerMockRepository(_MockAIRepository()); final bloc = AiWriterCubit( documentId: '', editorState: editorState, - aiService: _MockAIRepository(), ); bloc.register(aiNode); await blocResponseFuture(); @@ -366,10 +374,10 @@ void main() { final editorState = EditorState(document: document) ..selection = selection; final aiNode = editorState.getNodeAtPath([3])!; + registerMockRepository(_MockAIRepositoryLess()); final bloc = AiWriterCubit( documentId: '', editorState: editorState, - aiService: _MockAIRepositoryLess(), ); bloc.register(aiNode); await blocResponseFuture(); @@ -403,10 +411,10 @@ void main() { final editorState = EditorState(document: document) ..selection = selection; final aiNode = editorState.getNodeAtPath([3])!; + registerMockRepository(_MockAIRepositoryMore()); final bloc = AiWriterCubit( documentId: '', editorState: editorState, - aiService: _MockAIRepositoryMore(), ); bloc.register(aiNode); await blocResponseFuture();