diff --git a/web/src/pages/flow/canvas/index.tsx b/web/src/pages/flow/canvas/index.tsx index b3010995c..b1f4864e9 100644 --- a/web/src/pages/flow/canvas/index.tsx +++ b/web/src/pages/flow/canvas/index.tsx @@ -20,9 +20,16 @@ import { import { RagNode } from './node'; import ChatDrawer from '../chat/drawer'; +import { isValidConnection } from '../utils'; import styles from './index.less'; +import { BeginNode } from './node/begin-node'; +import { CategorizeNode } from './node/categorize-node'; -const nodeTypes = { ragNode: RagNode }; +const nodeTypes = { + ragNode: RagNode, + categorizeNode: CategorizeNode, + beginNode: BeginNode, +}; const edgeTypes = { buttonEdge: ButtonEdge, @@ -76,6 +83,7 @@ function FlowCanvas({ chatDrawerVisible, hideChatDrawer }: IProps) { onKeyUp={handleKeyUp} onSelectionChange={onSelectionChange} nodeOrigin={[0.5, 0]} + isValidConnection={isValidConnection} onChange={(...params) => { console.info('params:', ...params); }} diff --git a/web/src/pages/flow/canvas/node/begin-node.tsx b/web/src/pages/flow/canvas/node/begin-node.tsx new file mode 100644 index 000000000..badcfcd56 --- /dev/null +++ b/web/src/pages/flow/canvas/node/begin-node.tsx @@ -0,0 +1,39 @@ +import { Flex, Space } from 'antd'; +import classNames from 'classnames'; +import { Handle, NodeProps, Position } from 'reactflow'; +import { Operator } from '../../constant'; +import { NodeData } from '../../interface'; +import OperatorIcon from '../../operator-icon'; +import NodeDropdown from './dropdown'; + +import styles from './index.less'; + +// TODO: do not allow other nodes to connect to this node +export function BeginNode({ id, data, selected }: NodeProps) { + return ( +
+ + + + + + + +
+
{id}
+
+
+ ); +} diff --git a/web/src/pages/flow/canvas/node/categorize-node.tsx b/web/src/pages/flow/canvas/node/categorize-node.tsx new file mode 100644 index 000000000..397adfe75 --- /dev/null +++ b/web/src/pages/flow/canvas/node/categorize-node.tsx @@ -0,0 +1,51 @@ +import { Flex, Space } from 'antd'; +import classNames from 'classnames'; +import get from 'lodash/get'; +import { Handle, NodeProps, Position } from 'reactflow'; +import { CategorizeAnchorPointPositions, Operator } from '../../constant'; +import { NodeData } from '../../interface'; +import OperatorIcon from '../../operator-icon'; +import CategorizeHandle from './categorize-handle'; +import NodeDropdown from './dropdown'; + +import styles from './index.less'; + +export function CategorizeNode({ id, data, selected }: NodeProps) { + const categoryData = get(data, 'form.category_description') ?? {}; + + return ( +
+ + {Object.keys(categoryData).map((x, idx) => ( + + ))} + + + + + + +
+
{id}
+
+
+ ); +} diff --git a/web/src/pages/flow/canvas/node/dropdown.tsx b/web/src/pages/flow/canvas/node/dropdown.tsx new file mode 100644 index 000000000..632a60583 --- /dev/null +++ b/web/src/pages/flow/canvas/node/dropdown.tsx @@ -0,0 +1,47 @@ +import OperateDropdown from '@/components/operate-dropdown'; +import { CopyOutlined } from '@ant-design/icons'; +import { Flex, MenuProps } from 'antd'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import useGraphStore from '../../store'; + +interface IProps { + id: string; +} + +const NodeDropdown = ({ id }: IProps) => { + const { t } = useTranslation(); + const deleteNodeById = useGraphStore((store) => store.deleteNodeById); + const duplicateNodeById = useGraphStore((store) => store.duplicateNode); + + const deleteNode = useCallback(() => { + deleteNodeById(id); + }, [id, deleteNodeById]); + + const duplicateNode = useCallback(() => { + duplicateNodeById(id); + }, [id, duplicateNodeById]); + + const items: MenuProps['items'] = [ + { + key: '2', + onClick: duplicateNode, + label: ( + + {t('common.copy')} + + + ), + }, + ]; + + return ( + + ); +}; + +export default NodeDropdown; diff --git a/web/src/pages/flow/canvas/node/index.tsx b/web/src/pages/flow/canvas/node/index.tsx index 5d0fc1bda..67336c37c 100644 --- a/web/src/pages/flow/canvas/node/index.tsx +++ b/web/src/pages/flow/canvas/node/index.tsx @@ -1,17 +1,13 @@ import classNames from 'classnames'; import { Handle, NodeProps, Position } from 'reactflow'; -import OperateDropdown from '@/components/operate-dropdown'; -import { CopyOutlined } from '@ant-design/icons'; -import { Flex, MenuProps, Space } from 'antd'; +import { Flex, Space } from 'antd'; import get from 'lodash/get'; -import { useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; import { CategorizeAnchorPointPositions, Operator } from '../../constant'; import { NodeData } from '../../interface'; import OperatorIcon from '../../operator-icon'; -import useGraphStore from '../../store'; import CategorizeHandle from './categorize-handle'; +import NodeDropdown from './dropdown'; import styles from './index.less'; export function RagNode({ @@ -20,34 +16,9 @@ export function RagNode({ isConnectable = true, selected, }: NodeProps) { - const { t } = useTranslation(); - const deleteNodeById = useGraphStore((store) => store.deleteNodeById); - const duplicateNodeById = useGraphStore((store) => store.duplicateNode); - - const deleteNode = useCallback(() => { - deleteNodeById(id); - }, [id, deleteNodeById]); - - const duplicateNode = useCallback(() => { - duplicateNodeById(id); - }, [id, duplicateNodeById]); - const isCategorize = data.label === Operator.Categorize; const categoryData = get(data, 'form.category_description') ?? {}; - const items: MenuProps['items'] = [ - { - key: '2', - onClick: duplicateNode, - label: ( - - {t('common.copy')} - - - ), - }, - ]; - return (
- + diff --git a/web/src/pages/flow/constant.tsx b/web/src/pages/flow/constant.tsx index fc10b1345..cfe8a0ce0 100644 --- a/web/src/pages/flow/constant.tsx +++ b/web/src/pages/flow/constant.tsx @@ -97,3 +97,27 @@ export const CategorizeAnchorPointPositions = [ { top: 91, right: 20 }, { top: 98, right: 34 }, ]; + +// key is the source of the edge, value is the target of the edge +// no connection lines are allowed between key and value +export const RestrictedUpstreamMap = { + [Operator.Begin]: [ + Operator.Begin, + Operator.Answer, + Operator.Categorize, + Operator.Generate, + Operator.Retrieval, + ], + [Operator.Categorize]: [Operator.Begin, Operator.Categorize, Operator.Answer], + [Operator.Answer]: [], + [Operator.Retrieval]: [], + [Operator.Generate]: [], +}; + +export const NodeMap = { + [Operator.Begin]: 'beginNode', + [Operator.Categorize]: 'categorizeNode', + [Operator.Retrieval]: 'ragNode', + [Operator.Generate]: 'ragNode', + [Operator.Answer]: 'ragNode', +}; diff --git a/web/src/pages/flow/hooks.ts b/web/src/pages/flow/hooks.ts index bddac12ef..ed41e7b1d 100644 --- a/web/src/pages/flow/hooks.ts +++ b/web/src/pages/flow/hooks.ts @@ -25,7 +25,7 @@ import { useDebounceEffect } from 'ahooks'; import { FormInstance } from 'antd'; import { humanId } from 'human-id'; import { useParams } from 'umi'; -import { Operator } from './constant'; +import { NodeMap, Operator } from './constant'; import useGraphStore, { RFState } from './store'; import { buildDslComponentsByGraph } from './utils'; @@ -87,7 +87,7 @@ export const useHandleDrop = () => { }); const newNode = { id: `${type}:${humanId()}`, - type: 'ragNode', + type: NodeMap[type as Operator] || 'ragNode', position: position || { x: 0, y: 0, diff --git a/web/src/pages/flow/mock.tsx b/web/src/pages/flow/mock.tsx index 5b5fda716..af0129ba5 100644 --- a/web/src/pages/flow/mock.tsx +++ b/web/src/pages/flow/mock.tsx @@ -38,7 +38,7 @@ export const dsl = { nodes: [ { id: 'begin', - type: 'ragNode', + type: 'beginNode', position: { x: 50, y: 200, diff --git a/web/src/pages/flow/utils.ts b/web/src/pages/flow/utils.ts index d84360bd5..f6f50c140 100644 --- a/web/src/pages/flow/utils.ts +++ b/web/src/pages/flow/utils.ts @@ -3,9 +3,13 @@ import { removeUselessFieldsFromValues } from '@/utils/form'; import dagre from 'dagre'; import { curry, isEmpty } from 'lodash'; import pipe from 'lodash/fp/pipe'; -import { Edge, MarkerType, Node, Position } from 'reactflow'; +import { Connection, Edge, MarkerType, Node, Position } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; -import { Operator, initialFormValuesMap } from './constant'; +import { + Operator, + RestrictedUpstreamMap, + initialFormValuesMap, +} from './constant'; import { NodeData } from './interface'; const buildEdges = ( @@ -162,3 +166,14 @@ export const buildDslComponentsByGraph = ( return components; }; + +export const getOperatorType = (id: string | null) => { + return id?.split(':')[0] as Operator | undefined; +}; + +// restricted lines cannot be connected successfully. +export const isValidConnection = (connection: Connection) => { + return RestrictedUpstreamMap[ + getOperatorType(connection.source) as Operator + ]?.every((x) => x !== getOperatorType(connection.target)); +};