Compare commits

...

3 Commits

Author SHA1 Message Date
Hein
7a498edab7 fix(headers): enhance relation name resolution logic
* Allow resolution for both regular headers and X-Files.
* Introduce join-key-aware resolution for disambiguation.
* Add new function to handle multiple fields pointing to the same type.
2026-03-25 12:09:03 +02:00
Hein
f10bb0827e fix(sql_helpers): ensure case-insensitive matching for allowed prefixes 2026-03-25 10:57:42 +02:00
Hein
22a4ab345a feat(security): add session cookie management functions
* Introduce SessionCookieOptions for configurable session cookies
* Implement SetSessionCookie, GetSessionCookie, and ClearSessionCookie functions
* Enhance cookie handling in DatabaseAuthenticator
2026-03-24 17:11:53 +02:00
4 changed files with 242 additions and 13 deletions

View File

@@ -168,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
} }
// Build a set of allowed table prefixes (main table + preloaded relations) // Build a set of allowed table prefixes (main table + preloaded relations)
// Keys are stored lowercase for case-insensitive matching
allowedPrefixes := make(map[string]bool) allowedPrefixes := make(map[string]bool)
if tableName != "" { if tableName != "" {
allowedPrefixes[tableName] = true allowedPrefixes[strings.ToLower(tableName)] = true
} }
// Add preload relation names as allowed prefixes // Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil { if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload { for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" { if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation) logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
} }
} }
@@ -185,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
// Add join aliases as allowed prefixes // Add join aliases as allowed prefixes
for _, alias := range options[0].JoinAliases { for _, alias := range options[0].JoinAliases {
if alias != "" { if alias != "" {
allowedPrefixes[alias] = true allowedPrefixes[strings.ToLower(alias)] = true
logger.Debug("Added join alias '%s' as allowed table prefix", alias) logger.Debug("Added join alias '%s' as allowed table prefix", alias)
} }
} }
@@ -217,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
currentPrefix, columnName := extractTableAndColumn(condToCheck) currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" { if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation) // Check if the prefix is allowed (main table or preload relation) - case-insensitive
if !allowedPrefixes[currentPrefix] { if !allowedPrefixes[strings.ToLower(currentPrefix)] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table // Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) { if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name // Replace the incorrect prefix with the correct main table name

View File

@@ -274,9 +274,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
} }
} }
// Resolve relation names (convert table names to field names) if model is provided // Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names // This runs for both regular headers and X-Files, because XFile prefixes don't always match model
if model != nil && !options.XFilesPresent { // field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
// disambiguate when multiple fields point to the same related type.
if model != nil {
h.resolveRelationNamesInOptions(&options, model) h.resolveRelationNamesInOptions(&options, model)
} }
@@ -863,8 +865,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Resolve each part of the path // Resolve each part of the path
currentModel := model currentModel := model
for _, part := range parts { for partIdx, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part) isLast := partIdx == len(parts)-1
var resolvedPart string
if isLast {
// For the final part, use join-key-aware resolution to disambiguate when
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
localKey := preload.RelatedKey
if localKey == "" {
localKey = preload.ForeignKey
}
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
} else {
resolvedPart = h.resolveRelationName(currentModel, part)
}
resolvedParts = append(resolvedParts, resolvedPart) resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level // Try to get the model type for the next level
@@ -980,6 +995,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
return nameOrTable return nameOrTable
} }
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
if localKey == "" {
return h.resolveRelationName(model, nameOrTable)
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nameOrTable
}
// If it's already a direct field name, return as-is (no ambiguity).
for i := 0; i < modelType.NumField(); i++ {
if modelType.Field(i).Name == nameOrTable {
return nameOrTable
}
}
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
localKeyLower := strings.ToLower(localKey)
// Find all fields whose related type matches nameOrTable, then pick the one
// whose bun join tag local key matches localKey.
var fallbackField string
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil && targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetType == nil || targetType.Kind() != reflect.Struct {
continue
}
normalizedTypeName := strings.ToLower(targetType.Name())
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
if normalizedTypeName != normalizedInput {
continue
}
// Type name matches; record as fallback.
if fallbackField == "" {
fallbackField = field.Name
}
// Check bun join tag: "join:localKey=foreignKey"
bunTag := field.Tag.Get("bun")
for _, tagPart := range strings.Split(bunTag, ",") {
tagPart = strings.TrimSpace(tagPart)
if !strings.HasPrefix(tagPart, "join:") {
continue
}
joinSpec := strings.TrimPrefix(tagPart, "join:")
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
joinCols := strings.Fields(joinSpec)
if len(joinCols) == 0 {
joinCols = []string{joinSpec}
}
for _, joinCol := range joinCols {
eqIdx := strings.Index(joinCol, "=")
if eqIdx < 0 {
continue
}
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
if joinLocalKey == localKeyLower {
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
return field.Name
}
}
}
}
if fallbackField != "" {
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
return fallbackField
}
return h.resolveRelationName(model, nameOrTable)
}
// addXFilesPreload converts an XFiles relation into a PreloadOption // addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children // and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) { func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {

View File

@@ -456,6 +456,125 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
return meta, ok return meta, ok
} }
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
// All fields are optional; sensible secure defaults are applied when omitted.
type SessionCookieOptions struct {
// Name is the cookie name. Defaults to "session_token".
Name string
// Path is the cookie path. Defaults to "/".
Path string
// Domain restricts the cookie to a specific domain. Empty means current host.
Domain string
// Secure sets the Secure flag. Defaults to true.
// Set to false only in local development over HTTP.
Secure *bool
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
SameSite http.SameSite
}
func (o SessionCookieOptions) name() string {
if o.Name != "" {
return o.Name
}
return "session_token"
}
func (o SessionCookieOptions) path() string {
if o.Path != "" {
return o.Path
}
return "/"
}
func (o SessionCookieOptions) secure() bool {
if o.Secure != nil {
return *o.Secure
}
return true
}
func (o SessionCookieOptions) sameSite() http.SameSite {
if o.SameSite != 0 {
return o.SameSite
}
return http.SameSiteLaxMode
}
// SetSessionCookie writes the session_token cookie to the response after a successful login.
// Call this immediately after a successful Authenticator.Login() call.
//
// Example:
//
// resp, err := auth.Login(r.Context(), req)
// if err != nil { ... }
// security.SetSessionCookie(w, resp)
// json.NewEncoder(w).Encode(resp)
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
maxAge := 0
if loginResp.ExpiresIn > 0 {
maxAge = int(loginResp.ExpiresIn)
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: loginResp.Token,
Path: o.path(),
Domain: o.Domain,
MaxAge: maxAge,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
//
// Example:
//
// token := security.GetSessionCookie(r)
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
cookie, err := r.Cookie(o.name())
if err != nil {
return ""
}
return cookie.Value
}
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
// Call this after a successful Authenticator.Logout() call.
//
// Example:
//
// err := auth.Logout(r.Context(), req)
// if err != nil { ... }
// security.ClearSessionCookie(w)
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: "",
Path: o.path(),
Domain: o.Domain,
MaxAge: -1,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware // GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) { func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules) rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)

View File

@@ -222,9 +222,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
if sessionToken == "" { if sessionToken == "" {
// Try cookie // Try cookie
cookie, err := r.Cookie("session_token") if token := GetSessionCookie(r); token != "" {
if err == nil { tokens = []string{token}
tokens = []string{cookie.Value}
reference = "cookie" reference = "cookie"
} }
} else { } else {