feat: restrict classification operators cannot be connected to Answer and other classification #918 (#1294)

### What problem does this PR solve?

feat: restrict classification operators cannot be connected to Answer
and other classification #918

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
balibabu 2024-06-27 14:57:40 +08:00 committed by GitHub
parent 0ce720a247
commit fbb8cbfc67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 193 additions and 42 deletions

View File

@ -20,9 +20,16 @@ import {
import { RagNode } from './node';
import ChatDrawer from '../chat/drawer';
import { isValidConnection } from '../utils';
import styles from './index.less';
import { BeginNode } from './node/begin-node';
import { CategorizeNode } from './node/categorize-node';
const nodeTypes = { ragNode: RagNode };
const nodeTypes = {
ragNode: RagNode,
categorizeNode: CategorizeNode,
beginNode: BeginNode,
};
const edgeTypes = {
buttonEdge: ButtonEdge,
@ -76,6 +83,7 @@ function FlowCanvas({ chatDrawerVisible, hideChatDrawer }: IProps) {
onKeyUp={handleKeyUp}
onSelectionChange={onSelectionChange}
nodeOrigin={[0.5, 0]}
isValidConnection={isValidConnection}
onChange={(...params) => {
console.info('params:', ...params);
}}

View File

@ -0,0 +1,39 @@
import { Flex, Space } from 'antd';
import classNames from 'classnames';
import { Handle, NodeProps, Position } from 'reactflow';
import { Operator } from '../../constant';
import { NodeData } from '../../interface';
import OperatorIcon from '../../operator-icon';
import NodeDropdown from './dropdown';
import styles from './index.less';
// TODO: do not allow other nodes to connect to this node
export function BeginNode({ id, data, selected }: NodeProps<NodeData>) {
return (
<section
className={classNames(styles.ragNode, {
[styles.selectedNode]: selected,
})}
>
<Handle
type="source"
position={Position.Right}
isConnectable
className={styles.handle}
></Handle>
<Flex vertical align="center" justify="center">
<Space size={6}>
<OperatorIcon
name={data.label as Operator}
fontSize={16}
></OperatorIcon>
<NodeDropdown id={id}></NodeDropdown>
</Space>
</Flex>
<section className={styles.bottomBox}>
<div className={styles.nodeName}>{id}</div>
</section>
</section>
);
}

View File

@ -0,0 +1,51 @@
import { Flex, Space } from 'antd';
import classNames from 'classnames';
import get from 'lodash/get';
import { Handle, NodeProps, Position } from 'reactflow';
import { CategorizeAnchorPointPositions, Operator } from '../../constant';
import { NodeData } from '../../interface';
import OperatorIcon from '../../operator-icon';
import CategorizeHandle from './categorize-handle';
import NodeDropdown from './dropdown';
import styles from './index.less';
export function CategorizeNode({ id, data, selected }: NodeProps<NodeData>) {
const categoryData = get(data, 'form.category_description') ?? {};
return (
<section
className={classNames(styles.ragNode, {
[styles.selectedNode]: selected,
})}
>
<Handle
type="target"
position={Position.Left}
isConnectable
className={styles.handle}
></Handle>
{Object.keys(categoryData).map((x, idx) => (
<CategorizeHandle
top={CategorizeAnchorPointPositions[idx].top}
right={CategorizeAnchorPointPositions[idx].right}
key={idx}
text={x}
idx={idx}
></CategorizeHandle>
))}
<Flex vertical align="center" justify="center">
<Space size={6}>
<OperatorIcon
name={data.label as Operator}
fontSize={16}
></OperatorIcon>
<NodeDropdown id={id}></NodeDropdown>
</Space>
</Flex>
<section className={styles.bottomBox}>
<div className={styles.nodeName}>{id}</div>
</section>
</section>
);
}

View File

@ -0,0 +1,47 @@
import OperateDropdown from '@/components/operate-dropdown';
import { CopyOutlined } from '@ant-design/icons';
import { Flex, MenuProps } from 'antd';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import useGraphStore from '../../store';
interface IProps {
id: string;
}
const NodeDropdown = ({ id }: IProps) => {
const { t } = useTranslation();
const deleteNodeById = useGraphStore((store) => store.deleteNodeById);
const duplicateNodeById = useGraphStore((store) => store.duplicateNode);
const deleteNode = useCallback(() => {
deleteNodeById(id);
}, [id, deleteNodeById]);
const duplicateNode = useCallback(() => {
duplicateNodeById(id);
}, [id, duplicateNodeById]);
const items: MenuProps['items'] = [
{
key: '2',
onClick: duplicateNode,
label: (
<Flex justify={'space-between'}>
{t('common.copy')}
<CopyOutlined />
</Flex>
),
},
];
return (
<OperateDropdown
iconFontSize={14}
deleteItem={deleteNode}
items={items}
></OperateDropdown>
);
};
export default NodeDropdown;

View File

@ -1,17 +1,13 @@
import classNames from 'classnames';
import { Handle, NodeProps, Position } from 'reactflow';
import OperateDropdown from '@/components/operate-dropdown';
import { CopyOutlined } from '@ant-design/icons';
import { Flex, MenuProps, Space } from 'antd';
import { Flex, Space } from 'antd';
import get from 'lodash/get';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { CategorizeAnchorPointPositions, Operator } from '../../constant';
import { NodeData } from '../../interface';
import OperatorIcon from '../../operator-icon';
import useGraphStore from '../../store';
import CategorizeHandle from './categorize-handle';
import NodeDropdown from './dropdown';
import styles from './index.less';
export function RagNode({
@ -20,34 +16,9 @@ export function RagNode({
isConnectable = true,
selected,
}: NodeProps<NodeData>) {
const { t } = useTranslation();
const deleteNodeById = useGraphStore((store) => store.deleteNodeById);
const duplicateNodeById = useGraphStore((store) => store.duplicateNode);
const deleteNode = useCallback(() => {
deleteNodeById(id);
}, [id, deleteNodeById]);
const duplicateNode = useCallback(() => {
duplicateNodeById(id);
}, [id, duplicateNodeById]);
const isCategorize = data.label === Operator.Categorize;
const categoryData = get(data, 'form.category_description') ?? {};
const items: MenuProps['items'] = [
{
key: '2',
onClick: duplicateNode,
label: (
<Flex justify={'space-between'}>
{t('common.copy')}
<CopyOutlined />
</Flex>
),
},
];
return (
<section
className={classNames(styles.ragNode, {
@ -86,11 +57,7 @@ export function RagNode({
name={data.label as Operator}
fontSize={16}
></OperatorIcon>
<OperateDropdown
iconFontSize={14}
deleteItem={deleteNode}
items={items}
></OperateDropdown>
<NodeDropdown id={id}></NodeDropdown>
</Space>
</Flex>

View File

@ -97,3 +97,27 @@ export const CategorizeAnchorPointPositions = [
{ top: 91, right: 20 },
{ top: 98, right: 34 },
];
// key is the source of the edge, value is the target of the edge
// no connection lines are allowed between key and value
export const RestrictedUpstreamMap = {
[Operator.Begin]: [
Operator.Begin,
Operator.Answer,
Operator.Categorize,
Operator.Generate,
Operator.Retrieval,
],
[Operator.Categorize]: [Operator.Begin, Operator.Categorize, Operator.Answer],
[Operator.Answer]: [],
[Operator.Retrieval]: [],
[Operator.Generate]: [],
};
export const NodeMap = {
[Operator.Begin]: 'beginNode',
[Operator.Categorize]: 'categorizeNode',
[Operator.Retrieval]: 'ragNode',
[Operator.Generate]: 'ragNode',
[Operator.Answer]: 'ragNode',
};

View File

@ -25,7 +25,7 @@ import { useDebounceEffect } from 'ahooks';
import { FormInstance } from 'antd';
import { humanId } from 'human-id';
import { useParams } from 'umi';
import { Operator } from './constant';
import { NodeMap, Operator } from './constant';
import useGraphStore, { RFState } from './store';
import { buildDslComponentsByGraph } from './utils';
@ -87,7 +87,7 @@ export const useHandleDrop = () => {
});
const newNode = {
id: `${type}:${humanId()}`,
type: 'ragNode',
type: NodeMap[type as Operator] || 'ragNode',
position: position || {
x: 0,
y: 0,

View File

@ -38,7 +38,7 @@ export const dsl = {
nodes: [
{
id: 'begin',
type: 'ragNode',
type: 'beginNode',
position: {
x: 50,
y: 200,

View File

@ -3,9 +3,13 @@ import { removeUselessFieldsFromValues } from '@/utils/form';
import dagre from 'dagre';
import { curry, isEmpty } from 'lodash';
import pipe from 'lodash/fp/pipe';
import { Edge, MarkerType, Node, Position } from 'reactflow';
import { Connection, Edge, MarkerType, Node, Position } from 'reactflow';
import { v4 as uuidv4 } from 'uuid';
import { Operator, initialFormValuesMap } from './constant';
import {
Operator,
RestrictedUpstreamMap,
initialFormValuesMap,
} from './constant';
import { NodeData } from './interface';
const buildEdges = (
@ -162,3 +166,14 @@ export const buildDslComponentsByGraph = (
return components;
};
export const getOperatorType = (id: string | null) => {
return id?.split(':')[0] as Operator | undefined;
};
// restricted lines cannot be connected successfully.
export const isValidConnection = (connection: Connection) => {
return RestrictedUpstreamMap[
getOperatorType(connection.source) as Operator
]?.every((x) => x !== getOperatorType(connection.target));
};