Initial work porting WebArena to async (#325)

* Initial work porting webarena to async

* Perhaps resolved some of the eval() issues.
This commit is contained in:
afourney 2024-08-06 16:25:18 -07:00 committed by GitHub
parent 027791c00b
commit 8b13d59b59
4 changed files with 34 additions and 24 deletions

View File

@ -7,12 +7,13 @@ import importlib
import json
import time
import urllib
import inspect
from pathlib import Path
from typing import Any, Tuple, Union, TypedDict, Dict
from beartype import beartype
from nltk.tokenize import word_tokenize # type: ignore
from playwright.sync_api import CDPSession, Page
from playwright.async_api import CDPSession, Page
import numpy as np
import numpy.typing as npt
@ -58,7 +59,7 @@ class Evaluator(object):
self.eval_tag = eval_tag
@beartype
def __call__(
async def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
@ -136,7 +137,7 @@ class StringEvaluator(Evaluator):
def ua_match(ref: str, pred: str, intent: str, azure_config: dict[str, Any] | None) -> float:
return llm_ua_match(pred, ref, intent, azure_config)
def __call__(
async def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
@ -192,7 +193,7 @@ class URLEvaluator(Evaluator):
"""Check URL matching"""
@beartype
def __call__(
async def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
@ -256,7 +257,7 @@ class HTMLContentEvaluator(Evaluator):
"""Check whether the contents appear in the page"""
@beartype
def __call__(
async def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
@ -276,12 +277,14 @@ class HTMLContentEvaluator(Evaluator):
func = target_url.split("func:")[1]
func = func.replace("__last_url__", page.url)
target_url = eval(func)
if inspect.isawaitable(target_url):
target_url = await target_url
locator: str = target["locator"] # js element locator
# navigate to that url
if target_url != "last":
page.goto(target_url)
await page.goto(target_url)
time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep
# empty, use the full page
@ -292,11 +295,11 @@ class HTMLContentEvaluator(Evaluator):
if "prep_actions" in target:
try:
for prep_action in target["prep_actions"]:
page.evaluate(f"() => {prep_action}")
await page.evaluate(f"() => {prep_action}")
except Exception:
pass
try:
selected_element = str(page.evaluate(f"() => {locator}"))
selected_element = str(await page.evaluate(f"() => {locator}"))
if not selected_element:
selected_element = ""
except Exception:
@ -307,6 +310,8 @@ class HTMLContentEvaluator(Evaluator):
func = locator.split("func:")[1]
func = func.replace("__page__", "page")
selected_element = eval(func)
if inspect.isawaitable(selected_element):
selected_element = await selected_element
else:
raise ValueError(f"Unknown locator: {locator}")
@ -344,7 +349,7 @@ class EvaluatorComb:
self.evaluators = evaluators
@beartype
def __call__(
async def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
@ -354,7 +359,7 @@ class EvaluatorComb:
) -> float:
score = 1.0
for evaluator in self.evaluators:
cur_score = evaluator(trajectory, config_file, page, client, azure_config)
cur_score = await evaluator(trajectory, config_file, page, client, azure_config)
score *= cur_score
return score

View File

@ -5,7 +5,7 @@ from typing import Any
from urllib.parse import urlparse
import requests
from playwright.sync_api import CDPSession, Page
from playwright.async_api import Page
from .env_config import (
ACCOUNTS,
@ -110,10 +110,10 @@ def reddit_get_post_url(url: str) -> str:
return post_url
def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
async def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
# get the account index
try:
account_idx = page.evaluate(
account_idx = await page.evaluate(
f"""(() => {{
const elements = document.querySelectorAll("td[data-label='Account'] span.gl-avatar-labeled-sublabel");
let index = -1; // Default value if not found
@ -130,7 +130,7 @@ def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str:
)
# get the role
role: str = page.evaluate(
role: str = await page.evaluate(
f"""(() => {{
return document.querySelectorAll("td.col-max-role span")[{account_idx}].outerText;
}})()"""

View File

@ -120,7 +120,6 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Create the AzureOpenAI client, with AAD auth
# token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
client = AzureOpenAIChatCompletionClient(
api_version="2024-02-15-preview",
azure_endpoint="https://aif-complex-tasks-west-us-3.openai.azure.com/",
@ -236,18 +235,17 @@ Once the user has taken the final necessary action to complete the task, and you
########## EVALUATION ##########
context = actual_surfer._context
page = actual_surfer._page
cdp_session = context.new_cdp_session(page)
cdp_session = await context.new_cdp_session(page)
config_file = "full_task.json"
evaluator = evaluation_harness.evaluator_router(config_file)
score = "N/A"
#score = evaluator(
# trajectory=evaluation_harness.make_answer_trajecotry(final_answer),
# config_file=config_file,
# page=page,
# client=cdp_session,
score = await evaluator(
trajectory=evaluation_harness.make_answer_trajecotry(final_answer),
config_file=config_file,
page=page,
client=cdp_session,
# azure_config=llm_config,
#)
)
print("FINAL SCORE: " + str(score))

View File

@ -5,7 +5,7 @@ from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components.models import AssistantMessage, LLMMessage, UserMessage
from agnext.core import AgentProxy, CancellationToken
from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage, ResetMessage
from ..utils import message_content_to_str
from .base_agent import TeamOneBaseAgent
@ -82,3 +82,10 @@ class BaseOrchestrator(TeamOneBaseAgent):
def get_max_rounds(self) -> int:
return self._max_rounds
async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
"""Handle a reset message."""
await self._reset(cancellation_token)
async def _reset(self, cancellation_token: CancellationToken) -> None:
pass