{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SQL Agent for Spider text-to-SQL benchmark" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates a basic SQL agent that translates natural language questions into SQL queries." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment\n", "\n", "For this demo, we use a SQLite database environment based on a standard text-to-sql benchmark called [Spider](https://yale-lily.github.io/spider). The environment provides a gym-like interface and can be used as follows." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading cached Spider dataset from /home/wangdazhang/.cache/spider\n", "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/flight_4\n", "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/small_bank_1\n", "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/icfp_1\n", "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/twitter_1\n", "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/epinions_1\n", "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/chinook_1\n", "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/company_1\n" ] } ], "source": [ "# %pip install spider-env\n", "import json\n", "import os\n", "from typing import Annotated, Dict\n", "\n", "from spider_env import SpiderEnv\n", "\n", "from autogen import ConversableAgent, UserProxyAgent, config_list_from_json\n", "\n", "gym = SpiderEnv()\n", "\n", "# Randomly select a question from Spider\n", "observation, info = gym.reset()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Find the famous titles of artists that do not have any volume.\n" ] } ], "source": [ "# The natural language question\n", "question = observation[\"instruction\"]\n", "print(question)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CREATE TABLE \"artist\" (\n", "\"Artist_ID\" int,\n", "\"Artist\" text,\n", "\"Age\" int,\n", "\"Famous_Title\" text,\n", "\"Famous_Release_date\" text,\n", "PRIMARY KEY (\"Artist_ID\")\n", ");\n", "CREATE TABLE \"volume\" (\n", "\"Volume_ID\" int,\n", "\"Volume_Issue\" text,\n", "\"Issue_Date\" text,\n", "\"Weeks_on_Top\" real,\n", "\"Song\" text,\n", "\"Artist_ID\" int,\n", "PRIMARY KEY (\"Volume_ID\"),\n", "FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n", ");\n", "CREATE TABLE \"music_festival\" (\n", "\"ID\" int,\n", "\"Music_Festival\" text,\n", "\"Date_of_ceremony\" text,\n", "\"Category\" text,\n", "\"Volume\" int,\n", "\"Result\" text,\n", "PRIMARY KEY (\"ID\"),\n", "FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n", ");\n", "\n" ] } ], "source": [ "# The schema of the corresponding database\n", "schema = info[\"schema\"]\n", "print(schema)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Agent Implementation\n", "\n", "Using AutoGen, a SQL agent can be implemented with a ConversableAgent. The gym environment executes the generated SQL query and the agent can take execution results as feedback to improve its generation in multiple rounds of conversations." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "os.environ[\"AUTOGEN_USE_DOCKER\"] = \"False\"\n", "config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n", "\n", "\n", "def check_termination(msg: Dict):\n", " if \"tool_responses\" not in msg:\n", " return False\n", " json_str = msg[\"tool_responses\"][0][\"content\"]\n", " obj = json.loads(json_str)\n", " return \"error\" not in obj or obj[\"error\"] is None and obj[\"reward\"] == 1\n", "\n", "\n", "sql_writer = ConversableAgent(\n", " \"sql_writer\",\n", " llm_config={\"config_list\": config_list},\n", " system_message=\"You are good at writing SQL queries. Always respond with a function call to execute_sql().\",\n", " is_termination_msg=check_termination,\n", ")\n", "user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=5)\n", "\n", "\n", "@sql_writer.register_for_llm(description=\"Function for executing SQL query and returning a response\")\n", "@user_proxy.register_for_execution()\n", "def execute_sql(\n", " reflection: Annotated[str, \"Think about what to do\"], sql: Annotated[str, \"SQL query\"]\n", ") -> Annotated[Dict[str, str], \"Dictionary with keys 'result' and 'error'\"]:\n", " observation, reward, _, _, info = gym.step(sql)\n", " error = observation[\"feedback\"][\"error\"]\n", " if not error and reward == 0:\n", " error = \"The SQL query returned an incorrect result\"\n", " if error:\n", " return {\n", " \"error\": error,\n", " \"wrong_result\": observation[\"feedback\"][\"result\"],\n", " \"correct_result\": info[\"gold_result\"],\n", " }\n", " else:\n", " return {\n", " \"result\": observation[\"feedback\"][\"result\"],\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The agent can then take as input the schema and the text question, and generate the SQL query." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", "\n", "Below is the schema for a SQL database:\n", "CREATE TABLE \"artist\" (\n", "\"Artist_ID\" int,\n", "\"Artist\" text,\n", "\"Age\" int,\n", "\"Famous_Title\" text,\n", "\"Famous_Release_date\" text,\n", "PRIMARY KEY (\"Artist_ID\")\n", ");\n", "CREATE TABLE \"volume\" (\n", "\"Volume_ID\" int,\n", "\"Volume_Issue\" text,\n", "\"Issue_Date\" text,\n", "\"Weeks_on_Top\" real,\n", "\"Song\" text,\n", "\"Artist_ID\" int,\n", "PRIMARY KEY (\"Volume_ID\"),\n", "FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n", ");\n", "CREATE TABLE \"music_festival\" (\n", "\"ID\" int,\n", "\"Music_Festival\" text,\n", "\"Date_of_ceremony\" text,\n", "\"Category\" text,\n", "\"Volume\" int,\n", "\"Result\" text,\n", "PRIMARY KEY (\"ID\"),\n", "FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n", ");\n", "\n", "Generate a SQL query to answer the following question:\n", "Find the famous titles of artists that do not have any volume.\n", "\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[31m\n", ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", "\u001b[33msql_writer\u001b[0m (to user_proxy):\n", "\n", "\u001b[32m***** Suggested tool Call (call_eAu0OEzS8l3QvN3jQSn4w0hJ): execute_sql *****\u001b[0m\n", "Arguments: \n", "{\"reflection\":\"Generating SQL to find famous titles of artists without any volume\",\"sql\":\"SELECT a.Artist, a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n", "\u001b[32m****************************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[35m\n", ">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n", "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", "\n", "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", "\n", "\u001b[32m***** Response from calling tool \"call_eAu0OEzS8l3QvN3jQSn4w0hJ\" *****\u001b[0m\n", "{\"error\": \"The SQL query returned an incorrect result\", \"wrong_result\": [[\"Ophiolatry\", \"Antievangelistical Process (re-release)\"], [\"Triumfall\", \"Antithesis of All Flesh\"]], \"correct_result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[31m\n", ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", "\u001b[33msql_writer\u001b[0m (to user_proxy):\n", "\n", "\u001b[32m***** Suggested tool Call (call_5LXoKqdZ17kPCOHJbbpSz2yk): execute_sql *****\u001b[0m\n", "Arguments: \n", "{\"reflection\":\"Adjusting SQL to only select famous titles and exclude artist names for artists without any volume.\",\"sql\":\"SELECT a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n", "\u001b[32m****************************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[35m\n", ">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n", "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", "\n", "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", "\n", "\u001b[32m***** Response from calling tool \"call_5LXoKqdZ17kPCOHJbbpSz2yk\" *****\u001b[0m\n", "{\"result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[31m\n", ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" ] } ], "source": [ "message = f\"\"\"Below is the schema for a SQL database:\n", "{schema}\n", "Generate a SQL query to answer the following question:\n", "{question}\n", "\"\"\"\n", "\n", "user_proxy.initiate_chat(sql_writer, message=message)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }