mirror of
https://github.com/run-llama/llama-hub.git
synced 2025-08-15 04:01:32 +00:00
96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
![]() |
"""Database Reader."""
|
||
|
|
||
|
from typing import Any, List, Optional
|
||
|
|
||
|
from sqlalchemy import text
|
||
|
from sqlalchemy.engine import Engine
|
||
|
|
||
|
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
|
||
|
from gpt_index.readers.base import BaseReader
|
||
|
from gpt_index.readers.schema.base import Document
|
||
|
|
||
|
|
||
|
class DatabaseReader(BaseReader):
|
||
|
"""Simple Database reader.
|
||
|
|
||
|
Concatenates each row into Document used by GPT Index.
|
||
|
|
||
|
Args:
|
||
|
sql_database (Optional[SQLDatabase]): SQL database to use,
|
||
|
including table names to specify.
|
||
|
See :ref:`Ref-Struct-Store` for more details.
|
||
|
|
||
|
OR
|
||
|
|
||
|
engine (Optional[Engine]): SQLAlchemy Engine object of the database connection.
|
||
|
|
||
|
OR
|
||
|
|
||
|
uri (Optional[str]): uri of the database connection.
|
||
|
|
||
|
OR
|
||
|
|
||
|
scheme (Optional[str]): scheme of the database connection.
|
||
|
host (Optional[str]): host of the database connection.
|
||
|
port (Optional[int]): port of the database connection.
|
||
|
user (Optional[str]): user of the database connection.
|
||
|
password (Optional[str]): password of the database connection.
|
||
|
dbname (Optional[str]): dbname of the database connection.
|
||
|
|
||
|
Returns:
|
||
|
DatabaseReader: A DatabaseReader object.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
sql_database: Optional[SQLDatabase] = None,
|
||
|
engine: Optional[Engine] = None,
|
||
|
uri: Optional[str] = None,
|
||
|
scheme: Optional[str] = None,
|
||
|
host: Optional[str] = None,
|
||
|
port: Optional[str] = None,
|
||
|
user: Optional[str] = None,
|
||
|
password: Optional[str] = None,
|
||
|
dbname: Optional[str] = None,
|
||
|
*args: Optional[Any],
|
||
|
**kwargs: Optional[Any],
|
||
|
) -> None:
|
||
|
"""Initialize with parameters."""
|
||
|
if sql_database:
|
||
|
self.sql_database = sql_database
|
||
|
elif engine:
|
||
|
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
||
|
elif uri:
|
||
|
self.uri = uri
|
||
|
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||
|
elif scheme and host and port and user and password and dbname:
|
||
|
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
||
|
self.uri = uri
|
||
|
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"You must provide either a SQLDatabase, "
|
||
|
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
||
|
"set of credentials."
|
||
|
)
|
||
|
|
||
|
def load_data(self, query: str) -> List[Document]:
|
||
|
"""Query and load data from the Database, returning a list of Documents.
|
||
|
|
||
|
Args:
|
||
|
query (str): Query parameter to filter tables and rows.
|
||
|
|
||
|
Returns:
|
||
|
List[Document]: A list of Document objects.
|
||
|
"""
|
||
|
documents = []
|
||
|
with self.sql_database.engine.connect() as connection:
|
||
|
if query is None:
|
||
|
raise ValueError("A query parameter is necessary to filter the data")
|
||
|
else:
|
||
|
result = connection.execute(text(query))
|
||
|
|
||
|
for item in result.fetchall():
|
||
|
documents.append(Document(item[0]))
|
||
|
return documents
|