Compare commits

..

11 Commits

Author SHA1 Message Date
Hein
7a498edab7 fix(headers): enhance relation name resolution logic
* Allow resolution for both regular headers and X-Files.
* Introduce join-key-aware resolution for disambiguation.
* Add new function to handle multiple fields pointing to the same type.
2026-03-25 12:09:03 +02:00
Hein
f10bb0827e fix(sql_helpers): ensure case-insensitive matching for allowed prefixes 2026-03-25 10:57:42 +02:00
Hein
22a4ab345a feat(security): add session cookie management functions
* Introduce SessionCookieOptions for configurable session cookies
* Implement SetSessionCookie, GetSessionCookie, and ClearSessionCookie functions
* Enhance cookie handling in DatabaseAuthenticator
2026-03-24 17:11:53 +02:00
Hein
e289c2ed8f fix(handler): restore JoinAliases for proper WHERE sanitization 2026-03-24 12:00:02 +02:00
Hein
0d50bcfee6 fix(provider): enhance file opening logic with alternate path. Handling broken cases to be compatible with Bitech clients
* Implemented alternate path handling for file retrieval
* Improved error messaging for file not found scenarios
2026-03-24 09:02:17 +02:00
4df626ea71 chore(license): update project notice and clarify licensing terms 2026-03-23 20:32:09 +02:00
Hein
7dd630dec2 fix(handler): set default sort to primary key if none provided
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m15s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m11s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m52s
Build , Vet Test, and Lint / Build (push) Successful in -30m44s
Tests / Integration Tests (push) Failing after -31m5s
Tests / Unit Tests (push) Successful in -29m6s
2026-03-11 14:37:04 +02:00
Hein
613bf22cbd fix(cursor): use full schema-qualified table name in filters 2026-03-11 14:25:44 +02:00
d1ae4fe64e refactor(handler): unify filter operator handling for consistency
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m26s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m58s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m48s
Build , Vet Test, and Lint / Build (push) Successful in -30m4s
Tests / Integration Tests (push) Failing after -30m39s
Tests / Unit Tests (push) Successful in -30m29s
2026-03-01 13:21:38 +02:00
254102bfac refactor(auth): simplify handler type assertions for middleware
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m31s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m19s
Build , Vet Test, and Lint / Build (push) Successful in -29m42s
Tests / Integration Tests (push) Failing after -30m35s
Tests / Unit Tests (push) Successful in -30m17s
2026-03-01 12:08:36 +02:00
6c27419dbc refactor(auth): enhance request handling with middleware-enriched context 2026-03-01 12:06:43 +02:00
16 changed files with 406 additions and 94 deletions

15
LICENSE
View File

@@ -1,3 +1,18 @@
Project Notice
This project was independently developed.
The contents of this repository were prepared and published outside any time
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
or rely upon any proprietary or confidential information, trade secrets,
protected designs, or other intellectual property of Bitech Systems CC.
No portion of this repository reproduces any Bitech Systems CC-specific
implementation, design asset, confidential workflow, or non-public technical material.
This notice is provided for clarification only and does not modify the terms of
the Apache License, Version 2.0.
Apache License Apache License
Version 2.0, January 2004 Version 2.0, January 2004
http://www.apache.org/licenses/ http://www.apache.org/licenses/

View File

@@ -168,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
} }
// Build a set of allowed table prefixes (main table + preloaded relations) // Build a set of allowed table prefixes (main table + preloaded relations)
// Keys are stored lowercase for case-insensitive matching
allowedPrefixes := make(map[string]bool) allowedPrefixes := make(map[string]bool)
if tableName != "" { if tableName != "" {
allowedPrefixes[tableName] = true allowedPrefixes[strings.ToLower(tableName)] = true
} }
// Add preload relation names as allowed prefixes // Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil { if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload { for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" { if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation) logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
} }
} }
@@ -185,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
// Add join aliases as allowed prefixes // Add join aliases as allowed prefixes
for _, alias := range options[0].JoinAliases { for _, alias := range options[0].JoinAliases {
if alias != "" { if alias != "" {
allowedPrefixes[alias] = true allowedPrefixes[strings.ToLower(alias)] = true
logger.Debug("Added join alias '%s' as allowed table prefix", alias) logger.Debug("Added join alias '%s' as allowed table prefix", alias)
} }
} }
@@ -217,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
currentPrefix, columnName := extractTableAndColumn(condToCheck) currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" { if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation) // Check if the prefix is allowed (main table or preload relation) - case-insensitive
if !allowedPrefixes[currentPrefix] { if !allowedPrefixes[strings.ToLower(currentPrefix)] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table // Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) { if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name // Replace the incorrect prefix with the correct main table name

View File

@@ -32,7 +32,8 @@ func GetCursorFilter(
modelColumns []string, modelColumns []string,
options common.RequestOptions, options common.RequestOptions,
) (string, error) { ) (string, error) {
// Remove schema prefix if present // Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") { if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1] tableName = strings.SplitN(tableName, ".", 2)[1]
} }
@@ -115,7 +116,7 @@ func GetCursorFilter(
WHERE cursor_select.%s = %s WHERE cursor_select.%s = %s
AND (%s) AND (%s)
)`, )`,
tableName, fullTableName,
pkName, pkName,
cursorID, cursorID,
orSQL, orSQL,

View File

@@ -175,9 +175,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
// Should handle schema prefix properly // Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "users") { if !strings.Contains(filter, "public.users") {
t.Errorf("Filter should reference table name users, got: %s", filter) t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
} }
t.Logf("Generated cursor filter with schema: %s", filter) t.Logf("Generated cursor filter with schema: %s", filter)

View File

@@ -329,6 +329,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation // Extract model columns for validation
modelColumns := reflection.GetModelColumns(model) modelColumns := reflection.GetModelColumns(model)
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL // Get cursor filter SQL
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options) cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil { if err != nil {
@@ -1521,22 +1526,22 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
var args []interface{} var args []interface{}
switch filter.Operator { switch filter.Operator {
case "eq": case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column) condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "neq": case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column) condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "gt": case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column) condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "gte": case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column) condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "lt": case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column) condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "lte": case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column) condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "like": case "like":
@@ -1565,22 +1570,22 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
var args []interface{} var args []interface{}
switch filter.Operator { switch filter.Operator {
case "eq": case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column) condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "neq": case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column) condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "gt": case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column) condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "gte": case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column) condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "lt": case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column) condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "lte": case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column) condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value} args = []interface{}{filter.Value}
case "like": case "like":

View File

@@ -70,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}" entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
// Create handler functions for this specific entity // Create handler functions for this specific entity
postEntityHandler := createMuxHandler(handler, schema, entity, "") var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id") var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
getEntityHandler := createMuxGetHandler(handler, schema, entity, "") var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"}) optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"}) optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
// Apply authentication middleware if provided // Apply authentication middleware if provided
if authMiddleware != nil { if authMiddleware != nil {
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc) postEntityHandler = authMiddleware(postEntityHandler)
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc) postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc) getEntityHandler = authMiddleware(getEntityHandler)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth // Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
} }
@@ -225,7 +225,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
return func(w http.ResponseWriter, req bunrouter.Request) error { return func(w http.ResponseWriter, req bunrouter.Request) error {
// Create an http.Handler that calls the bunrouter handler // Create an http.Handler that calls the bunrouter handler
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = handler(w, req) // Replace the embedded *http.Request with the middleware-enriched one
// so that auth context (user ID, etc.) is visible to the handler.
enrichedReq := req
enrichedReq.Request = r
_ = handler(w, enrichedReq)
}) })
// Wrap with auth middleware and execute // Wrap with auth middleware and execute

View File

@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
modelColumns []string, // optional: for validation modelColumns []string, // optional: for validation
expandJoins map[string]string, // optional: alias → JOIN SQL expandJoins map[string]string, // optional: alias → JOIN SQL
) (string, error) { ) (string, error) {
// Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") { if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1] tableName = strings.SplitN(tableName, ".", 2)[1]
} }
@@ -127,7 +129,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
WHERE cursor_select.%s = %s WHERE cursor_select.%s = %s
AND (%s) AND (%s)
)`, )`,
tableName, fullTableName,
joinSQL, joinSQL,
pkName, pkName,
cursorID, cursorID,

View File

@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Fatalf("GetCursorFilter failed: %v", err) t.Fatalf("GetCursorFilter failed: %v", err)
} }
// Should handle schema prefix properly // Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "users") { if !strings.Contains(filter, "public.users") {
t.Errorf("Filter should reference table name users, got: %s", filter) t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
} }
t.Logf("Generated cursor filter with schema: %s", filter) t.Logf("Generated cursor filter with schema: %s", filter)

View File

@@ -731,6 +731,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// For now, pass empty map as joins are handled via Preload // For now, pass empty map as joins are handled via Preload
} }
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL // Get cursor filter SQL
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins) cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil { if err != nil {
@@ -2226,17 +2231,17 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
// buildFilterCondition builds a single filter condition and returns the condition string and args // buildFilterCondition builds a single filter condition and returns the condition string and args
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) { func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
switch strings.ToLower(filter.Operator) { switch strings.ToLower(filter.Operator) {
case "eq", "equals": case "eq", "equals", "=":
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value} return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
case "neq", "not_equals", "ne": case "neq", "not_equals", "ne", "!=", "<>":
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value} return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
case "gt", "greater_than": case "gt", "greater_than", ">":
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value} return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
case "gte", "greater_than_equals", "ge": case "gte", "greater_than_equals", "ge", ">=":
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value} return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
case "lt", "less_than": case "lt", "less_than", "<":
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value} return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
case "lte", "less_than_equals", "le": case "lte", "less_than_equals", "le", "<=":
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value} return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
case "like": case "like":
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value} return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
@@ -2879,6 +2884,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
// Filter base RequestOptions // Filter base RequestOptions
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions) filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
filtered.RequestOptions.JoinAliases = options.JoinAliases
// Filter SearchColumns // Filter SearchColumns
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns) filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)

View File

@@ -274,9 +274,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
} }
} }
// Resolve relation names (convert table names to field names) if model is provided // Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names // This runs for both regular headers and X-Files, because XFile prefixes don't always match model
if model != nil && !options.XFilesPresent { // field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
// disambiguate when multiple fields point to the same related type.
if model != nil {
h.resolveRelationNamesInOptions(&options, model) h.resolveRelationNamesInOptions(&options, model)
} }
@@ -863,8 +865,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Resolve each part of the path // Resolve each part of the path
currentModel := model currentModel := model
for _, part := range parts { for partIdx, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part) isLast := partIdx == len(parts)-1
var resolvedPart string
if isLast {
// For the final part, use join-key-aware resolution to disambiguate when
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
localKey := preload.RelatedKey
if localKey == "" {
localKey = preload.ForeignKey
}
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
} else {
resolvedPart = h.resolveRelationName(currentModel, part)
}
resolvedParts = append(resolvedParts, resolvedPart) resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level // Try to get the model type for the next level
@@ -980,6 +995,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
return nameOrTable return nameOrTable
} }
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
if localKey == "" {
return h.resolveRelationName(model, nameOrTable)
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nameOrTable
}
// If it's already a direct field name, return as-is (no ambiguity).
for i := 0; i < modelType.NumField(); i++ {
if modelType.Field(i).Name == nameOrTable {
return nameOrTable
}
}
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
localKeyLower := strings.ToLower(localKey)
// Find all fields whose related type matches nameOrTable, then pick the one
// whose bun join tag local key matches localKey.
var fallbackField string
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil && targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetType == nil || targetType.Kind() != reflect.Struct {
continue
}
normalizedTypeName := strings.ToLower(targetType.Name())
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
if normalizedTypeName != normalizedInput {
continue
}
// Type name matches; record as fallback.
if fallbackField == "" {
fallbackField = field.Name
}
// Check bun join tag: "join:localKey=foreignKey"
bunTag := field.Tag.Get("bun")
for _, tagPart := range strings.Split(bunTag, ",") {
tagPart = strings.TrimSpace(tagPart)
if !strings.HasPrefix(tagPart, "join:") {
continue
}
joinSpec := strings.TrimPrefix(tagPart, "join:")
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
joinCols := strings.Fields(joinSpec)
if len(joinCols) == 0 {
joinCols = []string{joinSpec}
}
for _, joinCol := range joinCols {
eqIdx := strings.Index(joinCol, "=")
if eqIdx < 0 {
continue
}
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
if joinLocalKey == localKeyLower {
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
return field.Name
}
}
}
}
if fallbackField != "" {
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
return fallbackField
}
return h.resolveRelationName(model, nameOrTable)
}
// addXFilesPreload converts an XFiles relation into a PreloadOption // addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children // and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) { func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
@@ -1061,15 +1171,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
} }
} }
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
if len(xfile.SqlJoins) > 0 {
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
for _, joinClause := range xfile.SqlJoins {
// Sanitize the join clause
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
if sanitizedJoin == "" {
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
continue
}
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
// Extract join alias for validation
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
}
}
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
}
// Add WHERE clause if SQL conditions specified // Add WHERE clause if SQL conditions specified
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
whereConditions := make([]string, 0) whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 { if len(xfile.SqlAnd) > 0 {
// Process each SQL condition var sqlAndOpts *common.RequestOptions
// Note: We don't add table prefixes here because they're only needed for JOINs if len(preloadOpt.JoinAliases) > 0 {
// The handler will add prefixes later if SqlJoins are present sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
}
for _, sqlCond := range xfile.SqlAnd { for _, sqlCond := range xfile.SqlAnd {
// Sanitize the condition without adding prefixes sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
if sanitizedCond != "" { if sanitizedCond != "" {
whereConditions = append(whereConditions, sanitizedCond) whereConditions = append(whereConditions, sanitizedCond)
} }
@@ -1114,32 +1251,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey) logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
} }
// Transfer SqlJoins from XFiles to PreloadOption
if len(xfile.SqlJoins) > 0 {
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
for _, joinClause := range xfile.SqlJoins {
// Sanitize the join clause
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
if sanitizedJoin == "" {
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
continue
}
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
// Extract join alias for validation
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
}
}
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
}
// Check if this table has a recursive child - if so, mark THIS preload as recursive // Check if this table has a recursive child - if so, mark THIS preload as recursive
// and store the recursive child's RelatedKey for recursion generation // and store the recursive child's RelatedKey for recursion generation
hasRecursiveChild := false hasRecursiveChild := false

View File

@@ -125,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
metadataPath := buildRoutePath(schema, entity) + "/metadata" metadataPath := buildRoutePath(schema, entity) + "/metadata"
// Create handler functions for this specific entity // Create handler functions for this specific entity
entityHandler := createMuxHandler(handler, schema, entity, "") var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id") var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
metadataHandler := createMuxGetHandler(handler, schema, entity, "") var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"}) optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}) optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
// Apply authentication middleware if provided // Apply authentication middleware if provided
if authMiddleware != nil { if authMiddleware != nil {
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc) entityHandler = authMiddleware(entityHandler)
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc) entityWithIDHandler = authMiddleware(entityWithIDHandler)
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc) metadataHandler = authMiddleware(metadataHandler)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth // Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
} }
@@ -289,7 +289,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
return func(w http.ResponseWriter, req bunrouter.Request) error { return func(w http.ResponseWriter, req bunrouter.Request) error {
// Create an http.Handler that calls the bunrouter handler // Create an http.Handler that calls the bunrouter handler
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = handler(w, req) // Replace the embedded *http.Request with the middleware-enriched one
// so that auth context (user ID, etc.) is visible to the handler.
enrichedReq := req
enrichedReq.Request = r
_ = handler(w, enrichedReq)
}) })
// Wrap with auth middleware and execute // Wrap with auth middleware and execute

View File

@@ -456,6 +456,125 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
return meta, ok return meta, ok
} }
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
// All fields are optional; sensible secure defaults are applied when omitted.
type SessionCookieOptions struct {
// Name is the cookie name. Defaults to "session_token".
Name string
// Path is the cookie path. Defaults to "/".
Path string
// Domain restricts the cookie to a specific domain. Empty means current host.
Domain string
// Secure sets the Secure flag. Defaults to true.
// Set to false only in local development over HTTP.
Secure *bool
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
SameSite http.SameSite
}
func (o SessionCookieOptions) name() string {
if o.Name != "" {
return o.Name
}
return "session_token"
}
func (o SessionCookieOptions) path() string {
if o.Path != "" {
return o.Path
}
return "/"
}
func (o SessionCookieOptions) secure() bool {
if o.Secure != nil {
return *o.Secure
}
return true
}
func (o SessionCookieOptions) sameSite() http.SameSite {
if o.SameSite != 0 {
return o.SameSite
}
return http.SameSiteLaxMode
}
// SetSessionCookie writes the session_token cookie to the response after a successful login.
// Call this immediately after a successful Authenticator.Login() call.
//
// Example:
//
// resp, err := auth.Login(r.Context(), req)
// if err != nil { ... }
// security.SetSessionCookie(w, resp)
// json.NewEncoder(w).Encode(resp)
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
maxAge := 0
if loginResp.ExpiresIn > 0 {
maxAge = int(loginResp.ExpiresIn)
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: loginResp.Token,
Path: o.path(),
Domain: o.Domain,
MaxAge: maxAge,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
//
// Example:
//
// token := security.GetSessionCookie(r)
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
cookie, err := r.Cookie(o.name())
if err != nil {
return ""
}
return cookie.Value
}
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
// Call this after a successful Authenticator.Logout() call.
//
// Example:
//
// err := auth.Logout(r.Context(), req)
// if err != nil { ... }
// security.ClearSessionCookie(w)
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: "",
Path: o.path(),
Domain: o.Domain,
MaxAge: -1,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware // GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) { func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules) rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)

View File

@@ -222,9 +222,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
if sessionToken == "" { if sessionToken == "" {
// Try cookie // Try cookie
cookie, err := r.Cookie("session_token") if token := GetSessionCookie(r); token != "" {
if err == nil { tokens = []string{token}
tokens = []string{cookie.Value}
reference = "cookie" reference = "cookie"
} }
} else { } else {

View File

@@ -98,6 +98,7 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path // Apply prefix stripping by prepending the prefix to the requested path
actualPath := name actualPath := name
alternatePath := ""
if p.stripPrefix != "" { if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes // Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/") prefix := strings.Trim(p.stripPrefix, "/")
@@ -105,12 +106,25 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
if prefix != "" { if prefix != "" {
actualPath = path.Join(prefix, cleanName) actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else { } else {
actualPath = cleanName actualPath = cleanName
} }
} }
// First try the actual path with prefix
if file, err := p.fs.Open(actualPath); err == nil {
return file, nil
}
return p.fs.Open(actualPath) // If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.fs.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
} }
// Close releases any resources held by the provider. // Close releases any resources held by the provider.

View File

@@ -53,6 +53,7 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path // Apply prefix stripping by prepending the prefix to the requested path
actualPath := name actualPath := name
alternatePath := ""
if p.stripPrefix != "" { if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes // Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/") prefix := strings.Trim(p.stripPrefix, "/")
@@ -60,12 +61,26 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
if prefix != "" { if prefix != "" {
actualPath = path.Join(prefix, cleanName) actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else { } else {
actualPath = cleanName actualPath = cleanName
} }
} }
return p.fs.Open(actualPath) // First try the actual path with prefix
if file, err := p.fs.Open(actualPath); err == nil {
return file, nil
}
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.fs.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
} }
// Close releases any resources held by the provider. // Close releases any resources held by the provider.

View File

@@ -56,6 +56,7 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path // Apply prefix stripping by prepending the prefix to the requested path
actualPath := name actualPath := name
alternatePath := ""
if p.stripPrefix != "" { if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes // Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/") prefix := strings.Trim(p.stripPrefix, "/")
@@ -63,12 +64,26 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
if prefix != "" { if prefix != "" {
actualPath = path.Join(prefix, cleanName) actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else { } else {
actualPath = cleanName actualPath = cleanName
} }
} }
return p.zipFS.Open(actualPath) // First try the actual path with prefix
if file, err := p.zipFS.Open(actualPath); err == nil {
return file, nil
}
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.zipFS.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
} }
// Close releases resources held by the zip reader. // Close releases resources held by the zip reader.