247 lines
5.7 KiB
Go
247 lines
5.7 KiB
Go
package install
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"sort"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/pgsql-broker/pkg/broker/adapter"
|
|
)
|
|
|
|
//go:embed all:sql
|
|
var sqlFS embed.FS
|
|
|
|
// Installer handles database schema installation
|
|
type Installer struct {
|
|
db adapter.DBAdapter
|
|
logger adapter.Logger
|
|
}
|
|
|
|
// New creates a new installer
|
|
func New(db adapter.DBAdapter, logger adapter.Logger) *Installer {
|
|
return &Installer{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// InstallSchema installs the complete database schema
|
|
func (i *Installer) InstallSchema(ctx context.Context) error {
|
|
i.logger.Info("starting schema installation")
|
|
|
|
// Install tables first
|
|
if err := i.installTables(ctx); err != nil {
|
|
return fmt.Errorf("failed to install tables: %w", err)
|
|
}
|
|
|
|
// Then install procedures
|
|
if err := i.installProcedures(ctx); err != nil {
|
|
return fmt.Errorf("failed to install procedures: %w", err)
|
|
}
|
|
|
|
i.logger.Info("schema installation completed successfully")
|
|
return nil
|
|
}
|
|
|
|
// installTables installs all table definitions
|
|
func (i *Installer) installTables(ctx context.Context) error {
|
|
i.logger.Info("installing tables")
|
|
|
|
files, err := sqlFS.ReadDir("sql/tables")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read tables directory: %w", err)
|
|
}
|
|
|
|
// Filter and sort SQL files
|
|
sqlFiles := filterAndSortSQLFiles(files)
|
|
|
|
for _, file := range sqlFiles {
|
|
// Skip install script
|
|
if file == "00_install.sql" {
|
|
continue
|
|
}
|
|
|
|
i.logger.Info("executing table script", "file", file)
|
|
|
|
content, err := sqlFS.ReadFile("sql/tables/" + file)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read file %s: %w", file, err)
|
|
}
|
|
|
|
if err := i.executeSQL(ctx, string(content)); err != nil {
|
|
return fmt.Errorf("failed to execute %s: %w", file, err)
|
|
}
|
|
}
|
|
|
|
i.logger.Info("tables installed successfully")
|
|
return nil
|
|
}
|
|
|
|
// installProcedures installs all stored procedures
|
|
func (i *Installer) installProcedures(ctx context.Context) error {
|
|
i.logger.Info("installing procedures")
|
|
|
|
files, err := sqlFS.ReadDir("sql/procedures")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read procedures directory: %w", err)
|
|
}
|
|
|
|
// Filter and sort SQL files
|
|
sqlFiles := filterAndSortSQLFiles(files)
|
|
|
|
for _, file := range sqlFiles {
|
|
// Skip install script
|
|
if file == "00_install.sql" {
|
|
continue
|
|
}
|
|
|
|
i.logger.Info("executing procedure script", "file", file)
|
|
|
|
content, err := sqlFS.ReadFile("sql/procedures/" + file)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read file %s: %w", file, err)
|
|
}
|
|
|
|
if err := i.executeSQL(ctx, string(content)); err != nil {
|
|
return fmt.Errorf("failed to execute %s: %w", file, err)
|
|
}
|
|
}
|
|
|
|
i.logger.Info("procedures installed successfully")
|
|
return nil
|
|
}
|
|
|
|
// executeSQL executes SQL statements
|
|
func (i *Installer) executeSQL(ctx context.Context, sql string) error {
|
|
// Remove comments and split by statement
|
|
statements := splitSQLStatements(sql)
|
|
|
|
for _, stmt := range statements {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
|
|
// Skip psql-specific commands
|
|
if strings.HasPrefix(stmt, "\\") {
|
|
continue
|
|
}
|
|
|
|
if _, err := i.db.Exec(ctx, stmt); err != nil {
|
|
return fmt.Errorf("failed to execute statement: %w\nStatement: %s", err, stmt)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// filterAndSortSQLFiles filters and sorts SQL files
|
|
func filterAndSortSQLFiles(files []fs.DirEntry) []string {
|
|
var sqlFiles []string
|
|
for _, file := range files {
|
|
if !file.IsDir() && strings.HasSuffix(file.Name(), ".sql") {
|
|
sqlFiles = append(sqlFiles, file.Name())
|
|
}
|
|
}
|
|
sort.Strings(sqlFiles)
|
|
return sqlFiles
|
|
}
|
|
|
|
// splitSQLStatements splits SQL into individual statements
|
|
func splitSQLStatements(sql string) []string {
|
|
// Simple split by semicolon
|
|
// This doesn't handle all edge cases (strings with semicolons, dollar-quoted strings, etc.)
|
|
// but works for our use case
|
|
statements := strings.Split(sql, ";")
|
|
|
|
var result []string
|
|
var buffer string
|
|
|
|
for _, stmt := range statements {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
|
|
buffer += stmt + ";"
|
|
|
|
// Check if we're inside a function definition ($$)
|
|
dollarCount := strings.Count(buffer, "$$")
|
|
if dollarCount%2 == 0 {
|
|
// Even number of $$ means we're outside function definitions
|
|
result = append(result, buffer)
|
|
buffer = ""
|
|
} else {
|
|
// Odd number means we're inside a function, keep accumulating
|
|
buffer += " "
|
|
}
|
|
}
|
|
|
|
// Add any remaining buffered content
|
|
if buffer != "" {
|
|
result = append(result, buffer)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// VerifyInstallation checks if the schema is properly installed
|
|
func (i *Installer) VerifyInstallation(ctx context.Context) error {
|
|
i.logger.Info("verifying installation")
|
|
|
|
tables := []string{"broker_queueinstance", "broker_jobs", "broker_schedule"}
|
|
procedures := []string{
|
|
"broker_get",
|
|
"broker_run",
|
|
"broker_set",
|
|
"broker_add_job",
|
|
"broker_register_instance",
|
|
"broker_ping_instance",
|
|
"broker_shutdown_instance",
|
|
}
|
|
|
|
// Check tables
|
|
for _, table := range tables {
|
|
var exists bool
|
|
err := i.db.QueryRow(ctx,
|
|
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1)",
|
|
table,
|
|
).Scan(&exists)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check table %s: %w", table, err)
|
|
}
|
|
|
|
if !exists {
|
|
return fmt.Errorf("table %s does not exist", table)
|
|
}
|
|
|
|
i.logger.Info("table verified", "table", table)
|
|
}
|
|
|
|
// Check procedures
|
|
for _, proc := range procedures {
|
|
var exists bool
|
|
err := i.db.QueryRow(ctx,
|
|
"SELECT EXISTS (SELECT FROM pg_proc WHERE proname = $1)",
|
|
proc,
|
|
).Scan(&exists)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check procedure %s: %w", proc, err)
|
|
}
|
|
|
|
if !exists {
|
|
return fmt.Errorf("procedure %s does not exist", proc)
|
|
}
|
|
|
|
i.logger.Info("procedure verified", "procedure", proc)
|
|
}
|
|
|
|
i.logger.Info("installation verified successfully")
|
|
return nil
|
|
}
|