Merge pull request #1834 from danielaskdd/postgres-ssl

feat: Add SSL support for PostgreSQL database connections
This commit is contained in:
Daniel.y 2025-07-21 02:07:49 +08:00 committed by GitHub
commit ae582f63a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 133 additions and 10 deletions

View File

@ -189,6 +189,13 @@ POSTGRES_DATABASE=your_database
POSTGRES_MAX_CONNECTIONS=12 POSTGRES_MAX_CONNECTIONS=12
# POSTGRES_WORKSPACE=forced_workspace_name # POSTGRES_WORKSPACE=forced_workspace_name
### PostgreSQL SSL Configuration (Optional)
# POSTGRES_SSL_MODE=require
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
# POSTGRES_SSL_KEY=/path/to/client-key.pem
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
# POSTGRES_SSL_CRL=/path/to/crl.pem
### Neo4j Configuration ### Neo4j Configuration
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
NEO4J_USERNAME=neo4j NEO4J_USERNAME=neo4j

View File

@ -8,6 +8,7 @@ from dataclasses import dataclass, field
from typing import Any, Union, final from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
import ssl
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@ -58,27 +59,121 @@ class PostgreSQLDB:
self.increment = 1 self.increment = 1
self.pool: Pool | None = None self.pool: Pool | None = None
# SSL configuration
self.ssl_mode = config.get("ssl_mode")
self.ssl_cert = config.get("ssl_cert")
self.ssl_key = config.get("ssl_key")
self.ssl_root_cert = config.get("ssl_root_cert")
self.ssl_crl = config.get("ssl_crl")
if self.user is None or self.password is None or self.database is None: if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database") raise ValueError("Missing database user, password, or database")
def _create_ssl_context(self) -> ssl.SSLContext | None:
"""Create SSL context based on configuration parameters."""
if not self.ssl_mode:
return None
ssl_mode = self.ssl_mode.lower()
# For simple modes that don't require custom context
if ssl_mode in ["disable", "allow", "prefer", "require"]:
if ssl_mode == "disable":
return None
elif ssl_mode in ["require", "prefer"]:
# Return None for simple SSL requirement, handled in initdb
return None
# For modes that require certificate verification
if ssl_mode in ["verify-ca", "verify-full"]:
try:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
# Configure certificate verification
if ssl_mode == "verify-ca":
context.check_hostname = False
elif ssl_mode == "verify-full":
context.check_hostname = True
# Load root certificate if provided
if self.ssl_root_cert:
if os.path.exists(self.ssl_root_cert):
context.load_verify_locations(cafile=self.ssl_root_cert)
logger.info(
f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}"
)
else:
logger.warning(
f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}"
)
# Load client certificate and key if provided
if self.ssl_cert and self.ssl_key:
if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key):
context.load_cert_chain(self.ssl_cert, self.ssl_key)
logger.info(
f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}"
)
else:
logger.warning(
"PostgreSQL, SSL client certificate or key file not found"
)
# Load certificate revocation list if provided
if self.ssl_crl:
if os.path.exists(self.ssl_crl):
context.load_verify_locations(crlfile=self.ssl_crl)
logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}")
else:
logger.warning(
f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}"
)
return context
except Exception as e:
logger.error(f"PostgreSQL, Failed to create SSL context: {e}")
raise ValueError(f"SSL configuration error: {e}")
# Unknown SSL mode
logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled")
return None
async def initdb(self): async def initdb(self):
try: try:
self.pool = await asyncpg.create_pool( # type: ignore # Prepare connection parameters
user=self.user, connection_params = {
password=self.password, "user": self.user,
database=self.database, "password": self.password,
host=self.host, "database": self.database,
port=self.port, "host": self.host,
min_size=1, "port": self.port,
max_size=self.max, "min_size": 1,
) "max_size": self.max,
}
# Add SSL configuration if provided
ssl_context = self._create_ssl_context()
if ssl_context is not None:
connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode:
# Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
# Ensure VECTOR extension is available # Ensure VECTOR extension is available
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
await self.configure_vector_extension(connection) await self.configure_vector_extension(connection)
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
logger.info( logger.info(
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}" f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -809,6 +904,27 @@ class ClientManager:
"POSTGRES_MAX_CONNECTIONS", "POSTGRES_MAX_CONNECTIONS",
config.get("postgres", "max_connections", fallback=20), config.get("postgres", "max_connections", fallback=20),
), ),
# SSL configuration
"ssl_mode": os.environ.get(
"POSTGRES_SSL_MODE",
config.get("postgres", "ssl_mode", fallback=None),
),
"ssl_cert": os.environ.get(
"POSTGRES_SSL_CERT",
config.get("postgres", "ssl_cert", fallback=None),
),
"ssl_key": os.environ.get(
"POSTGRES_SSL_KEY",
config.get("postgres", "ssl_key", fallback=None),
),
"ssl_root_cert": os.environ.get(
"POSTGRES_SSL_ROOT_CERT",
config.get("postgres", "ssl_root_cert", fallback=None),
),
"ssl_crl": os.environ.get(
"POSTGRES_SSL_CRL",
config.get("postgres", "ssl_crl", fallback=None),
),
} }
@classmethod @classmethod