Compare commits

...

7 Commits

Author SHA1 Message Date
Hein
22a4ab345a feat(security): add session cookie management functions
* Introduce SessionCookieOptions for configurable session cookies
* Implement SetSessionCookie, GetSessionCookie, and ClearSessionCookie functions
* Enhance cookie handling in DatabaseAuthenticator
2026-03-24 17:11:53 +02:00
Hein
e289c2ed8f fix(handler): restore JoinAliases for proper WHERE sanitization 2026-03-24 12:00:02 +02:00
Hein
0d50bcfee6 fix(provider): enhance file opening logic with alternate path. Handling broken cases to be compatible with Bitech clients
* Implemented alternate path handling for file retrieval
* Improved error messaging for file not found scenarios
2026-03-24 09:02:17 +02:00
4df626ea71 chore(license): update project notice and clarify licensing terms 2026-03-23 20:32:09 +02:00
Hein
7dd630dec2 fix(handler): set default sort to primary key if none provided
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m15s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m11s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m52s
Build , Vet Test, and Lint / Build (push) Successful in -30m44s
Tests / Integration Tests (push) Failing after -31m5s
Tests / Unit Tests (push) Successful in -29m6s
2026-03-11 14:37:04 +02:00
Hein
613bf22cbd fix(cursor): use full schema-qualified table name in filters 2026-03-11 14:25:44 +02:00
d1ae4fe64e refactor(handler): unify filter operator handling for consistency
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m26s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m58s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m48s
Build , Vet Test, and Lint / Build (push) Successful in -30m4s
Tests / Integration Tests (push) Failing after -30m39s
Tests / Unit Tests (push) Successful in -30m29s
2026-03-01 13:21:38 +02:00
13 changed files with 263 additions and 70 deletions

15
LICENSE
View File

@@ -1,3 +1,18 @@
Project Notice
This project was independently developed.
The contents of this repository were prepared and published outside any time
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
or rely upon any proprietary or confidential information, trade secrets,
protected designs, or other intellectual property of Bitech Systems CC.
No portion of this repository reproduces any Bitech Systems CC-specific
implementation, design asset, confidential workflow, or non-public technical material.
This notice is provided for clarification only and does not modify the terms of
the Apache License, Version 2.0.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

View File

@@ -32,7 +32,8 @@ func GetCursorFilter(
modelColumns []string,
options common.RequestOptions,
) (string, error) {
// Remove schema prefix if present
// Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
@@ -115,7 +116,7 @@ func GetCursorFilter(
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
fullTableName,
pkName,
cursorID,
orSQL,

View File

@@ -175,9 +175,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
// Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "public.users") {
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)

View File

@@ -329,6 +329,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation
modelColumns := reflection.GetModelColumns(model)
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
@@ -1521,22 +1526,22 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
var args []interface{}
switch filter.Operator {
case "eq":
case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq":
case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt":
case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte":
case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt":
case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte":
case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like":
@@ -1565,22 +1570,22 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
var args []interface{}
switch filter.Operator {
case "eq":
case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq":
case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt":
case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte":
case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt":
case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte":
case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like":

View File

@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
modelColumns []string, // optional: for validation
expandJoins map[string]string, // optional: alias → JOIN SQL
) (string, error) {
// Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
@@ -127,7 +129,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
fullTableName,
joinSQL,
pkName,
cursorID,

View File

@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
// Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "public.users") {
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)

View File

@@ -731,6 +731,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// For now, pass empty map as joins are handled via Preload
}
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
@@ -2226,17 +2231,17 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
// buildFilterCondition builds a single filter condition and returns the condition string and args
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
switch strings.ToLower(filter.Operator) {
case "eq", "equals":
case "eq", "equals", "=":
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
case "neq", "not_equals", "ne":
case "neq", "not_equals", "ne", "!=", "<>":
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
case "gt", "greater_than":
case "gt", "greater_than", ">":
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
case "gte", "greater_than_equals", "ge":
case "gte", "greater_than_equals", "ge", ">=":
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
case "lt", "less_than":
case "lt", "less_than", "<":
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
case "lte", "less_than_equals", "le":
case "lte", "less_than_equals", "le", "<=":
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
@@ -2879,6 +2884,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
// Filter base RequestOptions
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
filtered.RequestOptions.JoinAliases = options.JoinAliases
// Filter SearchColumns
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)

View File

@@ -1061,15 +1061,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
if len(xfile.SqlJoins) > 0 {
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
for _, joinClause := range xfile.SqlJoins {
// Sanitize the join clause
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
if sanitizedJoin == "" {
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
continue
}
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
// Extract join alias for validation
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
}
}
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
}
// Add WHERE clause if SQL conditions specified
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 {
// Process each SQL condition
// Note: We don't add table prefixes here because they're only needed for JOINs
// The handler will add prefixes later if SqlJoins are present
var sqlAndOpts *common.RequestOptions
if len(preloadOpt.JoinAliases) > 0 {
sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
}
for _, sqlCond := range xfile.SqlAnd {
// Sanitize the condition without adding prefixes
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
if sanitizedCond != "" {
whereConditions = append(whereConditions, sanitizedCond)
}
@@ -1114,32 +1141,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
}
// Transfer SqlJoins from XFiles to PreloadOption
if len(xfile.SqlJoins) > 0 {
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
for _, joinClause := range xfile.SqlJoins {
// Sanitize the join clause
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
if sanitizedJoin == "" {
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
continue
}
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
// Extract join alias for validation
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
}
}
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
}
// Check if this table has a recursive child - if so, mark THIS preload as recursive
// and store the recursive child's RelatedKey for recursion generation
hasRecursiveChild := false

View File

@@ -456,6 +456,125 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
return meta, ok
}
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
// All fields are optional; sensible secure defaults are applied when omitted.
type SessionCookieOptions struct {
// Name is the cookie name. Defaults to "session_token".
Name string
// Path is the cookie path. Defaults to "/".
Path string
// Domain restricts the cookie to a specific domain. Empty means current host.
Domain string
// Secure sets the Secure flag. Defaults to true.
// Set to false only in local development over HTTP.
Secure *bool
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
SameSite http.SameSite
}
func (o SessionCookieOptions) name() string {
if o.Name != "" {
return o.Name
}
return "session_token"
}
func (o SessionCookieOptions) path() string {
if o.Path != "" {
return o.Path
}
return "/"
}
func (o SessionCookieOptions) secure() bool {
if o.Secure != nil {
return *o.Secure
}
return true
}
func (o SessionCookieOptions) sameSite() http.SameSite {
if o.SameSite != 0 {
return o.SameSite
}
return http.SameSiteLaxMode
}
// SetSessionCookie writes the session_token cookie to the response after a successful login.
// Call this immediately after a successful Authenticator.Login() call.
//
// Example:
//
// resp, err := auth.Login(r.Context(), req)
// if err != nil { ... }
// security.SetSessionCookie(w, resp)
// json.NewEncoder(w).Encode(resp)
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
maxAge := 0
if loginResp.ExpiresIn > 0 {
maxAge = int(loginResp.ExpiresIn)
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: loginResp.Token,
Path: o.path(),
Domain: o.Domain,
MaxAge: maxAge,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
//
// Example:
//
// token := security.GetSessionCookie(r)
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
cookie, err := r.Cookie(o.name())
if err != nil {
return ""
}
return cookie.Value
}
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
// Call this after a successful Authenticator.Logout() call.
//
// Example:
//
// err := auth.Logout(r.Context(), req)
// if err != nil { ... }
// security.ClearSessionCookie(w)
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: "",
Path: o.path(),
Domain: o.Domain,
MaxAge: -1,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)

View File

@@ -222,9 +222,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
if sessionToken == "" {
// Try cookie
cookie, err := r.Cookie("session_token")
if err == nil {
tokens = []string{cookie.Value}
if token := GetSessionCookie(r); token != "" {
tokens = []string{token}
reference = "cookie"
}
} else {

View File

@@ -98,6 +98,7 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path
actualPath := name
alternatePath := ""
if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/")
@@ -105,12 +106,25 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
if prefix != "" {
actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else {
actualPath = cleanName
}
}
// First try the actual path with prefix
if file, err := p.fs.Open(actualPath); err == nil {
return file, nil
}
return p.fs.Open(actualPath)
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.fs.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
}
// Close releases any resources held by the provider.

View File

@@ -53,6 +53,7 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path
actualPath := name
alternatePath := ""
if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/")
@@ -60,12 +61,26 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
if prefix != "" {
actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else {
actualPath = cleanName
}
}
return p.fs.Open(actualPath)
// First try the actual path with prefix
if file, err := p.fs.Open(actualPath); err == nil {
return file, nil
}
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.fs.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
}
// Close releases any resources held by the provider.

View File

@@ -56,6 +56,7 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path
actualPath := name
alternatePath := ""
if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/")
@@ -63,12 +64,26 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
if prefix != "" {
actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else {
actualPath = cleanName
}
}
return p.zipFS.Open(actualPath)
// First try the actual path with prefix
if file, err := p.zipFS.Open(actualPath); err == nil {
return file, nil
}
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.zipFS.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
}
// Close releases resources held by the zip reader.