mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa99e8e4bc | ||
|
|
163593901f | ||
|
|
1261960e97 | ||
|
|
76bbf33db2 |
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
@@ -99,12 +100,20 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
|
||||
// BunSelectQuery implements SelectQuery for Bun
|
||||
type BunSelectQuery struct {
|
||||
query *bun.SelectQuery
|
||||
db bun.IDB // Store DB connection for count queries
|
||||
hasModel bool // Track if Model() was called
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
query *bun.SelectQuery
|
||||
db bun.IDB // Store DB connection for count queries
|
||||
hasModel bool // Track if Model() was called
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
||||
}
|
||||
|
||||
// deferredPreload represents a preload that will be executed as a separate query
|
||||
// to avoid PostgreSQL identifier length limits
|
||||
type deferredPreload struct {
|
||||
relation string
|
||||
apply []func(common.SelectQuery) common.SelectQuery
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
@@ -233,7 +242,92 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
return b
|
||||
}
|
||||
|
||||
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
|
||||
// // when combined with typical column names
|
||||
// func shortenAliasForPostgres(relationPath string) (string, bool) {
|
||||
// // Convert relation path to the alias format Bun uses: dots become double underscores
|
||||
// // Also convert to lowercase and use snake_case as Bun does
|
||||
// parts := strings.Split(relationPath, ".")
|
||||
// alias := strings.ToLower(strings.Join(parts, "__"))
|
||||
|
||||
// // PostgreSQL truncates identifiers to 63 chars
|
||||
// // If the alias + typical column name would exceed this, we need to shorten
|
||||
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
|
||||
// const maxAliasLength = 30
|
||||
|
||||
// if len(alias) > maxAliasLength {
|
||||
// // Create a shortened alias using a hash of the original
|
||||
// hash := md5.Sum([]byte(alias))
|
||||
// hashStr := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
// // Keep first few chars of original for readability + hash
|
||||
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
|
||||
// if prefixLen > len(alias) {
|
||||
// prefixLen = len(alias)
|
||||
// }
|
||||
|
||||
// shortened := alias[:prefixLen] + "_" + hashStr
|
||||
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
|
||||
// alias, len(alias), shortened, len(shortened))
|
||||
// return shortened, true
|
||||
// }
|
||||
|
||||
// return alias, false
|
||||
// }
|
||||
|
||||
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
|
||||
// // Bun creates aliases like: relationChain__columnName
|
||||
// func estimateColumnAliasLength(relationPath string, columnName string) int {
|
||||
// relationParts := strings.Split(relationPath, ".")
|
||||
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||
// // Bun adds "__" between alias and column name
|
||||
// return len(aliasChain) + 2 + len(columnName)
|
||||
// }
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Check if this relation chain would create problematic long aliases
|
||||
relationParts := strings.Split(relation, ".")
|
||||
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||
|
||||
// PostgreSQL's identifier limit is 63 characters
|
||||
const postgresIdentifierLimit = 63
|
||||
const safeAliasLimit = 35 // Leave room for column names
|
||||
|
||||
// If the alias chain is too long, defer this preload to be executed as a separate query
|
||||
if len(aliasChain) > safeAliasLimit {
|
||||
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
|
||||
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
|
||||
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
|
||||
|
||||
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
|
||||
// This avoids the long concatenated alias
|
||||
if len(relationParts) > 1 {
|
||||
// Load first level normally: "Parent"
|
||||
firstLevel := relationParts[0]
|
||||
remainingPath := strings.Join(relationParts[1:], ".")
|
||||
|
||||
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
|
||||
firstLevel, remainingPath)
|
||||
|
||||
// Apply the first level preload normally
|
||||
b.query = b.query.Relation(firstLevel)
|
||||
|
||||
// Store the remaining nested preload to be executed after the main query
|
||||
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
|
||||
relation: relation,
|
||||
apply: apply,
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Single level but still too long - just warn and continue
|
||||
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
|
||||
"Consider renaming the field to avoid potential issues.",
|
||||
relation, len(aliasChain))
|
||||
}
|
||||
|
||||
// Normal preload handling
|
||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -309,7 +403,23 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
||||
if dest == nil {
|
||||
return fmt.Errorf("destination cannot be nil")
|
||||
}
|
||||
return b.query.Scan(ctx, dest)
|
||||
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute any deferred preloads
|
||||
if len(b.deferredPreloads) > 0 {
|
||||
err = b.executeDeferredPreloads(ctx, dest)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||
// Don't fail the whole query, just log the warning
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
@@ -322,7 +432,132 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
return fmt.Errorf("model is nil")
|
||||
}
|
||||
|
||||
return b.query.Scan(ctx)
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute any deferred preloads
|
||||
if len(b.deferredPreloads) > 0 {
|
||||
model := b.query.GetModel()
|
||||
err = b.executeDeferredPreloads(ctx, model.Value())
|
||||
if err != nil {
|
||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||
// Don't fail the whole query, just log the warning
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
|
||||
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
|
||||
if len(b.deferredPreloads) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, dp := range b.deferredPreloads {
|
||||
err := b.executeSingleDeferredPreload(ctx, dest, dp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeSingleDeferredPreload executes a single deferred preload
|
||||
// For a relation like "Parent.Child", it:
|
||||
// 1. Finds all loaded Parent records in dest
|
||||
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
|
||||
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
|
||||
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
|
||||
relationParts := strings.Split(dp.relation, ".")
|
||||
if len(relationParts) < 2 {
|
||||
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
|
||||
}
|
||||
|
||||
// The parent relation that was already loaded
|
||||
parentRelation := relationParts[0]
|
||||
// The child relation we need to load
|
||||
childRelation := strings.Join(relationParts[1:], ".")
|
||||
|
||||
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
|
||||
|
||||
// Use reflection to access the parent relation field(s) in the loaded records
|
||||
// Then load the child relation for those parent records
|
||||
destValue := reflect.ValueOf(dest)
|
||||
if destValue.Kind() == reflect.Ptr {
|
||||
destValue = destValue.Elem()
|
||||
}
|
||||
|
||||
// Handle both slice and single record
|
||||
if destValue.Kind() == reflect.Slice {
|
||||
// Iterate through each record in the slice
|
||||
for i := 0; i < destValue.Len(); i++ {
|
||||
record := destValue.Index(i)
|
||||
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
|
||||
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
|
||||
// Continue with other records
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single record
|
||||
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
|
||||
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadChildRelationForRecord loads a child relation for a single parent record
|
||||
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
|
||||
// Ensure we're working with the actual struct value, not a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Get the parent relation field
|
||||
parentField := record.FieldByName(parentRelation)
|
||||
if !parentField.IsValid() {
|
||||
// Parent relation field doesn't exist
|
||||
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the parent field is nil (for pointer fields)
|
||||
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
|
||||
// Parent relation not loaded or nil, skip
|
||||
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the interface value to pass to Bun
|
||||
parentValue := parentField.Interface()
|
||||
|
||||
// Load the child relation on the parent record
|
||||
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
|
||||
return b.db.NewSelect().
|
||||
Model(parentValue).
|
||||
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
// Apply any custom query modifications
|
||||
if len(apply) > 0 {
|
||||
wrapper := &BunSelectQuery{query: sq, db: b.db}
|
||||
current := common.SelectQuery(wrapper)
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
current = fn(current)
|
||||
}
|
||||
}
|
||||
if finalBun, ok := current.(*BunSelectQuery); ok {
|
||||
return finalBun.query
|
||||
}
|
||||
}
|
||||
return sq
|
||||
}).
|
||||
Scan(ctx)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package common
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Database interface designed to work with both GORM and Bun
|
||||
type Database interface {
|
||||
@@ -130,6 +135,99 @@ type ResponseWriter interface {
|
||||
// HTTPHandlerFunc type for HTTP handlers
|
||||
type HTTPHandlerFunc func(ResponseWriter, Request)
|
||||
|
||||
// WrapHTTPRequest wraps standard http.ResponseWriter and *http.Request into common interfaces
|
||||
func WrapHTTPRequest(w http.ResponseWriter, r *http.Request) (ResponseWriter, Request) {
|
||||
return &StandardResponseWriter{w: w}, &StandardRequest{r: r}
|
||||
}
|
||||
|
||||
// StandardResponseWriter adapts http.ResponseWriter to ResponseWriter interface
|
||||
type StandardResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) SetHeader(key, value string) {
|
||||
s.w.Header().Set(key, value)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) WriteHeader(statusCode int) {
|
||||
s.status = statusCode
|
||||
s.w.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.w.Write(data)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
||||
s.SetHeader("Content-Type", "application/json")
|
||||
return json.NewEncoder(s.w).Encode(data)
|
||||
}
|
||||
|
||||
// StandardRequest adapts *http.Request to Request interface
|
||||
type StandardRequest struct {
|
||||
r *http.Request
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (s *StandardRequest) Method() string {
|
||||
return s.r.Method
|
||||
}
|
||||
|
||||
func (s *StandardRequest) URL() string {
|
||||
return s.r.URL.String()
|
||||
}
|
||||
|
||||
func (s *StandardRequest) Header(key string) string {
|
||||
return s.r.Header.Get(key)
|
||||
}
|
||||
|
||||
func (s *StandardRequest) AllHeaders() map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for key, values := range s.r.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func (s *StandardRequest) Body() ([]byte, error) {
|
||||
if s.body != nil {
|
||||
return s.body, nil
|
||||
}
|
||||
if s.r.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer s.r.Body.Close()
|
||||
body, err := io.ReadAll(s.r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.body = body
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (s *StandardRequest) PathParam(key string) string {
|
||||
// Standard http.Request doesn't have path params
|
||||
// This should be set by the router
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *StandardRequest) QueryParam(key string) string {
|
||||
return s.r.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (s *StandardRequest) AllQueryParams() map[string]string {
|
||||
params := make(map[string]string)
|
||||
for key, values := range s.r.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
params[key] = values[0]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// TableNameProvider interface for models that provide table names
|
||||
type TableNameProvider interface {
|
||||
TableName() string
|
||||
|
||||
@@ -199,7 +199,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply column selection
|
||||
if len(options.Columns) > 0 {
|
||||
logger.Debug("Selecting columns: %v", options.Columns)
|
||||
query = query.Column(options.Columns...)
|
||||
for _, col := range options.Columns {
|
||||
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||
}
|
||||
}
|
||||
|
||||
if len(options.ComputedColumns) > 0 {
|
||||
|
||||
@@ -213,6 +213,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
tableName := GetTableName(ctx)
|
||||
model := GetModel(ctx)
|
||||
|
||||
if id == "" {
|
||||
options.SingleRecordAsObject = false
|
||||
}
|
||||
|
||||
// Execute BeforeRead hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
@@ -299,7 +303,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply column selection
|
||||
if len(options.Columns) > 0 {
|
||||
logger.Debug("Selecting columns: %v", options.Columns)
|
||||
query = query.Column(options.Columns...)
|
||||
for _, col := range options.Columns {
|
||||
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Apply expand (Just expand to Preload for now)
|
||||
@@ -652,7 +659,6 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Apply WHERE clause
|
||||
if len(preload.Where) > 0 {
|
||||
|
||||
@@ -162,9 +162,17 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
case strings.HasPrefix(key, "x-searchcols"):
|
||||
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
||||
case strings.HasPrefix(key, "x-custom-sql-w"):
|
||||
options.CustomSQLWhere = decodedValue
|
||||
if options.CustomSQLWhere != "" {
|
||||
options.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", options.CustomSQLWhere, decodedValue)
|
||||
} else {
|
||||
options.CustomSQLWhere = decodedValue
|
||||
}
|
||||
case strings.HasPrefix(key, "x-custom-sql-or"):
|
||||
options.CustomSQLOr = decodedValue
|
||||
if options.CustomSQLOr != "" {
|
||||
options.CustomSQLOr = fmt.Sprintf("%s OR (%s)", options.CustomSQLOr, decodedValue)
|
||||
} else {
|
||||
options.CustomSQLOr = decodedValue
|
||||
}
|
||||
|
||||
// Joins & Relations
|
||||
case strings.HasPrefix(key, "x-preload"):
|
||||
@@ -226,6 +234,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
case strings.HasPrefix(key, "x-cql-sel-"):
|
||||
colName := strings.TrimPrefix(key, "x-cql-sel-")
|
||||
options.ComputedQL[colName] = decodedValue
|
||||
|
||||
case strings.HasPrefix(key, "x-distinct"):
|
||||
options.Distinct = strings.EqualFold(decodedValue, "true")
|
||||
case strings.HasPrefix(key, "x-skipcount"):
|
||||
@@ -267,7 +276,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
h.resolveRelationNamesInOptions(&options, model)
|
||||
}
|
||||
|
||||
//Always sort according to the primary key if no sorting is specified
|
||||
// Always sort according to the primary key if no sorting is specified
|
||||
if len(options.Sort) == 0 {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
@@ -783,7 +792,7 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
||||
field := modelType.Field(i)
|
||||
if field.Name == nameOrTable {
|
||||
// It's already a field name
|
||||
logger.Debug("Input '%s' is a field name", nameOrTable)
|
||||
// logger.Debug("Input '%s' is a field name", nameOrTable)
|
||||
return nameOrTable
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user