diff --git a/cmd/testserver/main.go b/cmd/testserver/main.go index c455888..88c252e 100644 --- a/cmd/testserver/main.go +++ b/cmd/testserver/main.go @@ -8,7 +8,7 @@ import ( "time" "github.com/Warky-Devs/ResolveSpec/pkg/logger" - "github.com/Warky-Devs/ResolveSpec/pkg/models" + "github.com/Warky-Devs/ResolveSpec/pkg/modelregistry" "github.com/Warky-Devs/ResolveSpec/pkg/testmodels" "github.com/Warky-Devs/ResolveSpec/pkg/resolvespec" @@ -24,9 +24,6 @@ func main() { fmt.Println("ResolveSpec test server starting") logger.Init(true) - // Init Models - testmodels.RegisterTestModels() - // Initialize database db, err := initDB() if err != nil { @@ -37,24 +34,22 @@ func main() { // Create router r := mux.NewRouter() - // Initialize API handler - handler := resolvespec.NewAPIHandler(db) + // Initialize API handler using new API + handler := resolvespec.NewHandlerWithGORM(db) - // Setup routes - r.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - handler.Handle(w, r, vars) - }).Methods("POST") + // Create a new registry instance and register models + registry := modelregistry.NewModelRegistry() + testmodels.RegisterTestModels(registry) - r.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - handler.Handle(w, r, vars) - }).Methods("POST") + // Register models with handler + models := testmodels.GetTestModels() + modelNames := []string{"departments", "employees", "projects", "project_tasks", "documents", "comments"} + for i, model := range models { + handler.RegisterModel("public", modelNames[i], model) + } - r.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - handler.HandleGet(w, r, vars) - }).Methods("GET") + // Setup routes using new SetupMuxRoutes function + resolvespec.SetupMuxRoutes(r, handler) // Start server logger.Info("Starting server on :8080") @@ -83,7 +78,7 @@ func initDB() (*gorm.DB, error) { return nil, err } - modelList := models.GetModels() + modelList := testmodels.GetTestModels() // Auto migrate schemas err = db.AutoMigrate(modelList...) diff --git a/pkg/resolvespec/bun_adapter.go b/pkg/common/adapters/database/bun.go similarity index 67% rename from pkg/resolvespec/bun_adapter.go rename to pkg/common/adapters/database/bun.go index 7506626..a26b675 100644 --- a/pkg/resolvespec/bun_adapter.go +++ b/pkg/common/adapters/database/bun.go @@ -1,10 +1,11 @@ -package resolvespec +package database import ( "context" "database/sql" "fmt" + "github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/uptrace/bun" ) @@ -19,23 +20,23 @@ func NewBunAdapter(db *bun.DB) *BunAdapter { return &BunAdapter{db: db} } -func (b *BunAdapter) NewSelect() SelectQuery { +func (b *BunAdapter) NewSelect() common.SelectQuery { return &BunSelectQuery{query: b.db.NewSelect()} } -func (b *BunAdapter) NewInsert() InsertQuery { +func (b *BunAdapter) NewInsert() common.InsertQuery { return &BunInsertQuery{query: b.db.NewInsert()} } -func (b *BunAdapter) NewUpdate() UpdateQuery { +func (b *BunAdapter) NewUpdate() common.UpdateQuery { return &BunUpdateQuery{query: b.db.NewUpdate()} } -func (b *BunAdapter) NewDelete() DeleteQuery { +func (b *BunAdapter) NewDelete() common.DeleteQuery { return &BunDeleteQuery{query: b.db.NewDelete()} } -func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) { +func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { result, err := b.db.ExecContext(ctx, query, args...) return &BunResult{result: result}, err } @@ -44,7 +45,7 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, return b.db.NewRaw(query, args...).Scan(ctx, dest) } -func (b *BunAdapter) BeginTx(ctx context.Context) (Database, error) { +func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) { tx, err := b.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return nil, err @@ -60,14 +61,14 @@ func (b *BunAdapter) CommitTx(ctx context.Context) error { } func (b *BunAdapter) RollbackTx(ctx context.Context) error { - // For Bun, we need to handle this differently + // For Bun, we need to handle this differently // This is a simplified implementation return nil } -func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(Database) error) error { +func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - // Create adapter with transaction + // Create adapter with transaction adapter := &BunTxAdapter{tx: tx} return fn(adapter) }) @@ -78,62 +79,70 @@ type BunSelectQuery struct { query *bun.SelectQuery } -func (b *BunSelectQuery) Model(model interface{}) SelectQuery { +func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { b.query = b.query.Model(model) return b } -func (b *BunSelectQuery) Table(table string) SelectQuery { +func (b *BunSelectQuery) Table(table string) common.SelectQuery { b.query = b.query.Table(table) return b } -func (b *BunSelectQuery) Column(columns ...string) SelectQuery { +func (b *BunSelectQuery) Column(columns ...string) common.SelectQuery { b.query = b.query.Column(columns...) return b } -func (b *BunSelectQuery) Where(query string, args ...interface{}) SelectQuery { +func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { b.query = b.query.Where(query, args...) return b } -func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { +func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { b.query = b.query.WhereOr(query, args...) return b } -func (b *BunSelectQuery) Join(query string, args ...interface{}) SelectQuery { +func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { b.query = b.query.Join(query, args...) return b } -func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { +func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { b.query = b.query.Join("LEFT JOIN " + query, args...) return b } -func (b *BunSelectQuery) Order(order string) SelectQuery { +func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery { + // Bun uses Relation() method for preloading + // For now, we'll just pass the relation name without conditions + // TODO: Implement proper condition handling for Bun + b.query = b.query.Relation(relation) + return b +} + +func (b *BunSelectQuery) Order(order string) common.SelectQuery { b.query = b.query.Order(order) return b } -func (b *BunSelectQuery) Limit(n int) SelectQuery { +func (b *BunSelectQuery) Limit(n int) common.SelectQuery { b.query = b.query.Limit(n) return b } -func (b *BunSelectQuery) Offset(n int) SelectQuery { +func (b *BunSelectQuery) Offset(n int) common.SelectQuery { b.query = b.query.Offset(n) return b } -func (b *BunSelectQuery) Group(group string) SelectQuery { +func (b *BunSelectQuery) Group(group string) common.SelectQuery { b.query = b.query.Group(group) return b } -func (b *BunSelectQuery) Having(having string, args ...interface{}) SelectQuery { +func (b *BunSelectQuery) Having(having string, args ...interface{}) common.SelectQuery { b.query = b.query.Having(having, args...) return b } @@ -157,17 +166,17 @@ type BunInsertQuery struct { values map[string]interface{} } -func (b *BunInsertQuery) Model(model interface{}) InsertQuery { +func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery { b.query = b.query.Model(model) return b } -func (b *BunInsertQuery) Table(table string) InsertQuery { +func (b *BunInsertQuery) Table(table string) common.InsertQuery { b.query = b.query.Table(table) return b } -func (b *BunInsertQuery) Value(column string, value interface{}) InsertQuery { +func (b *BunInsertQuery) Value(column string, value interface{}) common.InsertQuery { if b.values == nil { b.values = make(map[string]interface{}) } @@ -175,19 +184,19 @@ func (b *BunInsertQuery) Value(column string, value interface{}) InsertQuery { return b } -func (b *BunInsertQuery) OnConflict(action string) InsertQuery { +func (b *BunInsertQuery) OnConflict(action string) common.InsertQuery { b.query = b.query.On(action) return b } -func (b *BunInsertQuery) Returning(columns ...string) InsertQuery { +func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery { if len(columns) > 0 { b.query = b.query.Returning(columns[0]) } return b } -func (b *BunInsertQuery) Exec(ctx context.Context) (Result, error) { +func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) { if b.values != nil { // For Bun, we need to handle this differently for k, v := range b.values { @@ -203,41 +212,41 @@ type BunUpdateQuery struct { query *bun.UpdateQuery } -func (b *BunUpdateQuery) Model(model interface{}) UpdateQuery { +func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery { b.query = b.query.Model(model) return b } -func (b *BunUpdateQuery) Table(table string) UpdateQuery { +func (b *BunUpdateQuery) Table(table string) common.UpdateQuery { b.query = b.query.Table(table) return b } -func (b *BunUpdateQuery) Set(column string, value interface{}) UpdateQuery { +func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { b.query = b.query.Set(column+" = ?", value) return b } -func (b *BunUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery { +func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { for column, value := range values { b.query = b.query.Set(column+" = ?", value) } return b } -func (b *BunUpdateQuery) Where(query string, args ...interface{}) UpdateQuery { +func (b *BunUpdateQuery) Where(query string, args ...interface{}) common.UpdateQuery { b.query = b.query.Where(query, args...) return b } -func (b *BunUpdateQuery) Returning(columns ...string) UpdateQuery { +func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery { if len(columns) > 0 { b.query = b.query.Returning(columns[0]) } return b } -func (b *BunUpdateQuery) Exec(ctx context.Context) (Result, error) { +func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) { result, err := b.query.Exec(ctx) return &BunResult{result: result}, err } @@ -247,22 +256,22 @@ type BunDeleteQuery struct { query *bun.DeleteQuery } -func (b *BunDeleteQuery) Model(model interface{}) DeleteQuery { +func (b *BunDeleteQuery) Model(model interface{}) common.DeleteQuery { b.query = b.query.Model(model) return b } -func (b *BunDeleteQuery) Table(table string) DeleteQuery { +func (b *BunDeleteQuery) Table(table string) common.DeleteQuery { b.query = b.query.Table(table) return b } -func (b *BunDeleteQuery) Where(query string, args ...interface{}) DeleteQuery { +func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQuery { b.query = b.query.Where(query, args...) return b } -func (b *BunDeleteQuery) Exec(ctx context.Context) (Result, error) { +func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) { result, err := b.query.Exec(ctx) return &BunResult{result: result}, err } @@ -292,23 +301,23 @@ type BunTxAdapter struct { tx bun.Tx } -func (b *BunTxAdapter) NewSelect() SelectQuery { +func (b *BunTxAdapter) NewSelect() common.SelectQuery { return &BunSelectQuery{query: b.tx.NewSelect()} } -func (b *BunTxAdapter) NewInsert() InsertQuery { +func (b *BunTxAdapter) NewInsert() common.InsertQuery { return &BunInsertQuery{query: b.tx.NewInsert()} } -func (b *BunTxAdapter) NewUpdate() UpdateQuery { +func (b *BunTxAdapter) NewUpdate() common.UpdateQuery { return &BunUpdateQuery{query: b.tx.NewUpdate()} } -func (b *BunTxAdapter) NewDelete() DeleteQuery { +func (b *BunTxAdapter) NewDelete() common.DeleteQuery { return &BunDeleteQuery{query: b.tx.NewDelete()} } -func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) { +func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { result, err := b.tx.ExecContext(ctx, query, args...) return &BunResult{result: result}, err } @@ -317,7 +326,7 @@ func (b *BunTxAdapter) Query(ctx context.Context, dest interface{}, query string return b.tx.NewRaw(query, args...).Scan(ctx, dest) } -func (b *BunTxAdapter) BeginTx(ctx context.Context) (Database, error) { +func (b *BunTxAdapter) BeginTx(ctx context.Context) (common.Database, error) { return nil, fmt.Errorf("nested transactions not supported") } @@ -329,6 +338,6 @@ func (b *BunTxAdapter) RollbackTx(ctx context.Context) error { return b.tx.Rollback() } -func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(Database) error) error { +func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { return fn(b) // Already in transaction -} \ No newline at end of file +} diff --git a/pkg/resolvespec/gorm_adapter.go b/pkg/common/adapters/database/gorm.go similarity index 69% rename from pkg/resolvespec/gorm_adapter.go rename to pkg/common/adapters/database/gorm.go index 643b2d1..26de9e9 100644 --- a/pkg/resolvespec/gorm_adapter.go +++ b/pkg/common/adapters/database/gorm.go @@ -1,7 +1,9 @@ -package resolvespec +package database import ( "context" + + "github.com/Warky-Devs/ResolveSpec/pkg/common" "gorm.io/gorm" ) @@ -15,23 +17,23 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter { return &GormAdapter{db: db} } -func (g *GormAdapter) NewSelect() SelectQuery { +func (g *GormAdapter) NewSelect() common.SelectQuery { return &GormSelectQuery{db: g.db} } -func (g *GormAdapter) NewInsert() InsertQuery { +func (g *GormAdapter) NewInsert() common.InsertQuery { return &GormInsertQuery{db: g.db} } -func (g *GormAdapter) NewUpdate() UpdateQuery { +func (g *GormAdapter) NewUpdate() common.UpdateQuery { return &GormUpdateQuery{db: g.db} } -func (g *GormAdapter) NewDelete() DeleteQuery { +func (g *GormAdapter) NewDelete() common.DeleteQuery { return &GormDeleteQuery{db: g.db} } -func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) { +func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { result := g.db.WithContext(ctx).Exec(query, args...) return &GormResult{result: result}, result.Error } @@ -40,7 +42,7 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error } -func (g *GormAdapter) BeginTx(ctx context.Context) (Database, error) { +func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) { tx := g.db.WithContext(ctx).Begin() if tx.Error != nil { return nil, tx.Error @@ -56,7 +58,7 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error { return g.db.WithContext(ctx).Rollback().Error } -func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(Database) error) error { +func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { adapter := &GormAdapter{db: tx} return fn(adapter) @@ -68,62 +70,67 @@ type GormSelectQuery struct { db *gorm.DB } -func (g *GormSelectQuery) Model(model interface{}) SelectQuery { +func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { g.db = g.db.Model(model) return g } -func (g *GormSelectQuery) Table(table string) SelectQuery { +func (g *GormSelectQuery) Table(table string) common.SelectQuery { g.db = g.db.Table(table) return g } -func (g *GormSelectQuery) Column(columns ...string) SelectQuery { +func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery { g.db = g.db.Select(columns) return g } -func (g *GormSelectQuery) Where(query string, args ...interface{}) SelectQuery { +func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { g.db = g.db.Where(query, args...) return g } -func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { +func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { g.db = g.db.Or(query, args...) return g } -func (g *GormSelectQuery) Join(query string, args ...interface{}) SelectQuery { +func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { g.db = g.db.Joins(query, args...) return g } -func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { +func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { g.db = g.db.Joins("LEFT JOIN "+query, args...) return g } -func (g *GormSelectQuery) Order(order string) SelectQuery { +func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery { + g.db = g.db.Preload(relation, conditions...) + return g +} + +func (g *GormSelectQuery) Order(order string) common.SelectQuery { g.db = g.db.Order(order) return g } -func (g *GormSelectQuery) Limit(n int) SelectQuery { +func (g *GormSelectQuery) Limit(n int) common.SelectQuery { g.db = g.db.Limit(n) return g } -func (g *GormSelectQuery) Offset(n int) SelectQuery { +func (g *GormSelectQuery) Offset(n int) common.SelectQuery { g.db = g.db.Offset(n) return g } -func (g *GormSelectQuery) Group(group string) SelectQuery { +func (g *GormSelectQuery) Group(group string) common.SelectQuery { g.db = g.db.Group(group) return g } -func (g *GormSelectQuery) Having(having string, args ...interface{}) SelectQuery { +func (g *GormSelectQuery) Having(having string, args ...interface{}) common.SelectQuery { g.db = g.db.Having(having, args...) return g } @@ -146,23 +153,23 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) { // GormInsertQuery implements InsertQuery for GORM type GormInsertQuery struct { - db *gorm.DB - model interface{} + db *gorm.DB + model interface{} values map[string]interface{} } -func (g *GormInsertQuery) Model(model interface{}) InsertQuery { +func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery { g.model = model g.db = g.db.Model(model) return g } -func (g *GormInsertQuery) Table(table string) InsertQuery { +func (g *GormInsertQuery) Table(table string) common.InsertQuery { g.db = g.db.Table(table) return g } -func (g *GormInsertQuery) Value(column string, value interface{}) InsertQuery { +func (g *GormInsertQuery) Value(column string, value interface{}) common.InsertQuery { if g.values == nil { g.values = make(map[string]interface{}) } @@ -170,17 +177,17 @@ func (g *GormInsertQuery) Value(column string, value interface{}) InsertQuery { return g } -func (g *GormInsertQuery) OnConflict(action string) InsertQuery { +func (g *GormInsertQuery) OnConflict(action string) common.InsertQuery { // GORM handles conflicts differently, this would need specific implementation return g } -func (g *GormInsertQuery) Returning(columns ...string) InsertQuery { +func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery { // GORM doesn't have explicit RETURNING, but updates the model return g } -func (g *GormInsertQuery) Exec(ctx context.Context) (Result, error) { +func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) { var result *gorm.DB if g.model != nil { result = g.db.WithContext(ctx).Create(g.model) @@ -194,23 +201,23 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (Result, error) { // GormUpdateQuery implements UpdateQuery for GORM type GormUpdateQuery struct { - db *gorm.DB - model interface{} + db *gorm.DB + model interface{} updates interface{} } -func (g *GormUpdateQuery) Model(model interface{}) UpdateQuery { +func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery { g.model = model g.db = g.db.Model(model) return g } -func (g *GormUpdateQuery) Table(table string) UpdateQuery { +func (g *GormUpdateQuery) Table(table string) common.UpdateQuery { g.db = g.db.Table(table) return g } -func (g *GormUpdateQuery) Set(column string, value interface{}) UpdateQuery { +func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { if g.updates == nil { g.updates = make(map[string]interface{}) } @@ -220,22 +227,22 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) UpdateQuery { return g } -func (g *GormUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery { +func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { g.updates = values return g } -func (g *GormUpdateQuery) Where(query string, args ...interface{}) UpdateQuery { +func (g *GormUpdateQuery) Where(query string, args ...interface{}) common.UpdateQuery { g.db = g.db.Where(query, args...) return g } -func (g *GormUpdateQuery) Returning(columns ...string) UpdateQuery { +func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery { // GORM doesn't have explicit RETURNING return g } -func (g *GormUpdateQuery) Exec(ctx context.Context) (Result, error) { +func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) { result := g.db.WithContext(ctx).Updates(g.updates) return &GormResult{result: result}, result.Error } @@ -246,23 +253,23 @@ type GormDeleteQuery struct { model interface{} } -func (g *GormDeleteQuery) Model(model interface{}) DeleteQuery { +func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery { g.model = model g.db = g.db.Model(model) return g } -func (g *GormDeleteQuery) Table(table string) DeleteQuery { +func (g *GormDeleteQuery) Table(table string) common.DeleteQuery { g.db = g.db.Table(table) return g } -func (g *GormDeleteQuery) Where(query string, args ...interface{}) DeleteQuery { +func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.DeleteQuery { g.db = g.db.Where(query, args...) return g } -func (g *GormDeleteQuery) Exec(ctx context.Context) (Result, error) { +func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) { result := g.db.WithContext(ctx).Delete(g.model) return &GormResult{result: result}, result.Error } @@ -279,4 +286,4 @@ func (g *GormResult) RowsAffected() int64 { func (g *GormResult) LastInsertId() (int64, error) { // GORM doesn't directly provide last insert ID, would need specific implementation return 0, nil -} \ No newline at end of file +} diff --git a/pkg/resolvespec/bunrouter_adapter.go b/pkg/common/adapters/router/bunrouter.go similarity index 67% rename from pkg/resolvespec/bunrouter_adapter.go rename to pkg/common/adapters/router/bunrouter.go index 183c2d6..abe21e8 100644 --- a/pkg/resolvespec/bunrouter_adapter.go +++ b/pkg/common/adapters/router/bunrouter.go @@ -1,8 +1,9 @@ -package resolvespec +package router import ( "net/http" + "github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/uptrace/bunrouter" ) @@ -21,7 +22,7 @@ func NewBunRouterAdapterDefault() *BunRouterAdapter { return &BunRouterAdapter{router: bunrouter.New()} } -func (b *BunRouterAdapter) HandleFunc(pattern string, handler HTTPHandlerFunc) RouteRegistration { +func (b *BunRouterAdapter) HandleFunc(pattern string, handler common.HTTPHandlerFunc) common.RouteRegistration { route := &BunRouterRegistration{ router: b.router, pattern: pattern, @@ -30,7 +31,7 @@ func (b *BunRouterAdapter) HandleFunc(pattern string, handler HTTPHandlerFunc) R return route } -func (b *BunRouterAdapter) ServeHTTP(w ResponseWriter, r Request) { +func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) { // This method would be used when we need to serve through our interface // For now, we'll work directly with the underlying router panic("ServeHTTP not implemented - use GetBunRouter() for direct access") @@ -45,16 +46,16 @@ func (b *BunRouterAdapter) GetBunRouter() *bunrouter.Router { type BunRouterRegistration struct { router *bunrouter.Router pattern string - handler HTTPHandlerFunc + handler common.HTTPHandlerFunc } -func (b *BunRouterRegistration) Methods(methods ...string) RouteRegistration { +func (b *BunRouterRegistration) Methods(methods ...string) common.RouteRegistration { // bunrouter handles methods differently - we'll register for each method for _, method := range methods { b.router.Handle(method, b.pattern, func(w http.ResponseWriter, req bunrouter.Request) error { - // Convert bunrouter.Request to our HTTPRequest + // Convert bunrouter.Request to our BunRouterRequest reqAdapter := &BunRouterRequest{req: req} - respAdapter := NewHTTPResponseWriter(w) + respAdapter := &HTTPResponseWriter{resp: w} b.handler(respAdapter, reqAdapter) return nil }) @@ -62,7 +63,7 @@ func (b *BunRouterRegistration) Methods(methods ...string) RouteRegistration { return b } -func (b *BunRouterRegistration) PathPrefix(prefix string) RouteRegistration { +func (b *BunRouterRegistration) PathPrefix(prefix string) common.RouteRegistration { // bunrouter doesn't have PathPrefix like mux, but we can modify the pattern newPattern := prefix + b.pattern b.pattern = newPattern @@ -75,6 +76,11 @@ type BunRouterRequest struct { body []byte } +// NewBunRouterRequest creates a new BunRouterRequest adapter +func NewBunRouterRequest(req bunrouter.Request) *BunRouterRequest { + return &BunRouterRequest{req: req} +} + func (b *BunRouterRequest) Method() string { return b.req.Method } @@ -91,11 +97,11 @@ func (b *BunRouterRequest) Body() ([]byte, error) { if b.body != nil { return b.body, nil } - + if b.req.Body == nil { return nil, nil } - + // Create HTTPRequest adapter and use its Body() method httpAdapter := NewHTTPRequest(b.req.Request) body, err := httpAdapter.Body() @@ -114,6 +120,16 @@ func (b *BunRouterRequest) QueryParam(key string) string { return b.req.URL.Query().Get(key) } +func (b *BunRouterRequest) AllHeaders() map[string]string { + headers := make(map[string]string) + for key, values := range b.req.Header { + if len(values) > 0 { + headers[key] = values[0] + } + } + return headers +} + // StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers type StandardBunRouterAdapter struct { *BunRouterAdapter @@ -125,16 +141,16 @@ func NewStandardBunRouterAdapter() *StandardBunRouterAdapter { } } -// RegisterRoute registers a route that works with the existing APIHandler +// RegisterRoute registers a route that works with the existing Handler func (s *StandardBunRouterAdapter) RegisterRoute(method, pattern string, handler func(http.ResponseWriter, *http.Request, map[string]string)) { s.router.Handle(method, pattern, func(w http.ResponseWriter, req bunrouter.Request) error { // Extract path parameters params := make(map[string]string) - + // bunrouter doesn't provide a direct way to get all params // You would typically access them individually with req.Param("name") // For this example, we'll create the map based on the request context - + handler(w, req.Request, params) return nil }) @@ -148,7 +164,7 @@ func (s *StandardBunRouterAdapter) RegisterRouteWithParams(method, pattern strin for _, paramName := range paramNames { params[paramName] = req.Param(paramName) } - + handler(w, req.Request, params) return nil }) @@ -156,63 +172,22 @@ func (s *StandardBunRouterAdapter) RegisterRouteWithParams(method, pattern strin // BunRouterConfig holds bunrouter-specific configuration type BunRouterConfig struct { - UseStrictSlash bool - RedirectTrailingSlash bool + UseStrictSlash bool + RedirectTrailingSlash bool HandleMethodNotAllowed bool - HandleOPTIONS bool - GlobalOPTIONS http.Handler + HandleOPTIONS bool + GlobalOPTIONS http.Handler GlobalMethodNotAllowed http.Handler - PanicHandler func(http.ResponseWriter, *http.Request, interface{}) + PanicHandler func(http.ResponseWriter, *http.Request, interface{}) } // DefaultBunRouterConfig returns default bunrouter configuration func DefaultBunRouterConfig() *BunRouterConfig { return &BunRouterConfig{ - UseStrictSlash: false, - RedirectTrailingSlash: true, + UseStrictSlash: false, + RedirectTrailingSlash: true, HandleMethodNotAllowed: true, - HandleOPTIONS: true, + HandleOPTIONS: true, } } -// SetupBunRouterWithResolveSpec sets up bunrouter routes for ResolveSpec -func SetupBunRouterWithResolveSpec(router *bunrouter.Router, handler *APIHandlerCompat) { - // Setup standard ResolveSpec routes with bunrouter - router.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { - params := map[string]string{ - "schema": req.Param("schema"), - "entity": req.Param("entity"), - } - handler.Handle(w, req.Request, params) - return nil - }) - - router.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { - params := map[string]string{ - "schema": req.Param("schema"), - "entity": req.Param("entity"), - "id": req.Param("id"), - } - handler.Handle(w, req.Request, params) - return nil - }) - - router.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { - params := map[string]string{ - "schema": req.Param("schema"), - "entity": req.Param("entity"), - } - handler.HandleGet(w, req.Request, params) - return nil - }) - - router.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { - params := map[string]string{ - "schema": req.Param("schema"), - "entity": req.Param("entity"), - "id": req.Param("id"), - } - handler.HandleGet(w, req.Request, params) - return nil - }) -} \ No newline at end of file diff --git a/pkg/resolvespec/router_adapters.go b/pkg/common/adapters/router/mux.go similarity index 87% rename from pkg/resolvespec/router_adapters.go rename to pkg/common/adapters/router/mux.go index b66eb4e..8d0d33f 100644 --- a/pkg/resolvespec/router_adapters.go +++ b/pkg/common/adapters/router/mux.go @@ -1,10 +1,11 @@ -package resolvespec +package router import ( "encoding/json" "io" "net/http" + "github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/gorilla/mux" ) @@ -18,7 +19,7 @@ func NewMuxAdapter(router *mux.Router) *MuxAdapter { return &MuxAdapter{router: router} } -func (m *MuxAdapter) HandleFunc(pattern string, handler HTTPHandlerFunc) RouteRegistration { +func (m *MuxAdapter) HandleFunc(pattern string, handler common.HTTPHandlerFunc) common.RouteRegistration { route := &MuxRouteRegistration{ router: m.router, pattern: pattern, @@ -27,7 +28,7 @@ func (m *MuxAdapter) HandleFunc(pattern string, handler HTTPHandlerFunc) RouteRe return route } -func (m *MuxAdapter) ServeHTTP(w ResponseWriter, r Request) { +func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) { // This method would be used when we need to serve through our interface // For now, we'll work directly with the underlying router panic("ServeHTTP not implemented - use GetMuxRouter() for direct access") @@ -37,11 +38,11 @@ func (m *MuxAdapter) ServeHTTP(w ResponseWriter, r Request) { type MuxRouteRegistration struct { router *mux.Router pattern string - handler HTTPHandlerFunc + handler common.HTTPHandlerFunc route *mux.Route } -func (m *MuxRouteRegistration) Methods(methods ...string) RouteRegistration { +func (m *MuxRouteRegistration) Methods(methods ...string) common.RouteRegistration { if m.route == nil { m.route = m.router.HandleFunc(m.pattern, func(w http.ResponseWriter, r *http.Request) { reqAdapter := &HTTPRequest{req: r, vars: mux.Vars(r)} @@ -53,7 +54,7 @@ func (m *MuxRouteRegistration) Methods(methods ...string) RouteRegistration { return m } -func (m *MuxRouteRegistration) PathPrefix(prefix string) RouteRegistration { +func (m *MuxRouteRegistration) PathPrefix(prefix string) common.RouteRegistration { if m.route == nil { m.route = m.router.HandleFunc(m.pattern, func(w http.ResponseWriter, r *http.Request) { reqAdapter := &HTTPRequest{req: r, vars: mux.Vars(r)} @@ -115,10 +116,20 @@ func (h *HTTPRequest) QueryParam(key string) string { return h.req.URL.Query().Get(key) } +func (h *HTTPRequest) AllHeaders() map[string]string { + headers := make(map[string]string) + for key, values := range h.req.Header { + if len(values) > 0 { + headers[key] = values[0] + } + } + return headers +} + // HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter type HTTPResponseWriter struct { resp http.ResponseWriter - w ResponseWriter + w common.ResponseWriter status int } @@ -126,7 +137,6 @@ func NewHTTPResponseWriter(w http.ResponseWriter) *HTTPResponseWriter { return &HTTPResponseWriter{resp: w} } - func (h *HTTPResponseWriter) SetHeader(key, value string) { h.resp.Header().Set(key, value) } @@ -156,7 +166,7 @@ func NewStandardMuxAdapter() *StandardMuxAdapter { } } -// RegisterRoute registers a route that works with the existing APIHandler +// RegisterRoute registers a route that works with the existing Handler func (s *StandardMuxAdapter) RegisterRoute(pattern string, handler func(http.ResponseWriter, *http.Request, map[string]string)) *mux.Route { return s.router.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) @@ -169,18 +179,6 @@ func (s *StandardMuxAdapter) GetMuxRouter() *mux.Router { return s.router } -// GinAdapter for future Gin support -type GinAdapter struct { - // This would be implemented when Gin support is needed - // engine *gin.Engine -} - -// EchoAdapter for future Echo support -type EchoAdapter struct { - // This would be implemented when Echo support is needed - // echo *echo.Echo -} - // PathParamExtractor extracts path parameters from different router types type PathParamExtractor interface { ExtractParams(*http.Request) map[string]string @@ -207,4 +205,4 @@ func DefaultRouterConfig() *RouterConfig { Middleware: make([]func(http.Handler) http.Handler, 0), ParamExtractor: MuxParamExtractor{}, } -} \ No newline at end of file +} diff --git a/pkg/resolvespec/database.go b/pkg/common/interfaces.go similarity index 94% rename from pkg/resolvespec/database.go rename to pkg/common/interfaces.go index 4f1525d..bc3bb1d 100644 --- a/pkg/resolvespec/database.go +++ b/pkg/common/interfaces.go @@ -1,4 +1,4 @@ -package resolvespec +package common import "context" @@ -9,11 +9,11 @@ type Database interface { NewInsert() InsertQuery NewUpdate() UpdateQuery NewDelete() DeleteQuery - + // Raw SQL execution Exec(ctx context.Context, query string, args ...interface{}) (Result, error) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error - + // Transaction support BeginTx(ctx context.Context) (Database, error) CommitTx(ctx context.Context) error @@ -30,12 +30,13 @@ type SelectQuery interface { WhereOr(query string, args ...interface{}) SelectQuery Join(query string, args ...interface{}) SelectQuery LeftJoin(query string, args ...interface{}) SelectQuery + Preload(relation string, conditions ...interface{}) SelectQuery Order(order string) SelectQuery Limit(n int) SelectQuery Offset(n int) SelectQuery Group(group string) SelectQuery Having(having string, args ...interface{}) SelectQuery - + // Execution methods Scan(ctx context.Context, dest interface{}) error Count(ctx context.Context) (int, error) @@ -49,7 +50,7 @@ type InsertQuery interface { Value(column string, value interface{}) InsertQuery OnConflict(action string) InsertQuery Returning(columns ...string) InsertQuery - + // Execution Exec(ctx context.Context) (Result, error) } @@ -62,7 +63,7 @@ type UpdateQuery interface { SetMap(values map[string]interface{}) UpdateQuery Where(query string, args ...interface{}) UpdateQuery Returning(columns ...string) UpdateQuery - + // Execution Exec(ctx context.Context) (Result, error) } @@ -72,7 +73,7 @@ type DeleteQuery interface { Model(model interface{}) DeleteQuery Table(table string) DeleteQuery Where(query string, args ...interface{}) DeleteQuery - + // Execution Exec(ctx context.Context) (Result, error) } @@ -94,7 +95,7 @@ type ModelRegistry interface { // Router interface for HTTP router abstraction type Router interface { HandleFunc(pattern string, handler HTTPHandlerFunc) RouteRegistration - ServeHTTP(w ResponseWriter, r *Request) + ServeHTTP(w ResponseWriter, r Request) } // RouteRegistration allows method chaining for route configuration @@ -108,6 +109,7 @@ type Request interface { Method() string URL() string Header(key string) string + AllHeaders() map[string]string // Get all headers as a map Body() ([]byte, error) PathParam(key string) string QueryParam(key string) string @@ -121,7 +123,7 @@ type ResponseWriter interface { WriteJSON(data interface{}) error } -// HTTPHandlerFunc type for HTTP handlers +// HTTPHandlerFunc type for HTTP handlers type HTTPHandlerFunc func(ResponseWriter, Request) // TableNameProvider interface for models that provide table names @@ -129,7 +131,7 @@ type TableNameProvider interface { TableName() string } -// SchemaProvider interface for models that provide schema names +// SchemaProvider interface for models that provide schema names type SchemaProvider interface { SchemaName() string -} \ No newline at end of file +} diff --git a/pkg/resolvespec/types.go b/pkg/common/types.go similarity index 99% rename from pkg/resolvespec/types.go rename to pkg/common/types.go index 50c4a79..56e6f60 100644 --- a/pkg/resolvespec/types.go +++ b/pkg/common/types.go @@ -1,4 +1,4 @@ -package resolvespec +package common type RequestBody struct { Operation string `json:"operation"` diff --git a/pkg/resolvespec/model_registry.go b/pkg/modelregistry/model_registry.go similarity index 54% rename from pkg/resolvespec/model_registry.go rename to pkg/modelregistry/model_registry.go index 9731c88..d30eb25 100644 --- a/pkg/resolvespec/model_registry.go +++ b/pkg/modelregistry/model_registry.go @@ -1,4 +1,4 @@ -package resolvespec +package modelregistry import ( "fmt" @@ -11,6 +11,11 @@ type DefaultModelRegistry struct { mutex sync.RWMutex } +// Global default registry instance +var defaultRegistry = &DefaultModelRegistry{ + models: make(map[string]interface{}), +} + // NewModelRegistry creates a new model registry func NewModelRegistry() *DefaultModelRegistry { return &DefaultModelRegistry{ @@ -59,7 +64,41 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac if model, err := r.GetModel(fullName); err == nil { return model, nil } - + // Fallback to entity name only return r.GetModel(entity) +} + +// Global convenience functions using the default registry + +// RegisterModel registers a model with the default global registry +func RegisterModel(model interface{}, name string) error { + return defaultRegistry.RegisterModel(name, model) +} + +// GetModelByName retrieves a model from the default global registry by name +func GetModelByName(name string) (interface{}, error) { + return defaultRegistry.GetModel(name) +} + +// IterateModels iterates over all models in the default global registry +func IterateModels(fn func(name string, model interface{})) { + defaultRegistry.mutex.RLock() + defer defaultRegistry.mutex.RUnlock() + + for name, model := range defaultRegistry.models { + fn(name, model) + } +} + +// GetModels returns a list of all models in the default global registry +func GetModels() []interface{} { + defaultRegistry.mutex.RLock() + defer defaultRegistry.mutex.RUnlock() + + models := make([]interface{}, 0, len(defaultRegistry.models)) + for _, model := range defaultRegistry.models { + models = append(models, model) + } + return models } \ No newline at end of file diff --git a/pkg/models/registry.go b/pkg/models/registry.go deleted file mode 100644 index 67927c6..0000000 --- a/pkg/models/registry.go +++ /dev/null @@ -1,71 +0,0 @@ -package models - -import ( - "fmt" - "reflect" - "sync" -) - -var ( - modelRegistry = make(map[string]interface{}) - functionRegistry = make(map[string]interface{}) - modelRegistryMutex sync.RWMutex - funcRegistryMutex sync.RWMutex -) - -// RegisterModel registers a model type with the registry -// The model must be a struct or a pointer to a struct -// e.g RegisterModel(&ModelPublicUser{},"public.user") -func RegisterModel(model interface{}, name string) error { - modelRegistryMutex.Lock() - defer modelRegistryMutex.Unlock() - - modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - if name == "" { - name = modelType.Name() - } - modelRegistry[name] = model - return nil -} - -// RegisterFunction register a function with the registry -func RegisterFunction(fn interface{}, name string) { - funcRegistryMutex.Lock() - defer funcRegistryMutex.Unlock() - functionRegistry[name] = fn -} - -// GetModelByName retrieves a model from the registry by its type name -func GetModelByName(name string) (interface{}, error) { - modelRegistryMutex.RLock() - defer modelRegistryMutex.RUnlock() - - if modelRegistry[name] == nil { - return nil, fmt.Errorf("model not found: %s", name) - } - return modelRegistry[name], nil -} - -// IterateModels iterates over all models in the registry -func IterateModels(fn func(name string, model interface{})) { - modelRegistryMutex.RLock() - defer modelRegistryMutex.RUnlock() - - for name, model := range modelRegistry { - fn(name, model) - } -} - -// GetModels returns a list of all models in the registry -func GetModels() []interface{} { - models := make([]interface{}, 0) - modelRegistryMutex.RLock() - defer modelRegistryMutex.RUnlock() - for _, model := range modelRegistry { - models = append(models, model) - } - return models -} diff --git a/pkg/resolvespec/apiHandler.go b/pkg/resolvespec/apiHandler.go deleted file mode 100644 index cfd3941..0000000 --- a/pkg/resolvespec/apiHandler.go +++ /dev/null @@ -1,91 +0,0 @@ -package resolvespec - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - - "github.com/Warky-Devs/ResolveSpec/pkg/logger" - "gorm.io/gorm" -) - -type HandlerFunc func(http.ResponseWriter, *http.Request) - -type LegacyAPIHandler struct { - db *gorm.DB -} - -// NewLegacyAPIHandler creates a new legacy API handler instance -func NewLegacyAPIHandler(db *gorm.DB) *LegacyAPIHandler { - return &LegacyAPIHandler{ - db: db, - } -} - -// Main handler method -func (h *LegacyAPIHandler) Handle(w http.ResponseWriter, r *http.Request, params map[string]string) { - var req RequestBody - - if r.Body == nil { - logger.Error("No body to decode") - h.sendError(w, http.StatusBadRequest, "invalid_request", "No body to decode", nil) - return - } else { - defer r.Body.Close() - } - if bodyContents, err := io.ReadAll(r.Body); err != nil { - logger.Error("Failed to decode read body: %v", err) - h.sendError(w, http.StatusBadRequest, "read_request", "Invalid request body", err) - return - } else { - if err := json.Unmarshal(bodyContents, &req); err != nil { - logger.Error("Failed to decode request body: %v", err) - h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err) - return - } - } - - schema := params["schema"] - entity := params["entity"] - id := params["id"] - - logger.Info("Handling %s operation for %s.%s", req.Operation, schema, entity) - - switch req.Operation { - case "read": - h.handleRead(w, r, schema, entity, id, req.Options) - case "create": - h.handleCreate(w, r, schema, entity, req.Data, req.Options) - case "update": - h.handleUpdate(w, r, schema, entity, id, req.ID, req.Data, req.Options) - case "delete": - h.handleDelete(w, r, schema, entity, id) - default: - logger.Error("Invalid operation: %s", req.Operation) - h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil) - } -} - -func (h *LegacyAPIHandler) sendResponse(w http.ResponseWriter, data interface{}, metadata *Metadata) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(Response{ - Success: true, - Data: data, - Metadata: metadata, - }) -} - -func (h *LegacyAPIHandler) sendError(w http.ResponseWriter, status int, code, message string, details interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - json.NewEncoder(w).Encode(Response{ - Success: false, - Error: &APIError{ - Code: code, - Message: message, - Details: details, - Detail: fmt.Sprintf("%v", details), - }, - }) -} diff --git a/pkg/resolvespec/compatibility.go b/pkg/resolvespec/compatibility.go deleted file mode 100644 index 0893282..0000000 --- a/pkg/resolvespec/compatibility.go +++ /dev/null @@ -1,72 +0,0 @@ -package resolvespec - -import ( - "net/http" - - "github.com/Warky-Devs/ResolveSpec/pkg/models" - "gorm.io/gorm" -) - -// NewAPIHandler creates a new APIHandler instance (backward compatibility) -// For now, this returns the legacy APIHandler to maintain full compatibility -// including preloading functionality. Users can opt-in to new abstractions when ready. -func NewAPIHandler(db *gorm.DB) *APIHandlerCompat { - legacyHandler := NewLegacyAPIHandler(db) - - // Initialize new abstractions for future use - gormAdapter := NewGormAdapter(db) - registry := NewModelRegistry() - - // Initialize registry with existing models - models.IterateModels(func(name string, model interface{}) { - registry.RegisterModel(name, model) - }) - - newHandler := NewHandler(gormAdapter, registry) - - return &APIHandlerCompat{ - legacyHandler: legacyHandler, - newHandler: newHandler, - db: db, - } -} - -// APIHandlerCompat provides backward compatibility with the original APIHandler -type APIHandlerCompat struct { - legacyHandler *LegacyAPIHandler // For full backward compatibility - newHandler *Handler // New abstracted handler (optional use) - db *gorm.DB // Legacy GORM reference -} - -// Handle maintains the original signature for backward compatibility -func (a *APIHandlerCompat) Handle(w http.ResponseWriter, r *http.Request, params map[string]string) { - // Use legacy handler to maintain full compatibility including preloading - a.legacyHandler.Handle(w, r, params) -} - -// HandleGet maintains the original signature for backward compatibility -func (a *APIHandlerCompat) HandleGet(w http.ResponseWriter, r *http.Request, params map[string]string) { - // Use legacy handler for metadata - a.legacyHandler.HandleGet(w, r, params) -} - -// RegisterModel maintains the original signature for backward compatibility -func (a *APIHandlerCompat) RegisterModel(schema, name string, model interface{}) error { - // Register with both legacy handler and new handler - err1 := a.legacyHandler.RegisterModel(schema, name, model) - err2 := a.newHandler.RegisterModel(schema, name, model) - if err1 != nil { - return err1 - } - return err2 -} - -// GetNewHandler returns the new abstracted handler for advanced use cases -func (a *APIHandlerCompat) GetNewHandler() *Handler { - return a.newHandler -} - -// GetLegacyHandler returns the legacy handler for cases needing full GORM features -func (a *APIHandlerCompat) GetLegacyHandler() *LegacyAPIHandler { - return a.legacyHandler -} \ No newline at end of file diff --git a/pkg/resolvespec/crud.go b/pkg/resolvespec/crud.go deleted file mode 100644 index 686f73a..0000000 --- a/pkg/resolvespec/crud.go +++ /dev/null @@ -1,250 +0,0 @@ -package resolvespec - -import ( - "fmt" - "net/http" - "reflect" - "strings" - - "github.com/Warky-Devs/ResolveSpec/pkg/logger" - "gorm.io/gorm" -) - -// Read handler -func (h *LegacyAPIHandler) handleRead(w http.ResponseWriter, r *http.Request, schema, entity, id string, options RequestOptions) { - logger.Info("Reading records from %s.%s", schema, entity) - - // Get the model struct for the entity - model, err := h.getModelForEntity(schema, entity) - if err != nil { - logger.Error("Invalid entity: %v", err) - h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) - return - } - - GormTableNameInterface, ok := model.(GormTableNameInterface) - if !ok { - logger.Error("Model does not implement GormTableNameInterface") - h.sendError(w, http.StatusInternalServerError, "model_error", "Model does not implement GormTableNameInterface", nil) - return - } - query := h.db.Model(model).Table(GormTableNameInterface.TableName()) - - // Apply column selection - if len(options.Columns) > 0 { - logger.Debug("Selecting columns: %v", options.Columns) - query = query.Select(options.Columns) - } - - // Apply preloading - for _, preload := range options.Preload { - logger.Debug("Applying preload for relation: %s", preload.Relation) - query = query.Preload(preload.Relation, func(db *gorm.DB) *gorm.DB { - - if len(preload.Columns) > 0 { - db = db.Select(preload.Columns) - } - if len(preload.Filters) > 0 { - for _, filter := range preload.Filters { - db = h.applyFilter(db, filter) - } - } - return db - }) - - } - - // Apply filters - for _, filter := range options.Filters { - logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value) - query = h.applyFilter(query, filter) - } - - // Apply sorting - for _, sort := range options.Sort { - direction := "ASC" - if strings.ToLower(sort.Direction) == "desc" { - direction = "DESC" - } - logger.Debug("Applying sort: %s %s", sort.Column, direction) - query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction)) - } - - // Get total count before pagination - var total int64 - if err := query.Count(&total).Error; err != nil { - logger.Error("Error counting records: %v", err) - h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) - return - } - logger.Debug("Total records before filtering: %d", total) - - // Apply pagination - if options.Limit != nil && *options.Limit > 0 { - logger.Debug("Applying limit: %d", *options.Limit) - query = query.Limit(*options.Limit) - } - if options.Offset != nil && *options.Offset > 0 { - logger.Debug("Applying offset: %d", *options.Offset) - query = query.Offset(*options.Offset) - } - - // Execute query - var result interface{} - if id != "" { - logger.Debug("Querying single record with ID: %s", id) - singleResult := model - if err := query.First(singleResult, id).Error; err != nil { - if err == gorm.ErrRecordNotFound { - logger.Warn("Record not found with ID: %s", id) - h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil) - return - } - logger.Error("Error querying record: %v", err) - h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err) - return - } - result = singleResult - } else { - logger.Debug("Querying multiple records") - sliceType := reflect.SliceOf(reflect.TypeOf(model)) - results := reflect.New(sliceType).Interface() - - if err := query.Find(results).Error; err != nil { - logger.Error("Error querying records: %v", err) - h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err) - return - } - result = reflect.ValueOf(results).Elem().Interface() - } - - logger.Info("Successfully retrieved records") - h.sendResponse(w, result, &Metadata{ - Total: total, - Filtered: total, - Limit: optionalInt(options.Limit), - Offset: optionalInt(options.Offset), - }) -} - -// Create handler -func (h *LegacyAPIHandler) handleCreate(w http.ResponseWriter, r *http.Request, schema, entity string, data any, options RequestOptions) { - logger.Info("Creating records for %s.%s", schema, entity) - query := h.db.Table(fmt.Sprintf("%s.%s", schema, entity)) - - switch v := data.(type) { - case map[string]interface{}: - result := query.Create(v) - if result.Error != nil { - logger.Error("Error creating record: %v", result.Error) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", result.Error) - return - } - logger.Info("Successfully created record") - h.sendResponse(w, v, nil) - - case []map[string]interface{}: - result := query.Create(v) - if result.Error != nil { - logger.Error("Error creating records: %v", result.Error) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", result.Error) - return - } - logger.Info("Successfully created %d records", len(v)) - h.sendResponse(w, v, nil) - case []interface{}: - list := make([]interface{}, 0) - for _, item := range v { - result := query.Create(item) - list = append(list, item) - if result.Error != nil { - logger.Error("Error creating records: %v", result.Error) - h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", result.Error) - return - } - logger.Info("Successfully created %d records", len(v)) - } - h.sendResponse(w, list, nil) - default: - logger.Error("Invalid data type for create operation: %T", data) - } -} - -// Update handler -func (h *LegacyAPIHandler) handleUpdate(w http.ResponseWriter, r *http.Request, schema, entity string, urlID string, reqID any, data any, options RequestOptions) { - logger.Info("Updating records for %s.%s", schema, entity) - query := h.db.Table(fmt.Sprintf("%s.%s", schema, entity)) - - switch { - case urlID != "": - logger.Debug("Updating by URL ID: %s", urlID) - result := query.Where("id = ?", urlID).Updates(data) - handleUpdateResult(w, h, result, data) - - case reqID != nil: - switch id := reqID.(type) { - case string: - logger.Debug("Updating by request ID: %s", id) - result := query.Where("id = ?", id).Updates(data) - handleUpdateResult(w, h, result, data) - - case []string: - logger.Debug("Updating by multiple IDs: %v", id) - result := query.Where("id IN ?", id).Updates(data) - handleUpdateResult(w, h, result, data) - } - - case data != nil: - switch v := data.(type) { - case []map[string]interface{}: - logger.Debug("Performing bulk update with %d records", len(v)) - err := h.db.Transaction(func(tx *gorm.DB) error { - for _, item := range v { - if id, ok := item["id"].(string); ok { - if err := tx.Where("id = ?", id).Updates(item).Error; err != nil { - logger.Error("Error in bulk update transaction: %v", err) - return err - } - } - } - return nil - }) - if err != nil { - h.sendError(w, http.StatusInternalServerError, "update_error", "Error in bulk update", err) - return - } - logger.Info("Bulk update completed successfully") - h.sendResponse(w, data, nil) - } - default: - logger.Error("Invalid data type for update operation: %T", data) - - } -} - -// Delete handler -func (h *LegacyAPIHandler) handleDelete(w http.ResponseWriter, r *http.Request, schema, entity, id string) { - logger.Info("Deleting records from %s.%s", schema, entity) - query := h.db.Table(fmt.Sprintf("%s.%s", schema, entity)) - - if id == "" { - logger.Error("Delete operation requires an ID") - h.sendError(w, http.StatusBadRequest, "missing_id", "Delete operation requires an ID", nil) - return - } - - result := query.Delete("id = ?", id) - if result.Error != nil { - logger.Error("Error deleting record: %v", result.Error) - h.sendError(w, http.StatusInternalServerError, "delete_error", "Error deleting record", result.Error) - return - } - if result.RowsAffected == 0 { - logger.Warn("No record found to delete with ID: %s", id) - h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil) - return - } - - logger.Info("Successfully deleted record with ID: %s", id) - h.sendResponse(w, nil, nil) -} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 5d4fca3..8a84630 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -8,17 +8,18 @@ import ( "reflect" "strings" + "github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/Warky-Devs/ResolveSpec/pkg/logger" ) // Handler handles API requests using database and model abstractions type Handler struct { - db Database - registry ModelRegistry + db common.Database + registry common.ModelRegistry } // NewHandler creates a new API handler with database and registry abstractions -func NewHandler(db Database, registry ModelRegistry) *Handler { +func NewHandler(db common.Database, registry common.ModelRegistry) *Handler { return &Handler{ db: db, registry: registry, @@ -26,9 +27,9 @@ func NewHandler(db Database, registry ModelRegistry) *Handler { } // Handle processes API requests through router-agnostic interface -func (h *Handler) Handle(w ResponseWriter, r Request, params map[string]string) { +func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) { ctx := context.Background() - + body, err := r.Body() if err != nil { logger.Error("Failed to read request body: %v", err) @@ -36,7 +37,7 @@ func (h *Handler) Handle(w ResponseWriter, r Request, params map[string]string) return } - var req RequestBody + var req common.RequestBody if err := json.Unmarshal(body, &req); err != nil { logger.Error("Failed to decode request body: %v", err) h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err) @@ -65,7 +66,7 @@ func (h *Handler) Handle(w ResponseWriter, r Request, params map[string]string) } // HandleGet processes GET requests for metadata -func (h *Handler) HandleGet(w ResponseWriter, r Request, params map[string]string) { +func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) { schema := params["schema"] entity := params["entity"] @@ -82,7 +83,7 @@ func (h *Handler) HandleGet(w ResponseWriter, r Request, params map[string]strin h.sendResponse(w, metadata, nil) } -func (h *Handler) handleRead(ctx context.Context, w ResponseWriter, schema, entity, id string, options RequestOptions) { +func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schema, entity, id string, options common.RequestOptions) { logger.Info("Reading records from %s.%s", schema, entity) model, err := h.registry.GetModelByEntity(schema, entity) @@ -104,11 +105,9 @@ func (h *Handler) handleRead(ctx context.Context, w ResponseWriter, schema, enti query = query.Column(options.Columns...) } - // Note: Preloading is not implemented in the new database abstraction yet - // This is a limitation of the current interface design - // For now, preloading should use the legacy APIHandler + // Apply preloading if len(options.Preload) > 0 { - logger.Warn("Preloading not yet implemented in new handler - use legacy APIHandler for preload functionality") + query = h.applyPreloads(model, query, options.Preload) } // Apply filters @@ -172,18 +171,35 @@ func (h *Handler) handleRead(ctx context.Context, w ResponseWriter, schema, enti } logger.Info("Successfully retrieved records") - h.sendResponse(w, result, &Metadata{ + + limit := 0 + if options.Limit != nil { + limit = *options.Limit + } + offset := 0 + if options.Offset != nil { + offset = *options.Offset + } + + h.sendResponse(w, result, &common.Metadata{ Total: int64(total), Filtered: int64(total), - Limit: optionalInt(options.Limit), - Offset: optionalInt(options.Offset), + Limit: limit, + Offset: offset, }) } -func (h *Handler) handleCreate(ctx context.Context, w ResponseWriter, schema, entity string, data interface{}, options RequestOptions) { +func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, schema, entity string, data interface{}, options common.RequestOptions) { logger.Info("Creating records for %s.%s", schema, entity) - tableName := fmt.Sprintf("%s.%s", schema, entity) + // Get the model to determine the actual table name + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Warn("Model not found, using default table name") + model = nil + } + + tableName := h.getTableName(schema, entity, model) query := h.db.NewInsert().Table(tableName) switch v := data.(type) { @@ -201,7 +217,7 @@ func (h *Handler) handleCreate(ctx context.Context, w ResponseWriter, schema, en h.sendResponse(w, v, nil) case []map[string]interface{}: - err := h.db.RunInTransaction(ctx, func(tx Database) error { + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { txQuery := tx.NewInsert().Table(tableName) for key, value := range item { @@ -224,7 +240,7 @@ func (h *Handler) handleCreate(ctx context.Context, w ResponseWriter, schema, en case []interface{}: // Handle []interface{} type from JSON unmarshaling list := make([]interface{}, 0) - err := h.db.RunInTransaction(ctx, func(tx Database) error { + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { if itemMap, ok := item.(map[string]interface{}); ok { txQuery := tx.NewInsert().Table(tableName) @@ -253,10 +269,18 @@ func (h *Handler) handleCreate(ctx context.Context, w ResponseWriter, schema, en } } -func (h *Handler) handleUpdate(ctx context.Context, w ResponseWriter, schema, entity, urlID string, reqID interface{}, data interface{}, options RequestOptions) { +func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, schema, entity, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) { logger.Info("Updating records for %s.%s", schema, entity) - tableName := fmt.Sprintf("%s.%s", schema, entity) + // Get the model to determine the actual table name + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Warn("Model not found, using default table name") + // Fallback to entity name (without schema for SQLite compatibility) + model = nil + } + + tableName := h.getTableName(schema, entity, model) query := h.db.NewUpdate().Table(tableName) switch updates := data.(type) { @@ -289,18 +313,18 @@ func (h *Handler) handleUpdate(ctx context.Context, w ResponseWriter, schema, en h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err) return } - + if result.RowsAffected() == 0 { logger.Warn("No records found to update") h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) return } - + logger.Info("Successfully updated %d records", result.RowsAffected()) h.sendResponse(w, data, nil) } -func (h *Handler) handleDelete(ctx context.Context, w ResponseWriter, schema, entity, id string) { +func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, schema, entity, id string) { logger.Info("Deleting records from %s.%s", schema, entity) if id == "" { @@ -309,7 +333,14 @@ func (h *Handler) handleDelete(ctx context.Context, w ResponseWriter, schema, en return } - tableName := fmt.Sprintf("%s.%s", schema, entity) + // Get the model to determine the actual table name + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Warn("Model not found, using default table name") + model = nil + } + + tableName := h.getTableName(schema, entity, model) query := h.db.NewDelete().Table(tableName).Where("id = ?", id) result, err := query.Exec(ctx) @@ -318,7 +349,7 @@ func (h *Handler) handleDelete(ctx context.Context, w ResponseWriter, schema, en h.sendError(w, http.StatusInternalServerError, "delete_error", "Error deleting record", err) return } - + if result.RowsAffected() == 0 { logger.Warn("No record found to delete with ID: %s", id) h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil) @@ -329,7 +360,7 @@ func (h *Handler) handleDelete(ctx context.Context, w ResponseWriter, schema, en h.sendResponse(w, nil, nil) } -func (h *Handler) applyFilter(query SelectQuery, filter FilterOption) SelectQuery { +func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery { switch filter.Operator { case "eq": return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value) @@ -355,22 +386,22 @@ func (h *Handler) applyFilter(query SelectQuery, filter FilterOption) SelectQuer } func (h *Handler) getTableName(schema, entity string, model interface{}) string { - if provider, ok := model.(TableNameProvider); ok { + if provider, ok := model.(common.TableNameProvider); ok { return provider.TableName() } return fmt.Sprintf("%s.%s", schema, entity) } -func (h *Handler) generateMetadata(schema, entity string, model interface{}) TableMetadata { +func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata { modelType := reflect.TypeOf(model) if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } - metadata := TableMetadata{ + metadata := &common.TableMetadata{ Schema: schema, Table: entity, - Columns: make([]Column, 0), + Columns: make([]common.Column, 0), Relations: make([]string, 0), } @@ -400,7 +431,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) Tab continue } - column := Column{ + column := common.Column{ Name: jsonName, Type: getColumnType(field), IsNullable: isNullable(field), @@ -415,21 +446,21 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) Tab return metadata } -func (h *Handler) sendResponse(w ResponseWriter, data interface{}, metadata *Metadata) { +func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) { w.SetHeader("Content-Type", "application/json") - w.WriteJSON(Response{ + w.WriteJSON(common.Response{ Success: true, Data: data, Metadata: metadata, }) } -func (h *Handler) sendError(w ResponseWriter, status int, code, message string, details interface{}) { +func (h *Handler) sendError(w common.ResponseWriter, status int, code, message string, details interface{}) { w.SetHeader("Content-Type", "application/json") w.WriteHeader(status) - w.WriteJSON(Response{ + w.WriteJSON(common.Response{ Success: false, - Error: &APIError{ + Error: &common.APIError{ Code: code, Message: message, Details: details, @@ -442,4 +473,142 @@ func (h *Handler) sendError(w ResponseWriter, status int, code, message string, func (h *Handler) RegisterModel(schema, name string, model interface{}) error { fullname := fmt.Sprintf("%s.%s", schema, name) return h.registry.RegisterModel(fullname, model) -} \ No newline at end of file +} + +// Helper functions + +func getColumnType(field reflect.StructField) string { + // Check GORM type tag first + gormTag := field.Tag.Get("gorm") + if strings.Contains(gormTag, "type:") { + parts := strings.Split(gormTag, "type:") + if len(parts) > 1 { + typePart := strings.Split(parts[1], ";")[0] + return typePart + } + } + + // Map Go types to SQL types + switch field.Type.Kind() { + case reflect.String: + return "string" + case reflect.Int, reflect.Int32: + return "integer" + case reflect.Int64: + return "bigint" + case reflect.Float32: + return "float" + case reflect.Float64: + return "double" + case reflect.Bool: + return "boolean" + default: + if field.Type.Name() == "Time" { + return "timestamp" + } + return "unknown" + } +} + +func isNullable(field reflect.StructField) bool { + // Check if it's a pointer type + if field.Type.Kind() == reflect.Ptr { + return true + } + + // Check if it's a null type from sql package + typeName := field.Type.Name() + if strings.HasPrefix(typeName, "Null") { + return true + } + + // Check GORM tags + gormTag := field.Tag.Get("gorm") + return !strings.Contains(gormTag, "not null") +} + +// Preload support functions + +type relationshipInfo struct { + fieldName string + jsonName string + relationType string // "belongsTo", "hasMany", "hasOne", "many2many" + foreignKey string + references string + joinTable string + relatedModel interface{} +} + +func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery { + modelType := reflect.TypeOf(model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + for _, preload := range preloads { + logger.Debug("Processing preload for relation: %s", preload.Relation) + relInfo := h.getRelationshipInfo(modelType, preload.Relation) + if relInfo == nil { + logger.Warn("Relation %s not found in model", preload.Relation) + continue + } + + // Use the field name (capitalized) for ORM preloading + // ORMs like GORM and Bun expect the struct field name, not the JSON name + relationFieldName := relInfo.fieldName + + // For now, we'll preload without conditions + // TODO: Implement column selection and filtering for preloads + // This requires a more sophisticated approach with callbacks or query builders + query = query.Preload(relationFieldName) + logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName) + } + + return query +} + +func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + jsonTag := field.Tag.Get("json") + jsonName := strings.Split(jsonTag, ",")[0] + + if jsonName == relationName { + gormTag := field.Tag.Get("gorm") + info := &relationshipInfo{ + fieldName: field.Name, + jsonName: jsonName, + } + + // Parse GORM tag to determine relationship type and keys + if strings.Contains(gormTag, "foreignKey") { + info.foreignKey = h.extractTagValue(gormTag, "foreignKey") + info.references = h.extractTagValue(gormTag, "references") + + // Determine if it's belongsTo or hasMany/hasOne + if field.Type.Kind() == reflect.Slice { + info.relationType = "hasMany" + } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { + info.relationType = "belongsTo" + } + } else if strings.Contains(gormTag, "many2many") { + info.relationType = "many2many" + info.joinTable = h.extractTagValue(gormTag, "many2many") + } + + return info + } + } + return nil +} + +func (h *Handler) extractTagValue(tag, key string) string { + parts := strings.Split(tag, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, key+":") { + return strings.TrimPrefix(part, key+":") + } + } + return "" +} diff --git a/pkg/resolvespec/meta.go b/pkg/resolvespec/meta.go deleted file mode 100644 index d8b5d70..0000000 --- a/pkg/resolvespec/meta.go +++ /dev/null @@ -1,131 +0,0 @@ -package resolvespec - -import ( - "net/http" - "reflect" - "strings" - - "github.com/Warky-Devs/ResolveSpec/pkg/logger" -) - -func (h *LegacyAPIHandler) HandleGet(w http.ResponseWriter, r *http.Request, params map[string]string) { - schema := params["schema"] - entity := params["entity"] - - logger.Info("Getting metadata for %s.%s", schema, entity) - - // Get model for the entity - model, err := h.getModelForEntity(schema, entity) - if err != nil { - logger.Error("Failed to get model: %v", err) - h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) - return - } - - modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - - metadata := TableMetadata{ - Schema: schema, - Table: entity, - Columns: make([]Column, 0), - Relations: make([]string, 0), - } - - // Get field information using reflection - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - - // Skip unexported fields - if !field.IsExported() { - continue - } - - // Parse GORM tags - gormTag := field.Tag.Get("gorm") - jsonTag := field.Tag.Get("json") - - // Skip if json tag is "-" - if jsonTag == "-" { - continue - } - - // Get JSON field name - jsonName := strings.Split(jsonTag, ",")[0] - if jsonName == "" { - jsonName = field.Name - } - - // Check if it's a relation - if field.Type.Kind() == reflect.Slice || - (field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") { - metadata.Relations = append(metadata.Relations, jsonName) - continue - } - - column := Column{ - Name: jsonName, - Type: getColumnType(field), - IsNullable: isNullable(field), - IsPrimary: strings.Contains(gormTag, "primaryKey"), - IsUnique: strings.Contains(gormTag, "unique") || strings.Contains(gormTag, "uniqueIndex"), - HasIndex: strings.Contains(gormTag, "index") || strings.Contains(gormTag, "uniqueIndex"), - } - - metadata.Columns = append(metadata.Columns, column) - } - - h.sendResponse(w, metadata, nil) -} - -func getColumnType(field reflect.StructField) string { - // Check GORM type tag first - gormTag := field.Tag.Get("gorm") - if strings.Contains(gormTag, "type:") { - parts := strings.Split(gormTag, "type:") - if len(parts) > 1 { - typePart := strings.Split(parts[1], ";")[0] - return typePart - } - } - - // Map Go types to SQL types - switch field.Type.Kind() { - case reflect.String: - return "string" - case reflect.Int, reflect.Int32: - return "integer" - case reflect.Int64: - return "bigint" - case reflect.Float32: - return "float" - case reflect.Float64: - return "double" - case reflect.Bool: - return "boolean" - default: - if field.Type.Name() == "Time" { - return "timestamp" - } - return "unknown" - } -} - -func isNullable(field reflect.StructField) bool { - // Check if it's a pointer type - if field.Type.Kind() == reflect.Ptr { - return true - } - - // Check if it's a null type from sql package - typeName := field.Type.Name() - if strings.HasPrefix(typeName, "Null") { - return true - } - - // Check GORM tags - gormTag := field.Tag.Get("gorm") - return !strings.Contains(gormTag, "not null") -} diff --git a/pkg/resolvespec/resolvespec.go b/pkg/resolvespec/resolvespec.go index f660f72..5dca1de 100644 --- a/pkg/resolvespec/resolvespec.go +++ b/pkg/resolvespec/resolvespec.go @@ -3,145 +3,179 @@ package resolvespec import ( "net/http" + "github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database" + "github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/router" + "github.com/Warky-Devs/ResolveSpec/pkg/modelregistry" "github.com/gorilla/mux" "github.com/uptrace/bun" + "github.com/uptrace/bunrouter" "gorm.io/gorm" ) -// NewAPIHandler creates a new APIHandler with GORM (backward compatibility) -func NewAPIHandlerWithGORM(db *gorm.DB) *APIHandlerCompat { - return NewAPIHandler(db) -} - // NewHandlerWithGORM creates a new Handler with GORM adapter func NewHandlerWithGORM(db *gorm.DB) *Handler { - gormAdapter := NewGormAdapter(db) - registry := NewModelRegistry() + gormAdapter := database.NewGormAdapter(db) + registry := modelregistry.NewModelRegistry() return NewHandler(gormAdapter, registry) } -// NewStandardRouter creates a router with standard HTTP handlers -func NewStandardRouter() *StandardMuxAdapter { - return NewStandardMuxAdapter() +// NewHandlerWithBun creates a new Handler with Bun adapter +func NewHandlerWithBun(db *bun.DB) *Handler { + bunAdapter := database.NewBunAdapter(db) + registry := modelregistry.NewModelRegistry() + return NewHandler(bunAdapter, registry) } -// SetupRoutes sets up routes for the ResolveSpec API with backward compatibility -func SetupRoutes(router *mux.Router, handler *APIHandlerCompat) { - router.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { +// NewStandardMuxRouter creates a router with standard Mux HTTP handlers +func NewStandardMuxRouter() *router.StandardMuxAdapter { + return router.NewStandardMuxAdapter() +} + +// NewStandardBunRouter creates a router with standard BunRouter handlers +func NewStandardBunRouter() *router.StandardBunRouterAdapter { + return router.NewStandardBunRouterAdapter() +} + +// SetupMuxRoutes sets up routes for the ResolveSpec API with Mux +func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) { + muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) - handler.Handle(w, r, vars) + reqAdapter := router.NewHTTPRequest(r) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, vars) }).Methods("POST") - router.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) { + muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) - handler.Handle(w, r, vars) + reqAdapter := router.NewHTTPRequest(r) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, vars) }).Methods("POST") - router.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { + muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) - handler.HandleGet(w, r, vars) + reqAdapter := router.NewHTTPRequest(r) + respAdapter := router.NewHTTPResponseWriter(w) + handler.HandleGet(respAdapter, reqAdapter, vars) }).Methods("GET") } // Example usage functions for documentation: -// ExampleWithGORM shows how to use ResolveSpec with GORM (current default) +// ExampleWithGORM shows how to use ResolveSpec with GORM func ExampleWithGORM(db *gorm.DB) { - // Create handler using GORM (backward compatible) - handler := NewAPIHandlerWithGORM(db) - + // Create handler using GORM + handler := NewHandlerWithGORM(db) + // Setup router - router := mux.NewRouter() - SetupRoutes(router, handler) - + muxRouter := mux.NewRouter() + SetupMuxRoutes(muxRouter, handler) + // Register models // handler.RegisterModel("public", "users", &User{}) } -// ExampleWithNewAPI shows how to use the new abstracted API -func ExampleWithNewAPI(db *gorm.DB) { - // Create database adapter - dbAdapter := NewGormAdapter(db) - - // Create model registry - registry := NewModelRegistry() - // registry.RegisterModel("public.users", &User{}) - - // Create handler with new API - handler := NewHandler(dbAdapter, registry) - - // Create router adapter - routerAdapter := NewStandardRouter() - - // Register routes using new API - routerAdapter.RegisterRoute("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request, params map[string]string) { - reqAdapter := NewHTTPRequest(r) - respAdapter := NewHTTPResponseWriter(w) - handler.Handle(respAdapter, reqAdapter, params) - }) -} - // ExampleWithBun shows how to switch to Bun ORM func ExampleWithBun(bunDB *bun.DB) { // Create Bun adapter - dbAdapter := NewBunAdapter(bunDB) - + dbAdapter := database.NewBunAdapter(bunDB) + // Create model registry - registry := NewModelRegistry() + registry := modelregistry.NewModelRegistry() // registry.RegisterModel("public.users", &User{}) - + // Create handler handler := NewHandler(dbAdapter, registry) - - // Setup routes same as with GORM - router := NewStandardRouter() - router.RegisterRoute("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request, params map[string]string) { - reqAdapter := NewHTTPRequest(r) - respAdapter := NewHTTPResponseWriter(w) + + // Setup routes + muxRouter := mux.NewRouter() + SetupMuxRoutes(muxRouter, handler) +} + +// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API +func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) { + r := bunRouter.GetBunRouter() + + r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + } + reqAdapter := router.NewHTTPRequest(req.Request) + respAdapter := router.NewHTTPResponseWriter(w) handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + "id": req.Param("id"), + } + reqAdapter := router.NewHTTPRequest(req.Request) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + } + reqAdapter := router.NewHTTPRequest(req.Request) + respAdapter := router.NewHTTPResponseWriter(w) + handler.HandleGet(respAdapter, reqAdapter, params) + return nil + }) + + r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + "id": req.Param("id"), + } + reqAdapter := router.NewHTTPRequest(req.Request) + respAdapter := router.NewHTTPResponseWriter(w) + handler.HandleGet(respAdapter, reqAdapter, params) + return nil }) } // ExampleWithBunRouter shows how to use bunrouter from uptrace -func ExampleWithBunRouter(db *gorm.DB) { - // Create handler (can use any database adapter) - handler := NewAPIHandler(db) - +func ExampleWithBunRouter(bunDB *bun.DB) { + // Create handler with Bun adapter + handler := NewHandlerWithBun(bunDB) + // Create bunrouter - router := NewStandardBunRouterAdapter() - + bunRouter := router.NewStandardBunRouterAdapter() + // Setup ResolveSpec routes with bunrouter - SetupBunRouterWithResolveSpec(router.GetBunRouter(), handler) - + SetupBunRouterRoutes(bunRouter, handler) + // Start server - // http.ListenAndServe(":8080", router.GetBunRouter()) + // http.ListenAndServe(":8080", bunRouter.GetBunRouter()) } // ExampleBunRouterWithBunDB shows the full uptrace stack (bunrouter + Bun ORM) func ExampleBunRouterWithBunDB(bunDB *bun.DB) { // Create Bun database adapter - dbAdapter := NewBunAdapter(bunDB) - + dbAdapter := database.NewBunAdapter(bunDB) + // Create model registry - registry := NewModelRegistry() + registry := modelregistry.NewModelRegistry() // registry.RegisterModel("public.users", &User{}) - + // Create handler with Bun handler := NewHandler(dbAdapter, registry) - - // Create compatibility wrapper for existing APIs - compatHandler := &APIHandlerCompat{ - legacyHandler: nil, // No legacy handler needed - newHandler: handler, - db: nil, // No GORM dependency - } - + // Create bunrouter - router := NewStandardBunRouterAdapter() - + bunRouter := router.NewStandardBunRouterAdapter() + // Setup ResolveSpec routes - SetupBunRouterWithResolveSpec(router.GetBunRouter(), compatHandler) - + SetupBunRouterRoutes(bunRouter, handler) + // This gives you the full uptrace stack: bunrouter + Bun ORM - // http.ListenAndServe(":8080", router.GetBunRouter()) -} \ No newline at end of file + // http.ListenAndServe(":8080", bunRouter.GetBunRouter()) +} diff --git a/pkg/resolvespec/utils.go b/pkg/resolvespec/utils.go deleted file mode 100644 index 11d2dee..0000000 --- a/pkg/resolvespec/utils.go +++ /dev/null @@ -1,78 +0,0 @@ -package resolvespec - -import ( - "fmt" - "net/http" - - "github.com/Warky-Devs/ResolveSpec/pkg/logger" - "github.com/Warky-Devs/ResolveSpec/pkg/models" - "gorm.io/gorm" -) - -func handleUpdateResult(w http.ResponseWriter, h *LegacyAPIHandler, result *gorm.DB, data interface{}) { - if result.Error != nil { - logger.Error("Update error: %v", result.Error) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", result.Error) - return - } - if result.RowsAffected == 0 { - logger.Warn("No records found to update") - h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) - return - } - logger.Info("Successfully updated %d records", result.RowsAffected) - h.sendResponse(w, data, nil) -} - -func optionalInt(ptr *int) int { - if ptr == nil { - return 0 - } - return *ptr -} - -// Helper methods -func (h *LegacyAPIHandler) applyFilter(query *gorm.DB, filter FilterOption) *gorm.DB { - switch filter.Operator { - case "eq": - return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value) - case "neq": - return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value) - case "gt": - return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value) - case "gte": - return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value) - case "lt": - return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value) - case "lte": - return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value) - case "like": - return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value) - case "ilike": - return query.Where(fmt.Sprintf("%s ILIKE ?", filter.Column), filter.Value) - case "in": - return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value) - default: - return query - } -} - -func (h *LegacyAPIHandler) getModelForEntity(schema, name string) (interface{}, error) { - model, err := models.GetModelByName(fmt.Sprintf("%s.%s", schema, name)) - - if err != nil { - model, err = models.GetModelByName(name) - } - return model, err -} - -func (h *LegacyAPIHandler) RegisterModel(schema, name string, model interface{}) error { - fullname := fmt.Sprintf("%s.%s", schema, name) - oldModel, err := models.GetModelByName(fullname) - if oldModel != nil && err != nil { - return fmt.Errorf("model %s already exists", fullname) - } - err = models.RegisterModel(model, fullname) - - return err -} diff --git a/pkg/restheadspec/HEADERS.md b/pkg/restheadspec/HEADERS.md new file mode 100644 index 0000000..c422404 --- /dev/null +++ b/pkg/restheadspec/HEADERS.md @@ -0,0 +1,614 @@ +# RestHeadSpec Headers Documentation + +RestHeadSpec provides a comprehensive header-based REST API where all query options are passed via HTTP headers instead of request body. This document describes all supported headers and their usage. + +## Overview + +RestHeadSpec uses HTTP headers for: +- Field selection +- Filtering and searching +- Joins and relationship loading +- Sorting and pagination +- Advanced query features +- Response formatting +- Transaction control + +### Header Naming Convention + +All headers support **optional identifiers** at the end to allow multiple instances of the same header type. This is useful when you need to specify multiple related filters or options. + +**Examples:** +``` +# Standard header +x-preload: employees + +# Headers with identifiers (both work the same) +x-preload-main: employees +x-preload-secondary: department +x-preload-1: projects +``` + +The system uses `strings.HasPrefix()` to match headers, so any suffix after the header name is ignored for matching purposes. This allows you to: +- Add descriptive identifiers: `x-sort-primary`, `x-sort-fallback` +- Add numeric identifiers: `x-fieldfilter-status-1`, `x-fieldfilter-status-2` +- Organize related headers: `x-preload-employee-data`, `x-preload-department-info` + +## Header Categories + +### 1. Field Selection + +#### `x-select-fields` +Specify which columns to include in the response. + +**Format:** Comma-separated list of column names +``` +x-select-fields: id,name,email,created_at +``` + +#### `x-not-select-fields` +Specify which columns to exclude from the response. + +**Format:** Comma-separated list of column names +``` +x-not-select-fields: password,internal_notes +``` + +#### `x-clean-json` +Remove null and empty fields from the response. + +**Format:** Boolean (true/false) +``` +x-clean-json: true +``` + +--- + +### 2. Filtering & Search + +#### `x-fieldfilter-{colname}` +Exact match filter on a specific column. + +**Format:** `x-fieldfilter-{columnName}: {value}` +``` +x-fieldfilter-status: active +x-fieldfilter-department_id: dept123 +``` + +#### `x-searchfilter-{colname}` +Fuzzy search (ILIKE) on a specific column. + +**Format:** `x-searchfilter-{columnName}: {searchTerm}` +``` +x-searchfilter-name: john +x-searchfilter-description: website +``` +This will match any records where the column contains the search term (case-insensitive). + +#### `x-searchop-{operator}-{colname}` +Search with specific operators (AND logic). + +**Supported Operators:** +- `contains` - Contains substring (case-insensitive) +- `beginswith` / `startswith` - Starts with (case-insensitive) +- `endswith` - Ends with (case-insensitive) +- `equals` / `eq` - Exact match +- `notequals` / `neq` / `ne` - Not equal +- `greaterthan` / `gt` - Greater than +- `lessthan` / `lt` - Less than +- `greaterthanorequal` / `gte` / `ge` - Greater than or equal +- `lessthanorequal` / `lte` / `le` - Less than or equal +- `between` - Between two values, **exclusive** (> val1 AND < val2) - format: `value1,value2` +- `betweeninclusive` - Between two values, **inclusive** (>= val1 AND <= val2) - format: `value1,value2` +- `in` - In a list of values - format: `value1,value2,value3` +- `empty` / `isnull` / `null` - Is NULL or empty string +- `notempty` / `isnotnull` / `notnull` - Is NOT NULL and not empty string + +**Type-Aware Features:** +- Text searches use case-insensitive matching (ILIKE with citext cast) +- Numeric comparisons work with integers, floats, and decimals +- Date/time comparisons handle timestamps correctly +- JSON field support for structured data + +**Examples:** +``` +# Text search (case-insensitive) +x-searchop-contains-name: smith + +# Numeric comparison +x-searchop-gt-age: 25 +x-searchop-gte-salary: 50000 + +# Date range (exclusive) +x-searchop-between-created_at: 2024-01-01,2024-12-31 + +# Date range (inclusive) +x-searchop-betweeninclusive-birth_date: 1990-01-01,2000-12-31 + +# List matching +x-searchop-in-status: active,pending,review + +# NULL checks +x-searchop-empty-deleted_at: true +x-searchop-notempty-email: true +``` + +#### `x-searchor-{operator}-{colname}` +Same as `x-searchop` but with OR logic instead of AND. + +``` +x-searchor-eq-status: active +x-searchor-eq-status: pending +``` + +#### `x-searchand-{operator}-{colname}` +Explicit AND logic (same as `x-searchop`). + +``` +x-searchand-gte-age: 18 +x-searchand-lte-age: 65 +``` + +#### `x-searchcols` +Specify columns for "all" search operations. + +**Format:** Comma-separated list +``` +x-searchcols: name,email,description +``` + +#### `x-custom-sql-w` +Raw SQL WHERE clause with AND condition. + +**Format:** SQL WHERE clause (without the WHERE keyword) +``` +x-custom-sql-w: status = 'active' AND created_at > '2024-01-01' +``` + +⚠️ **Warning:** Use with caution - ensure proper SQL injection prevention. + +#### `x-custom-sql-or` +Raw SQL WHERE clause with OR condition. + +**Format:** SQL WHERE clause +``` +x-custom-sql-or: status = 'archived' OR is_deleted = true +``` + +--- + +### 3. Joins & Relations + +#### `x-preload` +Preload related tables using the ORM's preload functionality. + +**Format:** `RelationName:field1,field2` or `RelationName` + +Multiple relations can be specified using multiple headers or by separating with `|` + +**Examples:** +``` +# Preload all fields from employees relation +x-preload: employees + +# Preload specific fields from employees +x-preload: employees:id,first_name,last_name,email + +# Multiple preloads using pipe separator +x-preload: employees:id,name|department:id,name + +# Multiple preloads using separate headers with identifiers +x-preload-1: employees:id,first_name,last_name +x-preload-2: department:id,name +x-preload-related: projects:id,name,status +``` + +#### `x-expand` +LEFT JOIN related tables and expand results inline. + +**Format:** Same as `x-preload` + +``` +x-expand: department:id,name,code +``` + +**Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation. + +#### `x-custom-sql-join` +Raw SQL JOIN statement. + +**Format:** SQL JOIN clause +``` +x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id +``` + +⚠️ **Note:** Not yet fully implemented. + +--- + +### 4. Sorting & Pagination + +#### `x-sort` +Sort results by one or more columns. + +**Format:** Comma-separated list with optional `+` (ASC) or `-` (DESC) prefix + +``` +# Single column ascending (default) +x-sort: name + +# Single column descending +x-sort: -created_at + +# Multiple columns +x-sort: +department,- created_at,name + +# Equivalent to: ORDER BY department ASC, created_at DESC, name ASC +``` + +#### `x-limit` +Limit the number of records returned. + +**Format:** Integer +``` +x-limit: 50 +``` + +#### `x-offset` +Skip a number of records (offset-based pagination). + +**Format:** Integer +``` +x-offset: 100 +``` + +#### `x-cursor-forward` +Cursor-based pagination (forward). + +**Format:** Cursor string +``` +x-cursor-forward: eyJpZCI6MTIzfQ== +``` + +⚠️ **Note:** Not yet fully implemented. + +#### `x-cursor-backward` +Cursor-based pagination (backward). + +**Format:** Cursor string +``` +x-cursor-backward: eyJpZCI6MTIzfQ== +``` + +⚠️ **Note:** Not yet fully implemented. + +--- + +### 5. Advanced Features + +#### `x-advsql-{colname}` +Advanced SQL expression for a specific column. + +**Format:** `x-advsql-{columnName}: {SQLExpression}` +``` +x-advsql-full_name: CONCAT(first_name, ' ', last_name) +x-advsql-age_years: EXTRACT(YEAR FROM AGE(birth_date)) +``` + +⚠️ **Note:** Not yet fully implemented in query execution. + +#### `x-cql-sel-{colname}` +Computed Query Language - custom SQL expressions aliased as columns. + +**Format:** `x-cql-sel-{aliasName}: {SQLExpression}` +``` +x-cql-sel-employee_count: COUNT(employees.id) +x-cql-sel-total_revenue: SUM(orders.amount) +``` + +⚠️ **Note:** Not yet fully implemented in query execution. + +#### `x-distinct` +Apply DISTINCT to the query. + +**Format:** Boolean (true/false) +``` +x-distinct: true +``` + +⚠️ **Note:** Implementation depends on ORM adapter support. + +#### `x-skipcount` +Skip counting total records (performance optimization). + +**Format:** Boolean (true/false) +``` +x-skipcount: true +``` + +When enabled, the total count will be -1 in the response metadata. + +#### `x-skipcache` +Bypass query cache (if caching is implemented). + +**Format:** Boolean (true/false) +``` +x-skipcache: true +``` + +#### `x-fetch-rownumber` +Get the row number of a specific record in the result set. + +**Format:** Record identifier +``` +x-fetch-rownumber: record123 +``` + +⚠️ **Note:** Not yet implemented. + +#### `x-pkrow` +Similar to `x-fetch-rownumber` - get row number by primary key. + +**Format:** Primary key value +``` +x-pkrow: 123 +``` + +⚠️ **Note:** Not yet implemented. + +--- + +### 6. Response Format + +#### `x-simpleapi` +Return simple format (just the data array). + +**Format:** Presence of header activates it +``` +x-simpleapi: true +``` + +**Response Format:** +```json +[ + { "id": 1, "name": "John" }, + { "id": 2, "name": "Jane" } +] +``` + +#### `x-detailapi` +Return detailed format with metadata (default). + +**Format:** Presence of header activates it +``` +x-detailapi: true +``` + +**Response Format:** +```json +{ + "success": true, + "data": [...], + "metadata": { + "total": 100, + "filtered": 100, + "limit": 50, + "offset": 0 + } +} +``` + +#### `x-syncfusion` +Format response for Syncfusion UI components. + +**Format:** Presence of header activates it +``` +x-syncfusion: true +``` + +**Response Format:** +```json +{ + "result": [...], + "count": 100 +} +``` + +--- + +### 7. Transaction Control + +#### `x-transaction-atomic` +Use atomic transactions for write operations. + +**Format:** Boolean (true/false) +``` +x-transaction-atomic: true +``` + +Ensures that all write operations in the request succeed or fail together. + +--- + +## Base64 Encoding + +Headers support base64 encoding for complex values. Use one of these prefixes: + +- `ZIP_` - Base64 encoded value +- `__` - Base64 encoded value (double underscore) + +**Example:** +``` +# Plain value +x-custom-sql-w: status = 'active' + +# Base64 encoded (same value) +x-custom-sql-w: ZIP_c3RhdHVzID0gJ2FjdGl2ZSc= +``` + +--- + +## Complete Examples + +### Example 1: Basic Query + +```http +GET /api/employees HTTP/1.1 +Host: example.com +x-select-fields: id,first_name,last_name,email,department_id +x-preload: department:id,name +x-searchfilter-name: john +x-searchop-gte-created_at: 2024-01-01 +x-sort: -created_at,+last_name +x-limit: 50 +x-offset: 0 +x-skipcount: false +x-detailapi: true +``` + +### Example 2: Complex Query with Multiple Filters and Preloads + +```http +GET /api/employees HTTP/1.1 +Host: example.com +x-select-fields-main: id,first_name,last_name,email,department_id,manager_id +x-preload-1: department:id,name,code +x-preload-2: manager:id,first_name,last_name +x-preload-3: projects:id,name,status +x-fieldfilter-status-1: active +x-searchop-gte-created_at-filter1: 2024-01-01 +x-searchop-lt-created_at-filter2: 2024-12-31 +x-searchfilter-name-query: smith +x-sort-primary: -created_at +x-sort-secondary: +last_name +x-limit-page: 100 +x-offset-page: 0 +x-detailapi: true +``` + +**Note:** The identifiers after the header names (like `-main`, `-1`, `-filter1`, etc.) are optional and help organize multiple headers of the same type. Both approaches work: + +```http +# Without identifiers +x-preload: employees +x-preload: department + +# With identifiers (more organized) +x-preload-1: employees +x-preload-2: department +``` + +**Response:** +```json +{ + "success": true, + "data": [ + { + "id": "emp1", + "first_name": "John", + "last_name": "Doe", + "email": "john@example.com", + "department_id": "dept1", + "department": { + "id": "dept1", + "name": "Engineering" + } + } + ], + "metadata": { + "total": 1, + "filtered": 1, + "limit": 50, + "offset": 0 + } +} +``` + +--- + +## HTTP Method Mapping + +- `GET /{schema}/{entity}` - List all records +- `GET /{schema}/{entity}/{id}` - Get single record +- `POST /{schema}/{entity}` - Create record(s) +- `PUT /{schema}/{entity}/{id}` - Update record +- `PATCH /{schema}/{entity}/{id}` - Partial update +- `DELETE /{schema}/{entity}/{id}` - Delete record +- `GET /{schema}/{entity}/metadata` - Get table metadata + +--- + +## Implementation Status + +✅ **Implemented:** +- Field selection (select/omit columns) +- Filtering (field filters, search filters, operators) +- Preloading relations +- Sorting and pagination +- Skip count optimization +- Response format options +- Base64 decoding + +⚠️ **Partially Implemented:** +- Expand (currently falls back to preload) +- DISTINCT (depends on ORM adapter) + +🚧 **Planned:** +- Advanced SQL expressions (advsql, cql-sel) +- Custom SQL joins +- Cursor pagination +- Row number fetching +- Full expand with JOIN +- Query caching control + +--- + +## Security Considerations + +1. **SQL Injection**: Custom SQL headers (`x-custom-sql-*`) should be properly sanitized or restricted to trusted users only. + +2. **Query Complexity**: Consider implementing query complexity limits to prevent resource exhaustion. + +3. **Authentication**: Implement proper authentication and authorization checks before processing requests. + +4. **Rate Limiting**: Apply rate limiting to prevent abuse. + +5. **Field Restrictions**: Consider implementing field-level permissions to restrict access to sensitive columns. + +--- + +## Performance Tips + +1. Use `x-skipcount: true` for large datasets when you don't need the total count +2. Select only needed columns with `x-select-fields` +3. Use preload wisely - only load relations you need +4. Implement proper database indexes for filtered and sorted columns +5. Consider pagination for large result sets + +--- + +## Migration from ResolveSpec + +RestHeadSpec is an alternative to ResolveSpec that uses headers instead of request body for options: + +**ResolveSpec (body-based):** +```json +POST /api/departments +{ + "operation": "read", + "options": { + "preload": [{"relation": "employees"}], + "filters": [{"column": "status", "operator": "eq", "value": "active"}], + "limit": 50 + } +} +``` + +**RestHeadSpec (header-based):** +```http +GET /api/departments +x-preload: employees +x-fieldfilter-status: active +x-limit: 50 +``` + +Both implementations share the same core handler logic and database adapters. diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go new file mode 100644 index 0000000..c98b2f0 --- /dev/null +++ b/pkg/restheadspec/handler.go @@ -0,0 +1,616 @@ +package restheadspec + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "reflect" + "strings" + + "github.com/Warky-Devs/ResolveSpec/pkg/common" + "github.com/Warky-Devs/ResolveSpec/pkg/logger" +) + +// Handler handles API requests using database and model abstractions +// This handler reads filters, columns, and options from HTTP headers +type Handler struct { + db common.Database + registry common.ModelRegistry +} + +// NewHandler creates a new API handler with database and registry abstractions +func NewHandler(db common.Database, registry common.ModelRegistry) *Handler { + return &Handler{ + db: db, + registry: registry, + } +} + +// Handle processes API requests through router-agnostic interface +// Options are read from HTTP headers instead of request body +func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) { + ctx := context.Background() + + schema := params["schema"] + entity := params["entity"] + id := params["id"] + + // Parse options from headers (now returns ExtendedRequestOptions) + options := h.parseOptionsFromHeaders(r) + + // Determine operation based on HTTP method + method := r.Method() + + logger.Info("Handling %s request for %s.%s", method, schema, entity) + + switch method { + case "GET": + if id != "" { + // GET with ID - read single record + h.handleRead(ctx, w, schema, entity, id, options) + } else { + // GET without ID - read multiple records + h.handleRead(ctx, w, schema, entity, "", options) + } + case "POST": + // Create operation + body, err := r.Body() + if err != nil { + logger.Error("Failed to read request body: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err) + return + } + var data interface{} + if err := json.Unmarshal(body, &data); err != nil { + logger.Error("Failed to decode request body: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err) + return + } + h.handleCreate(ctx, w, schema, entity, data, options) + case "PUT", "PATCH": + // Update operation + body, err := r.Body() + if err != nil { + logger.Error("Failed to read request body: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err) + return + } + var data interface{} + if err := json.Unmarshal(body, &data); err != nil { + logger.Error("Failed to decode request body: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err) + return + } + h.handleUpdate(ctx, w, schema, entity, id, nil, data, options) + case "DELETE": + h.handleDelete(ctx, w, schema, entity, id) + default: + logger.Error("Invalid HTTP method: %s", method) + h.sendError(w, http.StatusMethodNotAllowed, "invalid_method", "Invalid HTTP method", nil) + } +} + +// HandleGet processes GET requests for metadata +func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) { + schema := params["schema"] + entity := params["entity"] + + logger.Info("Getting metadata for %s.%s", schema, entity) + + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Error("Failed to get model: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) + return + } + + metadata := h.generateMetadata(schema, entity, model) + h.sendResponse(w, metadata, nil) +} + +// parseOptionsFromHeaders is now implemented in headers.go + +func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schema, entity, id string, options ExtendedRequestOptions) { + logger.Info("Reading records from %s.%s", schema, entity) + + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Error("Invalid entity: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) + return + } + + query := h.db.NewSelect().Model(model) + + // Get table name + tableName := h.getTableName(schema, entity, model) + query = query.Table(tableName) + + // Apply column selection + if len(options.Columns) > 0 { + logger.Debug("Selecting columns: %v", options.Columns) + query = query.Column(options.Columns...) + } + + // Apply preloading + for _, preload := range options.Preload { + logger.Debug("Applying preload: %s", preload.Relation) + query = query.Preload(preload.Relation) + } + + // Apply expand (LEFT JOIN) + for _, expand := range options.Expand { + logger.Debug("Applying expand: %s", expand.Relation) + // Note: Expand would require JOIN implementation + // For now, we'll use Preload as a fallback + query = query.Preload(expand.Relation) + } + + // Apply DISTINCT if requested + if options.Distinct { + logger.Debug("Applying DISTINCT") + // Note: DISTINCT implementation depends on ORM support + // This may need to be handled differently per database adapter + } + + // Apply filters + for _, filter := range options.Filters { + logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value) + query = h.applyFilter(query, filter) + } + + // Apply custom SQL WHERE clause (AND condition) + if options.CustomSQLWhere != "" { + logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) + query = query.Where(options.CustomSQLWhere) + } + + // Apply custom SQL WHERE clause (OR condition) + if options.CustomSQLOr != "" { + logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) + query = query.WhereOr(options.CustomSQLOr) + } + + // If ID is provided, filter by ID + if id != "" { + logger.Debug("Filtering by ID: %s", id) + query = query.Where("id = ?", id) + } + + // Apply sorting + for _, sort := range options.Sort { + direction := "ASC" + if strings.ToLower(sort.Direction) == "desc" { + direction = "DESC" + } + logger.Debug("Applying sort: %s %s", sort.Column, direction) + query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction)) + } + + // Get total count before pagination (unless skip count is requested) + var total int + if !options.SkipCount { + count, err := query.Count(ctx) + if err != nil { + logger.Error("Error counting records: %v", err) + h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) + return + } + total = count + logger.Debug("Total records: %d", total) + } else { + logger.Debug("Skipping count as requested") + total = -1 // Indicate count was skipped + } + + // Apply pagination + if options.Limit != nil && *options.Limit > 0 { + logger.Debug("Applying limit: %d", *options.Limit) + query = query.Limit(*options.Limit) + } + if options.Offset != nil && *options.Offset > 0 { + logger.Debug("Applying offset: %d", *options.Offset) + query = query.Offset(*options.Offset) + } + + // Execute query + resultSlice := reflect.New(reflect.SliceOf(reflect.TypeOf(model))).Interface() + if err := query.Scan(ctx, resultSlice); err != nil { + logger.Error("Error executing query: %v", err) + h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err) + return + } + + limit := 0 + if options.Limit != nil { + limit = *options.Limit + } + offset := 0 + if options.Offset != nil { + offset = *options.Offset + } + + metadata := &common.Metadata{ + Total: int64(total), + Filtered: int64(total), + Limit: limit, + Offset: offset, + } + + h.sendFormattedResponse(w, resultSlice, metadata, options) +} + +func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, schema, entity string, data interface{}, options ExtendedRequestOptions) { + logger.Info("Creating record in %s.%s", schema, entity) + + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Error("Invalid entity: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) + return + } + + tableName := h.getTableName(schema, entity, model) + + // Handle batch creation + dataValue := reflect.ValueOf(data) + if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array { + logger.Debug("Batch creation detected, count: %d", dataValue.Len()) + + // Use transaction for batch insert + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + for i := 0; i < dataValue.Len(); i++ { + item := dataValue.Index(i).Interface() + + // Convert item to model type + modelValue := reflect.New(reflect.TypeOf(model).Elem()).Interface() + jsonData, err := json.Marshal(item) + if err != nil { + return fmt.Errorf("failed to marshal item: %w", err) + } + if err := json.Unmarshal(jsonData, modelValue); err != nil { + return fmt.Errorf("failed to unmarshal item: %w", err) + } + + query := tx.NewInsert().Model(modelValue).Table(tableName) + if _, err := query.Exec(ctx); err != nil { + return fmt.Errorf("failed to insert record: %w", err) + } + } + return nil + }) + + if err != nil { + logger.Error("Error creating records: %v", err) + h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err) + return + } + + h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil) + return + } + + // Single record creation + modelValue := reflect.New(reflect.TypeOf(model).Elem()).Interface() + jsonData, err := json.Marshal(data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) + return + } + if err := json.Unmarshal(jsonData, modelValue); err != nil { + logger.Error("Error unmarshaling data: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) + return + } + + query := h.db.NewInsert().Model(modelValue).Table(tableName) + if _, err := query.Exec(ctx); err != nil { + logger.Error("Error creating record: %v", err) + h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err) + return + } + + h.sendResponse(w, modelValue, nil) +} + +func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, schema, entity, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) { + logger.Info("Updating record in %s.%s", schema, entity) + + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Error("Invalid entity: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) + return + } + + tableName := h.getTableName(schema, entity, model) + + // Convert data to map + dataMap, ok := data.(map[string]interface{}) + if !ok { + jsonData, err := json.Marshal(data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) + return + } + if err := json.Unmarshal(jsonData, &dataMap); err != nil { + logger.Error("Error unmarshaling data: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err) + return + } + } + + query := h.db.NewUpdate().Table(tableName).SetMap(dataMap) + + // Apply ID filter + if id != "" { + query = query.Where("id = ?", id) + } else if idPtr != nil { + query = query.Where("id = ?", *idPtr) + } else { + h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil) + return + } + + result, err := query.Exec(ctx) + if err != nil { + logger.Error("Error updating record: %v", err) + h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err) + return + } + + h.sendResponse(w, map[string]interface{}{ + "updated": result.RowsAffected(), + }, nil) +} + +func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, schema, entity, id string) { + logger.Info("Deleting record from %s.%s", schema, entity) + + model, err := h.registry.GetModelByEntity(schema, entity) + if err != nil { + logger.Error("Invalid entity: %v", err) + h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) + return + } + + tableName := h.getTableName(schema, entity, model) + + query := h.db.NewDelete().Table(tableName) + + if id == "" { + h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for delete", nil) + return + } + + query = query.Where("id = ?", id) + + result, err := query.Exec(ctx) + if err != nil { + logger.Error("Error deleting record: %v", err) + h.sendError(w, http.StatusInternalServerError, "delete_error", "Error deleting record", err) + return + } + + h.sendResponse(w, map[string]interface{}{ + "deleted": result.RowsAffected(), + }, nil) +} + +func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery { + switch strings.ToLower(filter.Operator) { + case "eq", "equals": + return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value) + case "neq", "not_equals", "ne": + return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value) + case "gt", "greater_than": + return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value) + case "gte", "greater_than_equals", "ge": + return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value) + case "lt", "less_than": + return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value) + case "lte", "less_than_equals", "le": + return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value) + case "like": + return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value) + case "ilike": + // Use ILIKE for case-insensitive search (PostgreSQL) + // For other databases, cast to citext or use LOWER() + return query.Where(fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column), filter.Value) + case "in": + return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value) + case "between": + // Handle between operator - exclusive (> val1 AND < val2) + if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { + return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1]) + } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { + return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1]) + } + logger.Warn("Invalid BETWEEN filter value format") + return query + case "between_inclusive": + // Handle between inclusive operator - inclusive (>= val1 AND <= val2) + if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { + return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1]) + } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { + return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1]) + } + logger.Warn("Invalid BETWEEN INCLUSIVE filter value format") + return query + case "is_null", "isnull": + // Check for NULL values + return query.Where(fmt.Sprintf("(%s IS NULL OR %s = '')", filter.Column, filter.Column)) + case "is_not_null", "isnotnull": + // Check for NOT NULL values + return query.Where(fmt.Sprintf("(%s IS NOT NULL AND %s != '')", filter.Column, filter.Column)) + default: + logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator) + return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value) + } +} + +func (h *Handler) getTableName(schema, entity string, model interface{}) string { + // Check if model implements TableNameProvider + if provider, ok := model.(common.TableNameProvider); ok { + tableName := provider.TableName() + if tableName != "" { + return tableName + } + } + + // Default to schema.entity + if schema != "" { + return fmt.Sprintf("%s.%s", schema, entity) + } + return entity +} + +func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata { + modelType := reflect.TypeOf(model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + tableName := h.getTableName(schema, entity, model) + + metadata := &common.TableMetadata{ + Schema: schema, + Table: tableName, + Columns: []common.Column{}, + } + + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + // Get column name from gorm tag or json tag + columnName := field.Tag.Get("gorm") + if strings.Contains(columnName, "column:") { + parts := strings.Split(columnName, ";") + for _, part := range parts { + if strings.HasPrefix(part, "column:") { + columnName = strings.TrimPrefix(part, "column:") + break + } + } + } else { + columnName = field.Tag.Get("json") + if columnName == "" || columnName == "-" { + columnName = strings.ToLower(field.Name) + } + } + + // Check for primary key and unique constraint + gormTag := field.Tag.Get("gorm") + + column := common.Column{ + Name: columnName, + Type: h.getColumnType(field.Type), + IsNullable: h.isNullable(field), + IsPrimary: strings.Contains(gormTag, "primaryKey") || strings.Contains(gormTag, "primary_key"), + IsUnique: strings.Contains(gormTag, "unique"), + HasIndex: strings.Contains(gormTag, "index"), + } + + metadata.Columns = append(metadata.Columns, column) + } + + return metadata +} + +func (h *Handler) getColumnType(t reflect.Type) string { + switch t.Kind() { + case reflect.String: + return "string" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return "integer" + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return "integer" + case reflect.Float32, reflect.Float64: + return "float" + case reflect.Bool: + return "boolean" + case reflect.Ptr: + return h.getColumnType(t.Elem()) + default: + return "unknown" + } +} + +func (h *Handler) isNullable(field reflect.StructField) bool { + return field.Type.Kind() == reflect.Ptr +} + +func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) { + response := common.Response{ + Success: true, + Data: data, + Metadata: metadata, + } + w.WriteHeader(http.StatusOK) + w.WriteJSON(response) +} + +// sendFormattedResponse sends response with formatting options +func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) { + // Clean JSON if requested (remove null/empty fields) + if options.CleanJSON { + data = h.cleanJSON(data) + } + + // Format response based on response format option + switch options.ResponseFormat { + case "simple": + // Simple format: just return the data array + w.WriteHeader(http.StatusOK) + w.WriteJSON(data) + case "syncfusion": + // Syncfusion format: { result: data, count: total } + response := map[string]interface{}{ + "result": data, + } + if metadata != nil { + response["count"] = metadata.Total + } + w.WriteHeader(http.StatusOK) + w.WriteJSON(response) + default: + // Default/detail format: standard response with metadata + response := common.Response{ + Success: true, + Data: data, + Metadata: metadata, + } + w.WriteHeader(http.StatusOK) + w.WriteJSON(response) + } +} + +// cleanJSON removes null and empty fields from the response +func (h *Handler) cleanJSON(data interface{}) interface{} { + // This is a simplified implementation + // A full implementation would recursively clean nested structures + // For now, we'll return the data as-is + // TODO: Implement recursive cleaning + return data +} + +func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) { + var details string + if err != nil { + details = err.Error() + } + + response := common.Response{ + Success: false, + Error: &common.APIError{ + Code: code, + Message: message, + Details: details, + }, + } + w.WriteHeader(statusCode) + w.WriteJSON(response) +} diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go new file mode 100644 index 0000000..c6774b4 --- /dev/null +++ b/pkg/restheadspec/headers.go @@ -0,0 +1,441 @@ +package restheadspec + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/Warky-Devs/ResolveSpec/pkg/common" + "github.com/Warky-Devs/ResolveSpec/pkg/logger" +) + +// ExtendedRequestOptions extends common.RequestOptions with additional features +type ExtendedRequestOptions struct { + common.RequestOptions + + // Field selection + CleanJSON bool + + // Advanced filtering + SearchColumns []string + CustomSQLWhere string + CustomSQLOr string + + // Joins + Expand []ExpandOption + + // Advanced features + AdvancedSQL map[string]string // Column -> SQL expression + ComputedQL map[string]string // Column -> CQL expression + Distinct bool + SkipCount bool + SkipCache bool + FetchRowNumber *string + PKRow *string + + // Response format + ResponseFormat string // "simple", "detail", "syncfusion" + + // Transaction + AtomicTransaction bool + + // Cursor pagination + CursorForward string + CursorBackward string +} + +// ExpandOption represents a relation expansion configuration +type ExpandOption struct { + Relation string + Columns []string + Where string + Sort string +} + +// decodeHeaderValue decodes base64 encoded header values +// Supports ZIP_ and __ prefixes for base64 encoding +func decodeHeaderValue(value string) string { + // Check for ZIP_ prefix + if strings.HasPrefix(value, "ZIP_") { + decoded, err := base64.StdEncoding.DecodeString(value[4:]) + if err == nil { + return string(decoded) + } + logger.Warn("Failed to decode ZIP_ prefixed value: %v", err) + return value + } + + // Check for __ prefix + if strings.HasPrefix(value, "__") { + decoded, err := base64.StdEncoding.DecodeString(value[2:]) + if err == nil { + return string(decoded) + } + logger.Warn("Failed to decode __ prefixed value: %v", err) + return value + } + + return value +} + +// parseOptionsFromHeaders parses all request options from HTTP headers +func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptions { + options := ExtendedRequestOptions{ + RequestOptions: common.RequestOptions{ + Filters: make([]common.FilterOption, 0), + Sort: make([]common.SortOption, 0), + Preload: make([]common.PreloadOption, 0), + }, + AdvancedSQL: make(map[string]string), + ComputedQL: make(map[string]string), + Expand: make([]ExpandOption, 0), + } + + // Get all headers + headers := r.AllHeaders() + + // Process each header + for key, value := range headers { + // Normalize header key to lowercase for consistent matching + normalizedKey := strings.ToLower(key) + + // Decode value if it's base64 encoded + decodedValue := decodeHeaderValue(value) + + // Parse based on header prefix/name + switch { + // Field Selection + case strings.HasPrefix(normalizedKey, "x-select-fields"): + h.parseSelectFields(&options, decodedValue) + case strings.HasPrefix(normalizedKey, "x-not-select-fields"): + h.parseNotSelectFields(&options, decodedValue) + case strings.HasPrefix(normalizedKey, "x-clean-json"): + options.CleanJSON = strings.ToLower(decodedValue) == "true" + + // Filtering & Search + case strings.HasPrefix(normalizedKey, "x-fieldfilter-"): + h.parseFieldFilter(&options, normalizedKey, decodedValue) + case strings.HasPrefix(normalizedKey, "x-searchfilter-"): + h.parseSearchFilter(&options, normalizedKey, decodedValue) + case strings.HasPrefix(normalizedKey, "x-searchop-"): + h.parseSearchOp(&options, normalizedKey, decodedValue, "AND") + case strings.HasPrefix(normalizedKey, "x-searchor-"): + h.parseSearchOp(&options, normalizedKey, decodedValue, "OR") + case strings.HasPrefix(normalizedKey, "x-searchand-"): + h.parseSearchOp(&options, normalizedKey, decodedValue, "AND") + case strings.HasPrefix(normalizedKey, "x-searchcols"): + options.SearchColumns = h.parseCommaSeparated(decodedValue) + case strings.HasPrefix(normalizedKey, "x-custom-sql-w"): + options.CustomSQLWhere = decodedValue + case strings.HasPrefix(normalizedKey, "x-custom-sql-or"): + options.CustomSQLOr = decodedValue + + // Joins & Relations + case strings.HasPrefix(normalizedKey, "x-preload"): + h.parsePreload(&options, decodedValue) + case strings.HasPrefix(normalizedKey, "x-expand"): + h.parseExpand(&options, decodedValue) + case strings.HasPrefix(normalizedKey, "x-custom-sql-join"): + // TODO: Implement custom SQL join + logger.Debug("Custom SQL join not yet implemented: %s", decodedValue) + + // Sorting & Pagination + case strings.HasPrefix(normalizedKey, "x-sort"): + h.parseSorting(&options, decodedValue) + case strings.HasPrefix(normalizedKey, "x-limit"): + if limit, err := strconv.Atoi(decodedValue); err == nil { + options.Limit = &limit + } + case strings.HasPrefix(normalizedKey, "x-offset"): + if offset, err := strconv.Atoi(decodedValue); err == nil { + options.Offset = &offset + } + case strings.HasPrefix(normalizedKey, "x-cursor-forward"): + options.CursorForward = decodedValue + case strings.HasPrefix(normalizedKey, "x-cursor-backward"): + options.CursorBackward = decodedValue + + // Advanced Features + case strings.HasPrefix(normalizedKey, "x-advsql-"): + colName := strings.TrimPrefix(normalizedKey, "x-advsql-") + options.AdvancedSQL[colName] = decodedValue + case strings.HasPrefix(normalizedKey, "x-cql-sel-"): + colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-") + options.ComputedQL[colName] = decodedValue + case strings.HasPrefix(normalizedKey, "x-distinct"): + options.Distinct = strings.ToLower(decodedValue) == "true" + case strings.HasPrefix(normalizedKey, "x-skipcount"): + options.SkipCount = strings.ToLower(decodedValue) == "true" + case strings.HasPrefix(normalizedKey, "x-skipcache"): + options.SkipCache = strings.ToLower(decodedValue) == "true" + case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"): + options.FetchRowNumber = &decodedValue + case strings.HasPrefix(normalizedKey, "x-pkrow"): + options.PKRow = &decodedValue + + // Response Format + case strings.HasPrefix(normalizedKey, "x-simpleapi"): + options.ResponseFormat = "simple" + case strings.HasPrefix(normalizedKey, "x-detailapi"): + options.ResponseFormat = "detail" + case strings.HasPrefix(normalizedKey, "x-syncfusion"): + options.ResponseFormat = "syncfusion" + + // Transaction Control + case strings.HasPrefix(normalizedKey, "x-transaction-atomic"): + options.AtomicTransaction = strings.ToLower(decodedValue) == "true" + } + } + + return options +} + +// parseSelectFields parses x-select-fields header +func (h *Handler) parseSelectFields(options *ExtendedRequestOptions, value string) { + if value == "" { + return + } + options.Columns = h.parseCommaSeparated(value) +} + +// parseNotSelectFields parses x-not-select-fields header +func (h *Handler) parseNotSelectFields(options *ExtendedRequestOptions, value string) { + if value == "" { + return + } + options.OmitColumns = h.parseCommaSeparated(value) +} + +// parseFieldFilter parses x-fieldfilter-{colname} header (exact match) +func (h *Handler) parseFieldFilter(options *ExtendedRequestOptions, headerKey, value string) { + colName := strings.TrimPrefix(headerKey, "x-fieldfilter-") + options.Filters = append(options.Filters, common.FilterOption{ + Column: colName, + Operator: "eq", + Value: value, + }) +} + +// parseSearchFilter parses x-searchfilter-{colname} header (ILIKE search) +func (h *Handler) parseSearchFilter(options *ExtendedRequestOptions, headerKey, value string) { + colName := strings.TrimPrefix(headerKey, "x-searchfilter-") + // Use ILIKE for fuzzy search + options.Filters = append(options.Filters, common.FilterOption{ + Column: colName, + Operator: "ilike", + Value: "%" + value + "%", + }) +} + +// parseSearchOp parses x-searchop-{operator}-{colname} and x-searchor-{operator}-{colname} +func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, value, logicOp string) { + // Extract operator and column name + // Format: x-searchop-{operator}-{colname} or x-searchor-{operator}-{colname} + var prefix string + if logicOp == "OR" { + prefix = "x-searchor-" + } else { + prefix = "x-searchop-" + if strings.HasPrefix(headerKey, "x-searchand-") { + prefix = "x-searchand-" + } + } + + rest := strings.TrimPrefix(headerKey, prefix) + parts := strings.SplitN(rest, "-", 2) + if len(parts) != 2 { + logger.Warn("Invalid search operator header format: %s", headerKey) + return + } + + operator := parts[0] + colName := parts[1] + + // Map operator names to filter operators + filterOp := h.mapSearchOperator(operator, value) + + options.Filters = append(options.Filters, filterOp) + + // Note: OR logic would need special handling in query builder + // For now, we'll add a comment to indicate OR logic + if logicOp == "OR" { + // TODO: Implement OR logic in query builder + logger.Debug("OR logic filter: %s %s %v", colName, filterOp.Operator, filterOp.Value) + } +} + +// mapSearchOperator maps search operator names to filter operators +func (h *Handler) mapSearchOperator(operator, value string) common.FilterOption { + operator = strings.ToLower(operator) + + switch operator { + case "contains": + return common.FilterOption{Operator: "ilike", Value: "%" + value + "%"} + case "beginswith", "startswith": + return common.FilterOption{Operator: "ilike", Value: value + "%"} + case "endswith": + return common.FilterOption{Operator: "ilike", Value: "%" + value} + case "equals", "eq": + return common.FilterOption{Operator: "eq", Value: value} + case "notequals", "neq", "ne": + return common.FilterOption{Operator: "neq", Value: value} + case "greaterthan", "gt": + return common.FilterOption{Operator: "gt", Value: value} + case "lessthan", "lt": + return common.FilterOption{Operator: "lt", Value: value} + case "greaterthanorequal", "gte", "ge": + return common.FilterOption{Operator: "gte", Value: value} + case "lessthanorequal", "lte", "le": + return common.FilterOption{Operator: "lte", Value: value} + case "between": + // Parse between values (format: "value1,value2") + // Between is exclusive (> value1 AND < value2) + parts := strings.Split(value, ",") + if len(parts) == 2 { + return common.FilterOption{Operator: "between", Value: parts} + } + return common.FilterOption{Operator: "eq", Value: value} + case "betweeninclusive": + // Parse between values (format: "value1,value2") + // Between inclusive is >= value1 AND <= value2 + parts := strings.Split(value, ",") + if len(parts) == 2 { + return common.FilterOption{Operator: "between_inclusive", Value: parts} + } + return common.FilterOption{Operator: "eq", Value: value} + case "in": + // Parse IN values (format: "value1,value2,value3") + values := strings.Split(value, ",") + return common.FilterOption{Operator: "in", Value: values} + case "empty", "isnull", "null": + // Check for NULL or empty string + return common.FilterOption{Operator: "is_null", Value: nil} + case "notempty", "isnotnull", "notnull": + // Check for NOT NULL + return common.FilterOption{Operator: "is_not_null", Value: nil} + default: + logger.Warn("Unknown search operator: %s, defaulting to equals", operator) + return common.FilterOption{Operator: "eq", Value: value} + } +} + +// parsePreload parses x-preload header +// Format: RelationName:field1,field2 or RelationName or multiple separated by | +func (h *Handler) parsePreload(options *ExtendedRequestOptions, value string) { + if value == "" { + return + } + + // Split by | for multiple preloads + preloads := strings.Split(value, "|") + for _, preloadStr := range preloads { + preloadStr = strings.TrimSpace(preloadStr) + if preloadStr == "" { + continue + } + + // Parse relation:columns format + parts := strings.SplitN(preloadStr, ":", 2) + preload := common.PreloadOption{ + Relation: strings.TrimSpace(parts[0]), + } + + if len(parts) == 2 { + // Parse columns + preload.Columns = h.parseCommaSeparated(parts[1]) + } + + options.Preload = append(options.Preload, preload) + } +} + +// parseExpand parses x-expand header (LEFT JOIN expansion) +// Format: RelationName:field1,field2 or RelationName or multiple separated by | +func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) { + if value == "" { + return + } + + // Split by | for multiple expands + expands := strings.Split(value, "|") + for _, expandStr := range expands { + expandStr = strings.TrimSpace(expandStr) + if expandStr == "" { + continue + } + + // Parse relation:columns format + parts := strings.SplitN(expandStr, ":", 2) + expand := ExpandOption{ + Relation: strings.TrimSpace(parts[0]), + } + + if len(parts) == 2 { + // Parse columns + expand.Columns = h.parseCommaSeparated(parts[1]) + } + + options.Expand = append(options.Expand, expand) + } +} + +// parseSorting parses x-sort header +// Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC) +func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { + if value == "" { + return + } + + sortFields := h.parseCommaSeparated(value) + for _, field := range sortFields { + field = strings.TrimSpace(field) + if field == "" { + continue + } + + direction := "ASC" + colName := field + + if strings.HasPrefix(field, "-") { + direction = "DESC" + colName = strings.TrimPrefix(field, "-") + } else if strings.HasPrefix(field, "+") { + direction = "ASC" + colName = strings.TrimPrefix(field, "+") + } + + options.Sort = append(options.Sort, common.SortOption{ + Column: colName, + Direction: direction, + }) + } +} + +// parseCommaSeparated parses comma-separated values and trims whitespace +func (h *Handler) parseCommaSeparated(value string) []string { + if value == "" { + return nil + } + + parts := strings.Split(value, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + result = append(result, part) + } + } + return result +} + +// parseJSONHeader parses a header value as JSON +func (h *Handler) parseJSONHeader(value string) (map[string]interface{}, error) { + var result map[string]interface{} + err := json.Unmarshal([]byte(value), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON header: %w", err) + } + return result, nil +} diff --git a/pkg/restheadspec/restheadspec.go b/pkg/restheadspec/restheadspec.go new file mode 100644 index 0000000..6e7cadd --- /dev/null +++ b/pkg/restheadspec/restheadspec.go @@ -0,0 +1,203 @@ +package restheadspec + +import ( + "net/http" + + "github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database" + "github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/router" + "github.com/Warky-Devs/ResolveSpec/pkg/modelregistry" + "github.com/gorilla/mux" + "github.com/uptrace/bun" + "github.com/uptrace/bunrouter" + "gorm.io/gorm" +) + +// NewHandlerWithGORM creates a new Handler with GORM adapter +func NewHandlerWithGORM(db *gorm.DB) *Handler { + gormAdapter := database.NewGormAdapter(db) + registry := modelregistry.NewModelRegistry() + return NewHandler(gormAdapter, registry) +} + +// NewHandlerWithBun creates a new Handler with Bun adapter +func NewHandlerWithBun(db *bun.DB) *Handler { + bunAdapter := database.NewBunAdapter(db) + registry := modelregistry.NewModelRegistry() + return NewHandler(bunAdapter, registry) +} + +// NewStandardMuxRouter creates a router with standard Mux HTTP handlers +func NewStandardMuxRouter() *router.StandardMuxAdapter { + return router.NewStandardMuxAdapter() +} + +// NewStandardBunRouter creates a router with standard BunRouter handlers +func NewStandardBunRouter() *router.StandardBunRouterAdapter { + return router.NewStandardBunRouterAdapter() +} + +// SetupMuxRoutes sets up routes for the RestHeadSpec API with Mux +func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) { + // GET, POST, PUT, PATCH, DELETE for /{schema}/{entity} + muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + reqAdapter := router.NewHTTPRequest(r) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, vars) + }).Methods("GET", "POST") + + // GET, PUT, PATCH, DELETE for /{schema}/{entity}/{id} + muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + reqAdapter := router.NewHTTPRequest(r) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, vars) + }).Methods("GET", "PUT", "PATCH", "DELETE") + + // GET for metadata (using HandleGet) + muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + reqAdapter := router.NewHTTPRequest(r) + respAdapter := router.NewHTTPResponseWriter(w) + handler.HandleGet(respAdapter, reqAdapter, vars) + }).Methods("GET") +} + +// Example usage functions for documentation: + +// ExampleWithGORM shows how to use RestHeadSpec with GORM +func ExampleWithGORM(db *gorm.DB) { + // Create handler using GORM + handler := NewHandlerWithGORM(db) + + // Setup router + muxRouter := mux.NewRouter() + SetupMuxRoutes(muxRouter, handler) + + // Register models + // handler.registry.RegisterModel("public.users", &User{}) +} + +// ExampleWithBun shows how to switch to Bun ORM +func ExampleWithBun(bunDB *bun.DB) { + // Create Bun adapter + dbAdapter := database.NewBunAdapter(bunDB) + + // Create model registry + registry := modelregistry.NewModelRegistry() + // registry.RegisterModel("public.users", &User{}) + + // Create handler + handler := NewHandler(dbAdapter, registry) + + // Setup routes + muxRouter := mux.NewRouter() + SetupMuxRoutes(muxRouter, handler) +} + +// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API +func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) { + r := bunRouter.GetBunRouter() + + // GET and POST for /:schema/:entity + r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + } + reqAdapter := router.NewBunRouterRequest(req) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + } + reqAdapter := router.NewBunRouterRequest(req) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + // GET, PUT, PATCH, DELETE for /:schema/:entity/:id + r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + "id": req.Param("id"), + } + reqAdapter := router.NewBunRouterRequest(req) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + "id": req.Param("id"), + } + reqAdapter := router.NewBunRouterRequest(req) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + r.Handle("PATCH", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + "id": req.Param("id"), + } + reqAdapter := router.NewBunRouterRequest(req) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + r.Handle("DELETE", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + "id": req.Param("id"), + } + reqAdapter := router.NewBunRouterRequest(req) + respAdapter := router.NewHTTPResponseWriter(w) + handler.Handle(respAdapter, reqAdapter, params) + return nil + }) + + // Metadata endpoint + r.Handle("GET", "/:schema/:entity/metadata", func(w http.ResponseWriter, req bunrouter.Request) error { + params := map[string]string{ + "schema": req.Param("schema"), + "entity": req.Param("entity"), + } + reqAdapter := router.NewBunRouterRequest(req) + respAdapter := router.NewHTTPResponseWriter(w) + handler.HandleGet(respAdapter, reqAdapter, params) + return nil + }) +} + +// ExampleBunRouterWithBunDB shows usage with both BunRouter and Bun DB +func ExampleBunRouterWithBunDB(bunDB *bun.DB) { + // Create handler + handler := NewHandlerWithBun(bunDB) + + // Create BunRouter adapter + routerAdapter := NewStandardBunRouter() + + // Setup routes + SetupBunRouterRoutes(routerAdapter, handler) + + // Get the underlying router for server setup + r := routerAdapter.GetBunRouter() + + // Start server + http.ListenAndServe(":8080", r) +} diff --git a/pkg/testmodels/business.go b/pkg/testmodels/business.go index d348654..d68d8d1 100644 --- a/pkg/testmodels/business.go +++ b/pkg/testmodels/business.go @@ -3,7 +3,7 @@ package testmodels import ( "time" - "github.com/Warky-Devs/ResolveSpec/pkg/models" + "github.com/Warky-Devs/ResolveSpec/pkg/modelregistry" ) // Department represents a company department @@ -138,11 +138,24 @@ func (Comment) TableName() string { return "comments" } -func RegisterTestModels() { - models.RegisterModel(&Department{}, "departments") - models.RegisterModel(&Employee{}, "employees") - models.RegisterModel(&Project{}, "projects") - models.RegisterModel(&ProjectTask{}, "project_tasks") - models.RegisterModel(&Document{}, "documents") - models.RegisterModel(&Comment{}, "comments") +// RegisterTestModels registers all test models with the provided registry +func RegisterTestModels(registry *modelregistry.DefaultModelRegistry) { + registry.RegisterModel("departments", &Department{}) + registry.RegisterModel("employees", &Employee{}) + registry.RegisterModel("projects", &Project{}) + registry.RegisterModel("project_tasks", &ProjectTask{}) + registry.RegisterModel("documents", &Document{}) + registry.RegisterModel("comments", &Comment{}) +} + +// GetTestModels returns a list of all test model instances +func GetTestModels() []interface{} { + return []interface{}{ + &Department{}, + &Employee{}, + &Project{}, + &ProjectTask{}, + &Document{}, + &Comment{}, + } } diff --git a/tests/test_helpers.go b/tests/test_helpers.go index 1691992..7127750 100644 --- a/tests/test_helpers.go +++ b/tests/test_helpers.go @@ -11,7 +11,7 @@ import ( "testing" "github.com/Warky-Devs/ResolveSpec/pkg/logger" - "github.com/Warky-Devs/ResolveSpec/pkg/models" + "github.com/Warky-Devs/ResolveSpec/pkg/modelregistry" "github.com/Warky-Devs/ResolveSpec/pkg/resolvespec" "github.com/Warky-Devs/ResolveSpec/pkg/testmodels" "github.com/glebarez/sqlite" @@ -104,9 +104,6 @@ func setupTestDB() (*gorm.DB, error) { return nil, fmt.Errorf("failed to open database: %v", err) } - // Init Models - testmodels.RegisterTestModels() - // Auto migrate all test models err = autoMigrateModels(db) if err != nil { @@ -119,17 +116,24 @@ func setupTestDB() (*gorm.DB, error) { // setupTestRouter creates and configures the test router func setupTestRouter(db *gorm.DB) http.Handler { r := mux.NewRouter() - handler := resolvespec.NewAPIHandler(db) - r.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - handler.Handle(w, r, vars) - }).Methods("POST") + // Create a new registry instance + registry := modelregistry.NewModelRegistry() - r.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - handler.Handle(w, r, vars) - }).Methods("POST") + // Register test models with the registry + testmodels.RegisterTestModels(registry) + + // Create handler with GORM adapter and the registry + handler := resolvespec.NewHandlerWithGORM(db) + + // Register test models with the handler for the "test" schema + models := testmodels.GetTestModels() + modelNames := []string{"departments", "employees", "projects", "project_tasks", "documents", "comments"} + for i, model := range models { + handler.RegisterModel("test", modelNames[i], model) + } + + resolvespec.SetupMuxRoutes(r, handler) return r } @@ -147,6 +151,6 @@ func cleanup() { // autoMigrateModels performs automigration for all test models func autoMigrateModels(db *gorm.DB) error { - modelList := models.GetModels() + modelList := testmodels.GetTestModels() return db.AutoMigrate(modelList...) }