mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-12-04 10:58:53 +00:00
Merge pull request #1834 from danielaskdd/postgres-ssl
feat: Add SSL support for PostgreSQL database connections
This commit is contained in:
commit
ae582f63a3
@ -189,6 +189,13 @@ POSTGRES_DATABASE=your_database
|
||||
POSTGRES_MAX_CONNECTIONS=12
|
||||
# POSTGRES_WORKSPACE=forced_workspace_name
|
||||
|
||||
### PostgreSQL SSL Configuration (Optional)
|
||||
# POSTGRES_SSL_MODE=require
|
||||
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
||||
# POSTGRES_SSL_KEY=/path/to/client-key.pem
|
||||
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
|
||||
# POSTGRES_SSL_CRL=/path/to/crl.pem
|
||||
|
||||
### Neo4j Configuration
|
||||
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||
NEO4J_USERNAME=neo4j
|
||||
|
||||
@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Union, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
import ssl
|
||||
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
@ -58,27 +59,121 @@ class PostgreSQLDB:
|
||||
self.increment = 1
|
||||
self.pool: Pool | None = None
|
||||
|
||||
# SSL configuration
|
||||
self.ssl_mode = config.get("ssl_mode")
|
||||
self.ssl_cert = config.get("ssl_cert")
|
||||
self.ssl_key = config.get("ssl_key")
|
||||
self.ssl_root_cert = config.get("ssl_root_cert")
|
||||
self.ssl_crl = config.get("ssl_crl")
|
||||
|
||||
if self.user is None or self.password is None or self.database is None:
|
||||
raise ValueError("Missing database user, password, or database")
|
||||
|
||||
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
||||
"""Create SSL context based on configuration parameters."""
|
||||
if not self.ssl_mode:
|
||||
return None
|
||||
|
||||
ssl_mode = self.ssl_mode.lower()
|
||||
|
||||
# For simple modes that don't require custom context
|
||||
if ssl_mode in ["disable", "allow", "prefer", "require"]:
|
||||
if ssl_mode == "disable":
|
||||
return None
|
||||
elif ssl_mode in ["require", "prefer"]:
|
||||
# Return None for simple SSL requirement, handled in initdb
|
||||
return None
|
||||
|
||||
# For modes that require certificate verification
|
||||
if ssl_mode in ["verify-ca", "verify-full"]:
|
||||
try:
|
||||
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||
|
||||
# Configure certificate verification
|
||||
if ssl_mode == "verify-ca":
|
||||
context.check_hostname = False
|
||||
elif ssl_mode == "verify-full":
|
||||
context.check_hostname = True
|
||||
|
||||
# Load root certificate if provided
|
||||
if self.ssl_root_cert:
|
||||
if os.path.exists(self.ssl_root_cert):
|
||||
context.load_verify_locations(cafile=self.ssl_root_cert)
|
||||
logger.info(
|
||||
f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}"
|
||||
)
|
||||
|
||||
# Load client certificate and key if provided
|
||||
if self.ssl_cert and self.ssl_key:
|
||||
if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key):
|
||||
context.load_cert_chain(self.ssl_cert, self.ssl_key)
|
||||
logger.info(
|
||||
f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"PostgreSQL, SSL client certificate or key file not found"
|
||||
)
|
||||
|
||||
# Load certificate revocation list if provided
|
||||
if self.ssl_crl:
|
||||
if os.path.exists(self.ssl_crl):
|
||||
context.load_verify_locations(crlfile=self.ssl_crl)
|
||||
logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL, Failed to create SSL context: {e}")
|
||||
raise ValueError(f"SSL configuration error: {e}")
|
||||
|
||||
# Unknown SSL mode
|
||||
logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled")
|
||||
return None
|
||||
|
||||
async def initdb(self):
|
||||
try:
|
||||
self.pool = await asyncpg.create_pool( # type: ignore
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
database=self.database,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
min_size=1,
|
||||
max_size=self.max,
|
||||
)
|
||||
# Prepare connection parameters
|
||||
connection_params = {
|
||||
"user": self.user,
|
||||
"password": self.password,
|
||||
"database": self.database,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"min_size": 1,
|
||||
"max_size": self.max,
|
||||
}
|
||||
|
||||
# Add SSL configuration if provided
|
||||
ssl_context = self._create_ssl_context()
|
||||
if ssl_context is not None:
|
||||
connection_params["ssl"] = ssl_context
|
||||
logger.info("PostgreSQL, SSL configuration applied")
|
||||
elif self.ssl_mode:
|
||||
# Handle simple SSL modes without custom context
|
||||
if self.ssl_mode.lower() in ["require", "prefer"]:
|
||||
connection_params["ssl"] = True
|
||||
elif self.ssl_mode.lower() == "disable":
|
||||
connection_params["ssl"] = False
|
||||
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
||||
|
||||
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
||||
|
||||
# Ensure VECTOR extension is available
|
||||
async with self.pool.acquire() as connection:
|
||||
await self.configure_vector_extension(connection)
|
||||
|
||||
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
||||
logger.info(
|
||||
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
|
||||
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@ -809,6 +904,27 @@ class ClientManager:
|
||||
"POSTGRES_MAX_CONNECTIONS",
|
||||
config.get("postgres", "max_connections", fallback=20),
|
||||
),
|
||||
# SSL configuration
|
||||
"ssl_mode": os.environ.get(
|
||||
"POSTGRES_SSL_MODE",
|
||||
config.get("postgres", "ssl_mode", fallback=None),
|
||||
),
|
||||
"ssl_cert": os.environ.get(
|
||||
"POSTGRES_SSL_CERT",
|
||||
config.get("postgres", "ssl_cert", fallback=None),
|
||||
),
|
||||
"ssl_key": os.environ.get(
|
||||
"POSTGRES_SSL_KEY",
|
||||
config.get("postgres", "ssl_key", fallback=None),
|
||||
),
|
||||
"ssl_root_cert": os.environ.get(
|
||||
"POSTGRES_SSL_ROOT_CERT",
|
||||
config.get("postgres", "ssl_root_cert", fallback=None),
|
||||
),
|
||||
"ssl_crl": os.environ.get(
|
||||
"POSTGRES_SSL_CRL",
|
||||
config.get("postgres", "ssl_crl", fallback=None),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user