From d1272558811949b495ca834eba728accfc3bb5f2 Mon Sep 17 00:00:00 2001 From: Pavel Feldman Date: Fri, 15 Nov 2024 13:48:43 -0800 Subject: [PATCH] chore: add AriaSnapshot internal type (#33631) --- .../src/server/injected/ariaSnapshot.ts | 48 +++-- .../src/server/injected/injectedScript.ts | 24 ++- .../playwright-core/src/server/recorder.ts | 4 + .../src/server/recorder/chat.ts | 184 ++++++++++++++++++ .../src/server/recorder/contextRecorder.ts | 4 + .../playwright-core/src/utils/debugLogger.ts | 3 +- 6 files changed, 250 insertions(+), 17 deletions(-) create mode 100644 packages/playwright-core/src/server/recorder/chat.ts diff --git a/packages/playwright-core/src/server/injected/ariaSnapshot.ts b/packages/playwright-core/src/server/injected/ariaSnapshot.ts index b2352a1b9c..d541646f0e 100644 --- a/packages/playwright-core/src/server/injected/ariaSnapshot.ts +++ b/packages/playwright-core/src/server/injected/ariaSnapshot.ts @@ -20,15 +20,36 @@ import { escapeRegExp, longestCommonSubstring } from '@isomorphic/stringUtils'; import { yamlEscapeKeyIfNeeded, yamlEscapeValueIfNeeded } from './yaml'; import type { AriaProps, AriaRole, AriaTemplateNode, AriaTemplateRoleNode, AriaTemplateTextNode } from '@isomorphic/ariaSnapshot'; -type AriaNode = AriaProps & { +export type AriaNode = AriaProps & { role: AriaRole | 'fragment'; name: string; children: (AriaNode | string)[]; element: Element; }; -export function generateAriaTree(rootElement: Element): AriaNode { +export type AriaSnapshot = { + root: AriaNode; + elements: Map; + ids: Map; +}; + +export function generateAriaTree(rootElement: Element): AriaSnapshot { const visited = new Set(); + + const snapshot: AriaSnapshot = { + root: { role: 'fragment', name: '', children: [], element: rootElement }, + elements: new Map(), + ids: new Map(), + }; + + const addElement = (element: Element) => { + const id = snapshot.elements.size + 1; + snapshot.elements.set(id, element); + snapshot.ids.set(element, id); + }; + + addElement(rootElement); + const visit = (ariaNode: AriaNode, node: Node) => { if (visited.has(node)) return; @@ -58,6 +79,7 @@ export function generateAriaTree(rootElement: Element): AriaNode { } } + addElement(element); const childAriaNode = toAriaNode(element); if (childAriaNode) ariaNode.children.push(childAriaNode); @@ -100,15 +122,14 @@ export function generateAriaTree(rootElement: Element): AriaNode { } roleUtils.beginAriaCaches(); - const ariaRoot: AriaNode = { role: 'fragment', name: '', children: [], element: rootElement }; try { - visit(ariaRoot, rootElement); + visit(snapshot.root, rootElement); } finally { roleUtils.endAriaCaches(); } - normalizeStringChildren(ariaRoot); - return ariaRoot; + normalizeStringChildren(snapshot.root); + return snapshot; } function toAriaNode(element: Element): AriaNode | null { @@ -143,10 +164,6 @@ function toAriaNode(element: Element): AriaNode | null { return result; } -export function renderedAriaTree(rootElement: Element, options?: { mode?: 'raw' | 'regex' }): string { - return renderAriaTree(generateAriaTree(rootElement), options); -} - function normalizeStringChildren(rootA11yNode: AriaNode) { const flushChildren = (buffer: string[], normalizedChildren: (AriaNode | string)[]) => { if (!buffer.length) @@ -203,7 +220,7 @@ export type MatcherReceived = { }; export function matchesAriaTree(rootElement: Element, template: AriaTemplateNode): { matches: AriaNode[], received: MatcherReceived } { - const root = generateAriaTree(rootElement); + const root = generateAriaTree(rootElement).root; const matches = matchesNodeDeep(root, template, false); return { matches, @@ -215,7 +232,7 @@ export function matchesAriaTree(rootElement: Element, template: AriaTemplateNode } export function getAllByAria(rootElement: Element, template: AriaTemplateNode): Element[] { - const root = generateAriaTree(rootElement); + const root = generateAriaTree(rootElement).root; const matches = matchesNodeDeep(root, template, true); return matches.map(n => n.element); } @@ -285,7 +302,7 @@ function matchesNodeDeep(root: AriaNode, template: AriaTemplateNode, collectAll: return results; } -export function renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'regex' }): string { +export function renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'regex', ids?: Map }): string { const lines: string[] = []; const includeText = options?.mode === 'regex' ? textContributesInfo : () => true; const renderString = options?.mode === 'regex' ? convertToBestGuessRegex : (str: string) => str; @@ -324,6 +341,11 @@ export function renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'r key += ` [pressed]`; if (ariaNode.selected === true) key += ` [selected]`; + if (options?.ids) { + const id = options?.ids.get(ariaNode.element); + if (id) + key += ` [id=${id}]`; + } const escapedKey = indent + '- ' + yamlEscapeKeyIfNeeded(key); if (!ariaNode.children.length) { diff --git a/packages/playwright-core/src/server/injected/injectedScript.ts b/packages/playwright-core/src/server/injected/injectedScript.ts index d74a1c1482..7c235b700f 100644 --- a/packages/playwright-core/src/server/injected/injectedScript.ts +++ b/packages/playwright-core/src/server/injected/injectedScript.ts @@ -34,7 +34,8 @@ import { kLayoutSelectorNames, type LayoutSelectorName, layoutSelectorScore } fr import { asLocator } from '../../utils/isomorphic/locatorGenerators'; import type { Language } from '../../utils/isomorphic/locatorGenerators'; import { cacheNormalizedWhitespaces, normalizeWhiteSpace, trimStringWithEllipsis } from '../../utils/isomorphic/stringUtils'; -import { matchesAriaTree, renderedAriaTree, getAllByAria } from './ariaSnapshot'; +import { matchesAriaTree, getAllByAria, generateAriaTree, renderAriaTree } from './ariaSnapshot'; +import type { AriaNode, AriaSnapshot } from './ariaSnapshot'; import type { AriaTemplateNode } from '@isomorphic/ariaSnapshot'; import { parseYamlTemplate } from '@isomorphic/ariaSnapshot'; @@ -215,10 +216,27 @@ export class InjectedScript { return new Set(result.map(r => r.element)); } - ariaSnapshot(node: Node, options?: { mode?: 'raw' | 'regex' }): string { + ariaSnapshot(node: Node, options?: { mode?: 'raw' | 'regex', id?: boolean }): string { if (node.nodeType !== Node.ELEMENT_NODE) throw this.createStacklessError('Can only capture aria snapshot of Element nodes.'); - return renderedAriaTree(node as Element, options); + const ariaSnapshot = generateAriaTree(node as Element); + return renderAriaTree(ariaSnapshot.root, options); + } + + ariaSnapshotAsObject(node: Node): AriaSnapshot { + return generateAriaTree(node as Element); + } + + ariaSnapshotElement(snapshot: AriaSnapshot, elementId: number): Element | null { + return snapshot.elements.get(elementId) || null; + } + + renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'regex', id?: boolean}): string { + return renderAriaTree(ariaNode, options); + } + + renderAriaSnapshotWithIds(ariaSnapshot: AriaSnapshot): string { + return renderAriaTree(ariaSnapshot.root, { ids: ariaSnapshot.ids }); } getAllByAria(document: Document, template: AriaTemplateNode): Element[] { diff --git a/packages/playwright-core/src/server/recorder.ts b/packages/playwright-core/src/server/recorder.ts index 13dd3829b1..16f9d791e1 100644 --- a/packages/playwright-core/src/server/recorder.ts +++ b/packages/playwright-core/src/server/recorder.ts @@ -132,6 +132,10 @@ export class Recorder implements InstrumentationListener, IRecorder { this._contextRecorder.clearScript(); return; } + if (data.event === 'runTask') { + this._contextRecorder.runTask(data.params.task); + return; + } }); await Promise.all([ diff --git a/packages/playwright-core/src/server/recorder/chat.ts b/packages/playwright-core/src/server/recorder/chat.ts new file mode 100644 index 0000000000..5b3917c735 --- /dev/null +++ b/packages/playwright-core/src/server/recorder/chat.ts @@ -0,0 +1,184 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { WebSocketTransport } from '../transport'; +import type { ConnectionTransport, ProtocolResponse } from '../transport'; + +export type ChatMessage = { + content: string; + user: 'user' | 'assistant'; +}; + +export class Chat { + private _history: ChatMessage[] = []; + private _connectionPromise: Promise | undefined; + private _chatSinks = new Map void>(); + private _wsEndpoint: string; + + constructor(wsEndpoint: string) { + this._wsEndpoint = wsEndpoint; + } + + clearHistory() { + this._history = []; + } + + async post(prompt: string): Promise { + await this._append('user', prompt); + let text = await asString(await this._post()); + if (text.startsWith('```json') && text.endsWith('```')) + text = text.substring('```json'.length, text.length - '```'.length); + for (let i = 0; i < 3; ++i) { + try { + return JSON.parse(text); + } catch (e) { + await this._append('user', String(e)); + } + } + throw new Error('Failed to parse response: ' + text); + } + + private async _append(user: ChatMessage['user'], content: string) { + this._history.push({ user, content }); + } + + private async _connection(): Promise { + if (!this._connectionPromise) { + this._connectionPromise = WebSocketTransport.connect(undefined, this._wsEndpoint).then(transport => { + return new Connection(transport, (method, params) => this._dispatchEvent(method, params), () => {}); + }); + } + return this._connectionPromise; + } + + private _dispatchEvent(method: string, params: any) { + if (method === 'chatChunk') { + const { chatId, chunk } = params; + const chunkSink = this._chatSinks.get(chatId)!; + chunkSink(chunk); + if (!chunk) + this._chatSinks.delete(chatId); + } + } + + private async _post(): Promise> { + const connection = await this._connection(); + const result = await connection.send('chat', { history: this._history }); + const { chatId } = result; + const { iterable, addChunk } = iterablePump(); + this._chatSinks.set(chatId, addChunk); + return iterable; + } +} + +export async function asString(stream: AsyncIterable): Promise { + let result = ''; + for await (const chunk of stream) + result += chunk; + return result; +} + +type ChunkIterator = { + iterable: AsyncIterable; + addChunk: (chunk: string) => void; +}; + +function iterablePump(): ChunkIterator { + let controller: ReadableStreamDefaultController; + const stream = new ReadableStream({ start: c => controller = c }); + + const iterable = (async function* () { + const reader = stream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) + break; + yield value!; + } + })(); + + return { + iterable, + addChunk: chunk => { + if (chunk) + controller.enqueue(chunk); + else + controller.close(); + } + }; +} + +class Connection { + private readonly _transport: ConnectionTransport; + private _lastId = 0; + private _closed = false; + private _pending = new Map void; reject: (error: any) => void; }>(); + private _onEvent: (method: string, params: any) => void; + private _onClose: () => void; + + constructor(transport: ConnectionTransport, onEvent: (method: string, params: any) => void, onClose: () => void) { + this._transport = transport; + this._onEvent = onEvent; + this._onClose = onClose; + this._transport.onmessage = this._dispatchMessage.bind(this); + this._transport.onclose = this._close.bind(this); + } + + send(method: string, params: any): Promise { + const id = this._lastId++; + const message = { id, method, params }; + this._transport.send(message); + return new Promise((resolve, reject) => { + this._pending.set(id, { resolve, reject }); + }); + } + + private _dispatchMessage(message: ProtocolResponse) { + if (message.id === undefined) { + this._onEvent(message.method!, message.params); + return; + } + + const callback = this._pending.get(message.id); + this._pending.delete(message.id); + if (!callback) + return; + + if (message.error) { + callback.reject(new Error(message.error.message)); + return; + } + callback.resolve(message.result); + } + + _close() { + this._closed = true; + this._transport.onmessage = undefined; + this._transport.onclose = undefined; + for (const { reject } of this._pending.values()) + reject(new Error('Connection closed')); + this._onClose(); + } + + isClosed() { + return this._closed; + } + + close() { + if (!this._closed) + this._transport.close(); + } +} diff --git a/packages/playwright-core/src/server/recorder/contextRecorder.ts b/packages/playwright-core/src/server/recorder/contextRecorder.ts index 933a036233..d7a3c908e8 100644 --- a/packages/playwright-core/src/server/recorder/contextRecorder.ts +++ b/packages/playwright-core/src/server/recorder/contextRecorder.ts @@ -208,6 +208,10 @@ export class ContextRecorder extends EventEmitter { } } + runTask(task: string): void { + // TODO: implement + } + private _describeMainFrame(page: Page): actions.FrameDescription { return { pageAlias: this._pageAliases.get(page)!, diff --git a/packages/playwright-core/src/utils/debugLogger.ts b/packages/playwright-core/src/utils/debugLogger.ts index a5196da896..d50180a2ed 100644 --- a/packages/playwright-core/src/utils/debugLogger.ts +++ b/packages/playwright-core/src/utils/debugLogger.ts @@ -29,7 +29,8 @@ const debugLoggerColorMap = { 'channel': 33, // blue 'server': 45, // cyan 'server:channel': 34, // green - 'server:metadata': 33, // blue + 'server:metadata': 33, // blue, + 'recorder': 45, // cyan }; export type LogName = keyof typeof debugLoggerColorMap;