diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index 82c2456..1549d6a 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -168,7 +168,7 @@ func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) { return nil, err } // For Bun, we'll return a special wrapper that holds the transaction - return &BunTxAdapter{tx: tx}, nil + return &BunTxAdapter{tx: tx, driverName: b.DriverName()}, nil } func (b *BunAdapter) CommitTx(ctx context.Context) error { @@ -191,7 +191,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa }() return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { // Create adapter with transaction - adapter := &BunTxAdapter{tx: tx} + adapter := &BunTxAdapter{tx: tx, driverName: b.DriverName()} return fn(adapter) }) } @@ -200,6 +200,17 @@ func (b *BunAdapter) GetUnderlyingDB() interface{} { return b.db } +func (b *BunAdapter) DriverName() string { + // Normalize Bun's dialect name to match the project's canonical vocabulary. + // Bun returns "pg" for PostgreSQL; the rest of the project uses "postgres". + switch name := b.db.Dialect().Name().String(); name { + case "pg": + return "postgres" + default: + return name + } +} + // BunSelectQuery implements SelectQuery for Bun type BunSelectQuery struct { query *bun.SelectQuery @@ -1477,7 +1488,8 @@ func (b *BunResult) LastInsertId() (int64, error) { // BunTxAdapter wraps a Bun transaction to implement the Database interface type BunTxAdapter struct { - tx bun.Tx + tx bun.Tx + driverName string } func (b *BunTxAdapter) NewSelect() common.SelectQuery { @@ -1527,3 +1539,7 @@ func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Data func (b *BunTxAdapter) GetUnderlyingDB() interface{} { return b.tx } + +func (b *BunTxAdapter) DriverName() string { + return b.driverName +} diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 9d3b3d9..9058ae7 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -106,6 +106,20 @@ func (g *GormAdapter) GetUnderlyingDB() interface{} { return g.db } +func (g *GormAdapter) DriverName() string { + if g.db.Dialector == nil { + return "" + } + // Normalize GORM's dialector name to match the project's canonical vocabulary. + // GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql". + switch name := g.db.Name(); name { + case "sqlserver": + return "mssql" + default: + return name + } +} + // GormSelectQuery implements SelectQuery for GORM type GormSelectQuery struct { db *gorm.DB diff --git a/pkg/common/adapters/database/pgsql.go b/pkg/common/adapters/database/pgsql.go index f2486f3..caa0d8f 100644 --- a/pkg/common/adapters/database/pgsql.go +++ b/pkg/common/adapters/database/pgsql.go @@ -16,12 +16,19 @@ import ( // PgSQLAdapter adapts standard database/sql to work with our Database interface // This provides a lightweight PostgreSQL adapter without ORM overhead type PgSQLAdapter struct { - db *sql.DB + db *sql.DB + driverName string } -// NewPgSQLAdapter creates a new PostgreSQL adapter -func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter { - return &PgSQLAdapter{db: db} +// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB. +// An optional driverName (e.g. "postgres", "sqlite", "mssql") can be provided; +// it defaults to "postgres" when omitted. +func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter { + name := "postgres" + if len(driverName) > 0 && driverName[0] != "" { + name = driverName[0] + } + return &PgSQLAdapter{db: db, driverName: name} } // EnableQueryDebug enables query debugging for development @@ -98,7 +105,7 @@ func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) { if err != nil { return nil, err } - return &PgSQLTxAdapter{tx: tx}, nil + return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil } func (p *PgSQLAdapter) CommitTx(ctx context.Context) error { @@ -121,7 +128,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data return err } - adapter := &PgSQLTxAdapter{tx: tx} + adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName} defer func() { if p := recover(); p != nil { @@ -141,6 +148,10 @@ func (p *PgSQLAdapter) GetUnderlyingDB() interface{} { return p.db } +func (p *PgSQLAdapter) DriverName() string { + return p.driverName +} + // preloadConfig represents a relationship to be preloaded type preloadConfig struct { relation string @@ -835,7 +846,8 @@ func (p *PgSQLResult) LastInsertId() (int64, error) { // PgSQLTxAdapter wraps a PostgreSQL transaction type PgSQLTxAdapter struct { - tx *sql.Tx + tx *sql.Tx + driverName string } func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery { @@ -912,6 +924,10 @@ func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} { return p.tx } +func (p *PgSQLTxAdapter) DriverName() string { + return p.driverName +} + // applyJoinPreloads adds JOINs for relationships that should use JOIN strategy func (p *PgSQLSelectQuery) applyJoinPreloads() { for _, preload := range p.preloads { diff --git a/pkg/common/interfaces.go b/pkg/common/interfaces.go index 03a72a0..6f72709 100644 --- a/pkg/common/interfaces.go +++ b/pkg/common/interfaces.go @@ -30,6 +30,12 @@ type Database interface { // For Bun, this returns *bun.DB // This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN GetUnderlyingDB() interface{} + + // DriverName returns the canonical name of the underlying database driver. + // Possible values: "postgres", "sqlite", "mssql", "mysql". + // All adapters normalise vendor-specific strings (e.g. Bun's "pg", GORM's + // "sqlserver") to the values above before returning. + DriverName() string } // SelectQuery interface for building SELECT queries (compatible with both GORM and Bun) diff --git a/pkg/common/recursive_crud_test.go b/pkg/common/recursive_crud_test.go index 9bda8bb..0d2366a 100644 --- a/pkg/common/recursive_crud_test.go +++ b/pkg/common/recursive_crud_test.go @@ -50,6 +50,9 @@ func (m *mockDatabase) RollbackTx(ctx context.Context) error { func (m *mockDatabase) GetUnderlyingDB() interface{} { return nil } +func (m *mockDatabase) DriverName() string { + return "postgres" +} // Mock SelectQuery type mockSelectQuery struct{} diff --git a/pkg/dbmanager/connection.go b/pkg/dbmanager/connection.go index 748e42e..23f1ea0 100644 --- a/pkg/dbmanager/connection.go +++ b/pkg/dbmanager/connection.go @@ -467,13 +467,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) { // Create a native adapter based on database type switch c.dbType { case DatabaseTypePostgreSQL: - c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB) + c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)) case DatabaseTypeSQLite: - // For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB - c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB) + c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)) case DatabaseTypeMSSQL: - // For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB - c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB) + c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)) default: return nil, ErrUnsupportedDatabase } diff --git a/pkg/dbmanager/manager.go b/pkg/dbmanager/manager.go index 7bcab48..e5dfb5b 100644 --- a/pkg/dbmanager/manager.go +++ b/pkg/dbmanager/manager.go @@ -231,12 +231,14 @@ func (m *connectionManager) Connect(ctx context.Context) error { // Close closes all database connections func (m *connectionManager) Close() error { + // Stop the health checker before taking mu. performHealthCheck acquires + // a read lock, so waiting for the goroutine while holding the write lock + // would deadlock. + m.stopHealthChecker() + m.mu.Lock() defer m.mu.Unlock() - // Stop health checker - m.stopHealthChecker() - // Close all connections var errors []error for name, conn := range m.connections { diff --git a/pkg/funcspec/function_api_test.go b/pkg/funcspec/function_api_test.go index defc30a..c90bf90 100644 --- a/pkg/funcspec/function_api_test.go +++ b/pkg/funcspec/function_api_test.go @@ -74,6 +74,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} { return m } +func (m *MockDatabase) DriverName() string { + return "postgres" +} + // MockResult implements common.Result interface for testing type MockResult struct { rows int64 diff --git a/pkg/mqttspec/handler.go b/pkg/mqttspec/handler.go index d90aae9..6154283 100644 --- a/pkg/mqttspec/handler.go +++ b/pkg/mqttspec/handler.go @@ -645,11 +645,14 @@ func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string { // Database operation helpers (adapted from websocketspec) func (h *Handler) getTableName(schema, entity string, model interface{}) string { - // Use entity as table name tableName := entity if schema != "" { - tableName = schema + "." + tableName + if h.db.DriverName() == "sqlite" { + tableName = schema + "_" + tableName + } else { + tableName = schema + "." + tableName + } } return tableName } diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 33a1cb6..c191936 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1380,10 +1380,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac return schema, entity } -// getTableName returns the full table name including schema (schema.table) +// getTableName returns the full table name including schema. +// For most drivers the result is "schema.table". For SQLite, which does not +// support schema-qualified names, the schema and table are joined with an +// underscore: "schema_table". func (h *Handler) getTableName(schema, entity string, model interface{}) string { schemaName, tableName := h.getSchemaAndTable(schema, entity, model) if schemaName != "" { + if h.db.DriverName() == "sqlite" { + return fmt.Sprintf("%s_%s", schemaName, tableName) + } return fmt.Sprintf("%s.%s", schemaName, tableName) } return tableName diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index e2f7f09..a44bc80 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -2015,11 +2015,18 @@ func (h *Handler) processChildRelationsForField( return nil } -// getTableNameForRelatedModel gets the table name for a related model +// getTableNameForRelatedModel gets the table name for a related model. +// If the model's TableName() is schema-qualified (e.g. "public.users") the +// separator is adjusted for the active driver: underscore for SQLite, dot otherwise. func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string { if provider, ok := model.(common.TableNameProvider); ok { tableName := provider.TableName() if tableName != "" { + if schema, table := h.parseTableName(tableName); schema != "" { + if h.db.DriverName() == "sqlite" { + return fmt.Sprintf("%s_%s", schema, table) + } + } return tableName } } @@ -2264,10 +2271,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac return schema, entity } -// getTableName returns the full table name including schema (schema.table) +// getTableName returns the full table name including schema. +// For most drivers the result is "schema.table". For SQLite, which does not +// support schema-qualified names, the schema and table are joined with an +// underscore: "schema_table". func (h *Handler) getTableName(schema, entity string, model interface{}) string { schemaName, tableName := h.getSchemaAndTable(schema, entity, model) if schemaName != "" { + if h.db.DriverName() == "sqlite" { + return fmt.Sprintf("%s_%s", schemaName, tableName) + } return fmt.Sprintf("%s.%s", schemaName, tableName) } return tableName diff --git a/pkg/websocketspec/handler.go b/pkg/websocketspec/handler.go index e7f25cd..89f0459 100644 --- a/pkg/websocketspec/handler.go +++ b/pkg/websocketspec/handler.go @@ -656,11 +656,14 @@ func (h *Handler) delete(hookCtx *HookContext) error { // Helper methods func (h *Handler) getTableName(schema, entity string, model interface{}) string { - // Use entity as table name tableName := entity if schema != "" { - tableName = schema + "." + tableName + if h.db.DriverName() == "sqlite" { + tableName = schema + "_" + tableName + } else { + tableName = schema + "." + tableName + } } return tableName } diff --git a/pkg/websocketspec/handler_test.go b/pkg/websocketspec/handler_test.go index d950914..623783c 100644 --- a/pkg/websocketspec/handler_test.go +++ b/pkg/websocketspec/handler_test.go @@ -82,6 +82,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} { return args.Get(0) } +func (m *MockDatabase) DriverName() string { + return "postgres" +} + // MockSelectQuery is a mock implementation of common.SelectQuery type MockSelectQuery struct { mock.Mock