chore: implement ai writer history (#7523)

* chore: implement ai writer history

* chore: pass hitosyr
This commit is contained in:
Richard Shiue 2025-03-18 17:14:20 +08:00 committed by GitHub
parent e3ea3fcdfa
commit 22b03eee29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 121 additions and 7 deletions

View File

@ -3,6 +3,7 @@ import 'dart:ffi';
import 'dart:isolate';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/ai/operations/ai_writer_entities.dart';
import 'package:appflowy/shared/list_extension.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/log.dart';
@ -19,6 +20,8 @@ abstract class AIRepository {
String? objectId,
required String text,
PredefinedFormat? format,
List<String> sourceIds = const [],
List<AiWriterRecord> history = const [],
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
@ -34,6 +37,7 @@ class AppFlowyAIService implements AIRepository {
required String text,
PredefinedFormat? format,
List<String> sourceIds = const [],
List<AiWriterRecord> history = const [],
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
@ -47,6 +51,8 @@ class AppFlowyAIService implements AIRepository {
onError: onError,
);
final records = history.map((record) => record.toPB()).toList();
final payload = CompleteTextPB(
text: text,
completionType: completionType,
@ -57,6 +63,7 @@ class AppFlowyAIService implements AIRepository {
if (objectId != null) objectId,
...sourceIds,
].unique(),
history: records,
);
return AIEventCompleteText(payload).send().fold(

View File

@ -44,6 +44,7 @@ class AiWriterCubit extends Cubit<AiWriterState> {
final AppFlowyAIService _aiService;
final MarkdownTextRobot _textRobot;
final List<AiWriterRecord> records = [];
final ValueNotifier<List<String>> selectedSourcesNotifier;
(String, PredefinedFormat?)? _previousPrompt;
bool acceptReplacesOriginal = false;
@ -66,6 +67,7 @@ class AiWriterCubit extends Cubit<AiWriterState> {
) async {
final command = AiWriterCommand.userQuestion;
final node = getAiWriterNode();
_previousPrompt = (prompt, format);
final stream = await _aiService.streamCompletion(
@ -74,6 +76,7 @@ class AiWriterCubit extends Cubit<AiWriterState> {
format: format,
sourceIds: selectedSourcesNotifier.value,
completionType: command.toCompletionType(),
history: records,
onStart: () async {
final transaction = editorState.transaction;
final position =
@ -87,6 +90,9 @@ class AiWriterCubit extends Cubit<AiWriterState> {
),
);
_textRobot.start(position: position);
records.add(
AiWriterRecord.user(content: prompt),
);
},
onProcess: (text) async {
await _textRobot.appendMarkdownText(
@ -99,9 +105,15 @@ class AiWriterCubit extends Cubit<AiWriterState> {
attributes: ApplySuggestionFormatType.replace.attributes,
);
emit(ReadyAiWriterState(command, isFirstRun: false));
records.add(
AiWriterRecord.ai(content: _textRobot.markdownText),
);
},
onError: (error) async {
emit(ErrorAiWriterState(state.command, error: error));
records.add(
AiWriterRecord.ai(content: _textRobot.markdownText),
);
},
);
@ -337,6 +349,7 @@ class AiWriterCubit extends Cubit<AiWriterState> {
objectId: documentId,
text: text,
completionType: command.toCompletionType(),
history: records,
onStart: () async {
final transaction = editorState.transaction;
final position =
@ -364,9 +377,15 @@ class AiWriterCubit extends Cubit<AiWriterState> {
);
emit(ReadyAiWriterState(command, isFirstRun: false));
}
records.add(
AiWriterRecord.ai(content: _textRobot.markdownText),
);
},
onError: (error) async {
emit(ErrorAiWriterState(command, error: error));
records.add(
AiWriterRecord.ai(content: _textRobot.markdownText),
);
},
);
if (stream != null) {
@ -392,6 +411,7 @@ class AiWriterCubit extends Cubit<AiWriterState> {
objectId: documentId,
text: await editorState.getMarkdownInSelection(selection),
completionType: command.toCompletionType(),
history: records,
onStart: () async {
final transaction = editorState.transaction;
formatSelection(
@ -429,10 +449,16 @@ class AiWriterCubit extends Cubit<AiWriterState> {
isFirstRun: false,
),
);
records.add(
AiWriterRecord.ai(content: _textRobot.markdownText),
);
}
},
onError: (error) async {
emit(ErrorAiWriterState(command, error: error));
records.add(
AiWriterRecord.ai(content: _textRobot.markdownText),
);
},
);
if (stream != null) {
@ -456,6 +482,7 @@ class AiWriterCubit extends Cubit<AiWriterState> {
objectId: documentId,
text: await editorState.getMarkdownInSelection(selection),
completionType: command.toCompletionType(),
history: records,
onStart: () async {},
onProcess: (text) async {
if (state case final GeneratingAiWriterState generatingState) {
@ -477,9 +504,17 @@ class AiWriterCubit extends Cubit<AiWriterState> {
markdownText: generatingState.markdownText,
),
);
records.add(
AiWriterRecord.ai(content: generatingState.markdownText),
);
}
},
onError: (error) async {
if (state case final GeneratingAiWriterState generatingState) {
records.add(
AiWriterRecord.ai(content: generatingState.markdownText),
);
}
emit(ErrorAiWriterState(command, error: error));
},
);

View File

@ -2,6 +2,7 @@ import 'package:appflowy/generated/flowy_svgs.g.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:equatable/equatable.dart';
import 'package:flutter/material.dart';
import '../ai_writer_block_component.dart';
@ -120,3 +121,40 @@ enum ApplySuggestionFormatType {
Map<String, dynamic> get attributes => {AiWriterBlockKeys.suggestion: value};
}
enum AiRole {
user,
system,
ai,
}
class AiWriterRecord extends Equatable {
const AiWriterRecord({
required this.role,
required this.content,
});
const AiWriterRecord.user({
required this.content,
}) : role = AiRole.user;
const AiWriterRecord.ai({
required this.content,
}) : role = AiRole.ai;
final AiRole role;
final String content;
@override
List<Object> get props => [role, content];
CompletionRecordPB toPB() {
return CompletionRecordPB(
content: content,
role: switch (role) {
AiRole.user => ChatMessageTypePB.User,
AiRole.system || AiRole.ai => ChatMessageTypePB.System,
},
);
}
}

View File

@ -30,6 +30,8 @@ class MarkdownTextRobot {
bool get hasAnyResult => _markdownText.isNotEmpty;
String get markdownText => _markdownText;
Selection? getInsertedSelection() {
final position = _insertPosition;
if (position == null) {

View File

@ -23,6 +23,7 @@ class _MockAIRepository extends Mock implements AppFlowyAIService {
required String text,
PredefinedFormat? format,
List<String> sourceIds = const [],
List<AiWriterRecord> history = const [],
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
@ -53,6 +54,7 @@ class _MockAIRepositoryLess extends Mock implements AppFlowyAIService {
required String text,
PredefinedFormat? format,
List<String> sourceIds = const [],
List<AiWriterRecord> history = const [],
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
@ -79,6 +81,7 @@ class _MockAIRepositoryMore extends Mock implements AppFlowyAIService {
required String text,
PredefinedFormat? format,
List<String> sourceIds = const [],
List<AiWriterRecord> history = const [],
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
@ -107,6 +110,7 @@ class _MockErrorRepository extends Mock implements AppFlowyAIService {
required String text,
PredefinedFormat? format,
List<String> sourceIds = const [],
List<AiWriterRecord> history = const [],
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,

View File

@ -100,6 +100,7 @@ impl EventIntegrationTest {
object_id: "".to_string(),
rag_ids: vec![],
format: None,
history: vec![],
};
EventBuilder::new(self.clone())
.event(AIEvent::CompleteText)

View File

@ -1,8 +1,8 @@
use bytes::Bytes;
pub use client_api::entity::ai_dto::{
AppFlowyOfflineAI, CompleteTextParams, CompletionMetadata, CompletionType, CreateChatContext,
LLMModel, LocalAIConfig, ModelInfo, ModelList, OutputContent, OutputLayout, RelatedQuestion,
RepeatedRelatedQuestion, ResponseFormat, StringOrMessage,
AppFlowyOfflineAI, CompleteTextParams, CompletionMetadata, CompletionRecord, CompletionType,
CreateChatContext, LLMModel, LocalAIConfig, ModelInfo, ModelList, OutputContent, OutputLayout,
RelatedQuestion, RepeatedRelatedQuestion, ResponseFormat, StringOrMessage,
};
pub use client_api::entity::billing_dto::SubscriptionPlan;
pub use client_api::entity::chat_dto::{

View File

@ -98,6 +98,8 @@ impl CompletionTask {
};
let _ = sink.send("start:".to_string()).await;
let completion_history = Some(self.context.history.iter().map(Into::into).collect());
let format = self.context.format.map(Into::into).unwrap_or_default();
let params = CompleteTextParams {
text: self.context.text,
completion_type: Some(complete_type),
@ -106,9 +108,9 @@ impl CompletionTask {
object_id: self.context.object_id,
workspace_id: Some(self.workspace_id.clone()),
rag_ids: Some(self.context.rag_ids),
completion_history: None,
completion_history,
}),
format: self.context.format.map(Into::into).unwrap_or_default(),
format,
};
info!("start completion: {:?}", params);

View File

@ -4,8 +4,8 @@ use std::collections::HashMap;
use crate::local_ai::controller::LocalAISetting;
use crate::local_ai::resource::PendingResource;
use flowy_ai_pub::cloud::{
ChatMessage, ChatMessageMetadata, ChatMessageType, LLMModel, OutputContent, OutputLayout,
RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion, ResponseFormat,
ChatMessage, ChatMessageMetadata, ChatMessageType, CompletionRecord, LLMModel, OutputContent,
OutputLayout, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion, ResponseFormat,
};
use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
use lib_infra::validator_fn::required_not_empty_str;
@ -358,6 +358,9 @@ pub struct CompleteTextPB {
#[pb(index = 6)]
pub rag_ids: Vec<String>,
#[pb(index = 7)]
pub history: Vec<CompletionRecordPB>,
}
#[derive(Default, ProtoBuf, Clone, Debug)]
@ -378,6 +381,28 @@ pub enum CompletionTypePB {
MakeLonger = 6,
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct CompletionRecordPB {
#[pb(index = 1)]
pub role: ChatMessageTypePB,
#[pb(index = 2)]
pub content: String,
}
impl From<&CompletionRecordPB> for CompletionRecord {
fn from(value: &CompletionRecordPB) -> Self {
CompletionRecord {
role: match value.role {
// Coerce ChatMessageTypePB::System to AI
ChatMessageTypePB::System => "ai".to_string(),
ChatMessageTypePB::User => "human".to_string(),
},
content: value.content.clone(),
}
}
}
#[derive(Default, ProtoBuf, Clone, Debug)]
pub struct ChatStatePB {
#[pb(index = 1)]