mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-12-04 19:06:51 +00:00
Merge pull request #1834 from danielaskdd/postgres-ssl
feat: Add SSL support for PostgreSQL database connections
This commit is contained in:
commit
ae582f63a3
@ -189,6 +189,13 @@ POSTGRES_DATABASE=your_database
|
|||||||
POSTGRES_MAX_CONNECTIONS=12
|
POSTGRES_MAX_CONNECTIONS=12
|
||||||
# POSTGRES_WORKSPACE=forced_workspace_name
|
# POSTGRES_WORKSPACE=forced_workspace_name
|
||||||
|
|
||||||
|
### PostgreSQL SSL Configuration (Optional)
|
||||||
|
# POSTGRES_SSL_MODE=require
|
||||||
|
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
||||||
|
# POSTGRES_SSL_KEY=/path/to/client-key.pem
|
||||||
|
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
|
||||||
|
# POSTGRES_SSL_CRL=/path/to/crl.pem
|
||||||
|
|
||||||
### Neo4j Configuration
|
### Neo4j Configuration
|
||||||
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||||
NEO4J_USERNAME=neo4j
|
NEO4J_USERNAME=neo4j
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Union, final
|
from typing import Any, Union, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
import ssl
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
@ -58,27 +59,121 @@ class PostgreSQLDB:
|
|||||||
self.increment = 1
|
self.increment = 1
|
||||||
self.pool: Pool | None = None
|
self.pool: Pool | None = None
|
||||||
|
|
||||||
|
# SSL configuration
|
||||||
|
self.ssl_mode = config.get("ssl_mode")
|
||||||
|
self.ssl_cert = config.get("ssl_cert")
|
||||||
|
self.ssl_key = config.get("ssl_key")
|
||||||
|
self.ssl_root_cert = config.get("ssl_root_cert")
|
||||||
|
self.ssl_crl = config.get("ssl_crl")
|
||||||
|
|
||||||
if self.user is None or self.password is None or self.database is None:
|
if self.user is None or self.password is None or self.database is None:
|
||||||
raise ValueError("Missing database user, password, or database")
|
raise ValueError("Missing database user, password, or database")
|
||||||
|
|
||||||
|
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
||||||
|
"""Create SSL context based on configuration parameters."""
|
||||||
|
if not self.ssl_mode:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ssl_mode = self.ssl_mode.lower()
|
||||||
|
|
||||||
|
# For simple modes that don't require custom context
|
||||||
|
if ssl_mode in ["disable", "allow", "prefer", "require"]:
|
||||||
|
if ssl_mode == "disable":
|
||||||
|
return None
|
||||||
|
elif ssl_mode in ["require", "prefer"]:
|
||||||
|
# Return None for simple SSL requirement, handled in initdb
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For modes that require certificate verification
|
||||||
|
if ssl_mode in ["verify-ca", "verify-full"]:
|
||||||
|
try:
|
||||||
|
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||||
|
|
||||||
|
# Configure certificate verification
|
||||||
|
if ssl_mode == "verify-ca":
|
||||||
|
context.check_hostname = False
|
||||||
|
elif ssl_mode == "verify-full":
|
||||||
|
context.check_hostname = True
|
||||||
|
|
||||||
|
# Load root certificate if provided
|
||||||
|
if self.ssl_root_cert:
|
||||||
|
if os.path.exists(self.ssl_root_cert):
|
||||||
|
context.load_verify_locations(cafile=self.ssl_root_cert)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load client certificate and key if provided
|
||||||
|
if self.ssl_cert and self.ssl_key:
|
||||||
|
if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key):
|
||||||
|
context.load_cert_chain(self.ssl_cert, self.ssl_key)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"PostgreSQL, SSL client certificate or key file not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load certificate revocation list if provided
|
||||||
|
if self.ssl_crl:
|
||||||
|
if os.path.exists(self.ssl_crl):
|
||||||
|
context.load_verify_locations(crlfile=self.ssl_crl)
|
||||||
|
logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PostgreSQL, Failed to create SSL context: {e}")
|
||||||
|
raise ValueError(f"SSL configuration error: {e}")
|
||||||
|
|
||||||
|
# Unknown SSL mode
|
||||||
|
logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
async def initdb(self):
|
async def initdb(self):
|
||||||
try:
|
try:
|
||||||
self.pool = await asyncpg.create_pool( # type: ignore
|
# Prepare connection parameters
|
||||||
user=self.user,
|
connection_params = {
|
||||||
password=self.password,
|
"user": self.user,
|
||||||
database=self.database,
|
"password": self.password,
|
||||||
host=self.host,
|
"database": self.database,
|
||||||
port=self.port,
|
"host": self.host,
|
||||||
min_size=1,
|
"port": self.port,
|
||||||
max_size=self.max,
|
"min_size": 1,
|
||||||
)
|
"max_size": self.max,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add SSL configuration if provided
|
||||||
|
ssl_context = self._create_ssl_context()
|
||||||
|
if ssl_context is not None:
|
||||||
|
connection_params["ssl"] = ssl_context
|
||||||
|
logger.info("PostgreSQL, SSL configuration applied")
|
||||||
|
elif self.ssl_mode:
|
||||||
|
# Handle simple SSL modes without custom context
|
||||||
|
if self.ssl_mode.lower() in ["require", "prefer"]:
|
||||||
|
connection_params["ssl"] = True
|
||||||
|
elif self.ssl_mode.lower() == "disable":
|
||||||
|
connection_params["ssl"] = False
|
||||||
|
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
||||||
|
|
||||||
|
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
||||||
|
|
||||||
# Ensure VECTOR extension is available
|
# Ensure VECTOR extension is available
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
await self.configure_vector_extension(connection)
|
await self.configure_vector_extension(connection)
|
||||||
|
|
||||||
|
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
|
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -809,6 +904,27 @@ class ClientManager:
|
|||||||
"POSTGRES_MAX_CONNECTIONS",
|
"POSTGRES_MAX_CONNECTIONS",
|
||||||
config.get("postgres", "max_connections", fallback=20),
|
config.get("postgres", "max_connections", fallback=20),
|
||||||
),
|
),
|
||||||
|
# SSL configuration
|
||||||
|
"ssl_mode": os.environ.get(
|
||||||
|
"POSTGRES_SSL_MODE",
|
||||||
|
config.get("postgres", "ssl_mode", fallback=None),
|
||||||
|
),
|
||||||
|
"ssl_cert": os.environ.get(
|
||||||
|
"POSTGRES_SSL_CERT",
|
||||||
|
config.get("postgres", "ssl_cert", fallback=None),
|
||||||
|
),
|
||||||
|
"ssl_key": os.environ.get(
|
||||||
|
"POSTGRES_SSL_KEY",
|
||||||
|
config.get("postgres", "ssl_key", fallback=None),
|
||||||
|
),
|
||||||
|
"ssl_root_cert": os.environ.get(
|
||||||
|
"POSTGRES_SSL_ROOT_CERT",
|
||||||
|
config.get("postgres", "ssl_root_cert", fallback=None),
|
||||||
|
),
|
||||||
|
"ssl_crl": os.environ.get(
|
||||||
|
"POSTGRES_SSL_CRL",
|
||||||
|
config.get("postgres", "ssl_crl", fallback=None),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user