From 7a48741483536ff9af41f29a3af2eb08f615ef4c Mon Sep 17 00:00:00 2001 From: Ben Irvin Date: Mon, 30 Oct 2023 17:53:25 +0100 Subject: [PATCH] fix: init dialect on every new connection in the pool (#18588) --- packages/core/database/src/connection.ts | 30 ++++++++++++++++--- .../core/database/src/dialects/dialect.ts | 5 +++- .../core/database/src/dialects/mysql/index.ts | 11 +++++-- .../database/src/dialects/postgresql/index.ts | 10 ++++++- .../database/src/dialects/sqlite/index.ts | 4 +-- packages/core/database/src/index.ts | 12 ++++++-- 6 files changed, 59 insertions(+), 13 deletions(-) diff --git a/packages/core/database/src/connection.ts b/packages/core/database/src/connection.ts index e87ad3cb77..0430bdcca7 100644 --- a/packages/core/database/src/connection.ts +++ b/packages/core/database/src/connection.ts @@ -11,12 +11,34 @@ function isClientValid(config: { client?: unknown }): config is { client: keyof return Object.keys(clientMap).includes(config.client as string); } -export const createConnection = (config: Knex.Config) => { - if (!isClientValid(config)) { - throw new Error(`Unsupported database client ${config.client}`); +export const createConnection = (userConfig: Knex.Config, strapiConfig?: Partial) => { + if (!isClientValid(userConfig)) { + throw new Error(`Unsupported database client ${userConfig.client}`); } - const knexConfig = { ...config, client: (clientMap as any)[config.client] }; + const knexConfig: Knex.Config = { ...userConfig, client: (clientMap as any)[userConfig.client] }; + + // initialization code to run upon opening a new connection + if (strapiConfig?.pool?.afterCreate) { + knexConfig.pool = knexConfig.pool || {}; + // if the user has set their own afterCreate in config, we will replace it and call it + const userAfterCreate = knexConfig.pool?.afterCreate; + const strapiAfterCreate = strapiConfig.pool.afterCreate; + knexConfig.pool.afterCreate = ( + conn: unknown, + done: (err: Error | null | undefined, connection: any) => void + ) => { + strapiAfterCreate(conn, (err: Error | null | undefined, nativeConn: any) => { + if (err) { + return done(err, nativeConn); + } + if (userAfterCreate) { + return userAfterCreate(nativeConn, done); + } + return done(null, nativeConn); + }); + }; + } return knex(knexConfig); }; diff --git a/packages/core/database/src/dialects/dialect.ts b/packages/core/database/src/dialects/dialect.ts index dc7f089551..027914b090 100644 --- a/packages/core/database/src/dialects/dialect.ts +++ b/packages/core/database/src/dialects/dialect.ts @@ -19,7 +19,10 @@ export default class Dialect { configure() {} - initialize() {} + // eslint-disable-next-line @typescript-eslint/no-unused-vars + async initialize(_nativeConnection?: unknown) { + // noop + } getSqlType(type: unknown) { return type; diff --git a/packages/core/database/src/dialects/mysql/index.ts b/packages/core/database/src/dialects/mysql/index.ts index 7ea1ca3ef2..711bbf0641 100644 --- a/packages/core/database/src/dialects/mysql/index.ts +++ b/packages/core/database/src/dialects/mysql/index.ts @@ -52,14 +52,19 @@ export default class MysqlDialect extends Dialect { }; } - async initialize() { + async initialize(nativeConnection: unknown) { try { - await this.db.connection.raw(`set session sql_require_primary_key = 0;`); + await this.db.connection + .raw(`set session sql_require_primary_key = 0;`) + .connection(nativeConnection); } catch (err) { // Ignore error due to lack of session permissions } - this.info = await this.databaseInspector.getInformation(); + // We only need to get info on the first connection in the pool + if (!this.info) { + this.info = await this.databaseInspector.getInformation(); + } } async startSchemaUpdate() { diff --git a/packages/core/database/src/dialects/postgresql/index.ts b/packages/core/database/src/dialects/postgresql/index.ts index 022d3f435b..d48dd93bf3 100644 --- a/packages/core/database/src/dialects/postgresql/index.ts +++ b/packages/core/database/src/dialects/postgresql/index.ts @@ -16,7 +16,7 @@ export default class PostgresDialect extends Dialect { return true; } - async initialize() { + async initialize(nativeConnection: unknown) { // Don't cast DATE string to Date() this.db.connection.client.driver.types.setTypeParser( this.db.connection.client.driver.types.builtins.DATE, @@ -34,6 +34,14 @@ export default class PostgresDialect extends Dialect { 'text', parseFloat ); + + // If we're using a schema, set the default path for all table names in queries to use that schema + const schemaName = this.db.getSchemaName(); + if (schemaName) { + await this.db.connection + .raw(`SET search_path TO "${schemaName}"`) + .connection(nativeConnection); + } } usesForeignKeys() { diff --git a/packages/core/database/src/dialects/sqlite/index.ts b/packages/core/database/src/dialects/sqlite/index.ts index 177ffc60ab..78533bd94b 100644 --- a/packages/core/database/src/dialects/sqlite/index.ts +++ b/packages/core/database/src/dialects/sqlite/index.ts @@ -33,8 +33,8 @@ export default class SqliteDialect extends Dialect { return true; } - async initialize() { - await this.db.connection.raw('pragma foreign_keys = on'); + async initialize(nativeConnection: unknown) { + await this.db.connection.raw('pragma foreign_keys = on').connection(nativeConnection); } canAlterConstraints() { diff --git a/packages/core/database/src/index.ts b/packages/core/database/src/index.ts index bfaa9fe722..823b2ad98b 100644 --- a/packages/core/database/src/index.ts +++ b/packages/core/database/src/index.ts @@ -69,9 +69,17 @@ class Database { this.dialect = getDialect(this); this.dialect.configure(); - this.connection = createConnection(this.config.connection); + const afterCreate = ( + nativeConnection: unknown, + done: (error: Error | null, nativeConnection: unknown) => Promise + ) => { + // run initialize for it since commands such as postgres SET and sqlite PRAGMA are per-connection + this.dialect.initialize(nativeConnection).then(() => { + return done(null, nativeConnection); + }); + }; - this.dialect.initialize(); + this.connection = createConnection(this.config.connection, { pool: { afterCreate } }); this.schema = createSchemaProvider(this);