package security import ( "context" "fmt" "net/http" "reflect" "strings" "sync" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) type ColumnSecurity struct { Schema string Tablename string Path []string ExtraFilters map[string]string UserID int Accesstype string `json:"accesstype"` MaskStart int MaskEnd int MaskInvert bool MaskChar string Control string `json:"control"` ID int `json:"id"` } type RowSecurity struct { Schema string Tablename string Template string HasBlock bool UserID int } func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Type) string { str := m.Template str = strings.ReplaceAll(str, "{PrimaryKeyName}", pPrimaryKeyName) str = strings.ReplaceAll(str, "{TableName}", m.Tablename) str = strings.ReplaceAll(str, "{SchemaName}", m.Schema) str = strings.ReplaceAll(str, "{UserID}", fmt.Sprintf("%d", m.UserID)) return str } // Callback function types for customizing security behavior type ( // AuthenticateFunc extracts user ID and roles from HTTP request // Return userID, roles, error. If error is not nil, request will be rejected. AuthenticateFunc func(r *http.Request) (userID int, roles string, err error) // LoadColumnSecurityFunc loads column security rules for a user and entity // Override this to customize how column security is loaded from your data source LoadColumnSecurityFunc func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) // LoadRowSecurityFunc loads row security rules for a user and entity // Override this to customize how row security is loaded from your data source LoadRowSecurityFunc func(pUserID int, pSchema, pTablename string) (RowSecurity, error) ) type SecurityList struct { ColumnSecurityMutex sync.RWMutex ColumnSecurity map[string][]ColumnSecurity RowSecurityMutex sync.RWMutex RowSecurity map[string]RowSecurity // Overridable callbacks AuthenticateCallback AuthenticateFunc LoadColumnSecurityCallback LoadColumnSecurityFunc LoadRowSecurityCallback LoadRowSecurityFunc } type CONTEXT_KEY string const SECURITY_CONTEXT_KEY CONTEXT_KEY = "SecurityList" var GlobalSecurity SecurityList // SetSecurityMiddleware adds security context to requests func SetSecurityMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, &GlobalSecurity) next.ServeHTTP(w, r.WithContext(ctx)) }) } func maskString(pString string, maskStart, maskEnd int, maskChar string, invert bool) string { strLen := len(pString) middleIndex := (strLen / 2) newStr := "" if maskStart == 0 && maskEnd == 0 { maskStart = strLen maskEnd = strLen } if maskEnd > strLen { maskEnd = strLen } if maskStart > strLen { maskStart = strLen } if maskChar == "" { maskChar = "*" } for index, char := range pString { if invert && index >= middleIndex-maskStart && index <= middleIndex { newStr += maskChar continue } if invert && index <= middleIndex+maskEnd && index >= middleIndex { newStr += maskChar continue } if !invert && index <= maskStart { newStr += maskChar continue } if !invert && index >= strLen-1-maskEnd { newStr += maskChar continue } newStr += string(char) } return newStr } func (m *SecurityList) ColumSecurityApplyOnRecord(prevRecord reflect.Value, newRecord reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) ([]string, error) { cols := make([]string, 0) if m.ColumnSecurity == nil { return cols, fmt.Errorf("security not initialized") } if prevRecord.Type() != newRecord.Type() { logger.Error("prev:%s and new:%s record type mismatch", prevRecord.Type(), newRecord.Type()) return cols, fmt.Errorf("prev and new record type mismatch") } m.ColumnSecurityMutex.RLock() defer m.ColumnSecurityMutex.RUnlock() colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)] if !ok || colsecList == nil { return cols, fmt.Errorf("no security data") } for i := range colsecList { colsec := &colsecList[i] if !strings.EqualFold(colsec.Accesstype, "mask") && !strings.EqualFold(colsec.Accesstype, "hide") { continue } lastRecords := interateStruct(prevRecord) newRecords := interateStruct(newRecord) var lastLoopField, lastLoopNewField reflect.Value pathLen := len(colsec.Path) for i, path := range colsec.Path { var nameType, fieldName string if len(newRecords) == 0 { if lastLoopNewField.IsValid() && lastLoopField.IsValid() && i < pathLen-1 { lastLoopNewField.Set(lastLoopField) } break } for ri := range newRecords { if !newRecords[ri].IsValid() || !lastRecords[ri].IsValid() { break } var field, oldField reflect.Value columnData := reflection.GetModelColumnDetail(newRecords[ri]) lastColumnData := reflection.GetModelColumnDetail(lastRecords[ri]) for i, cols := range columnData { if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) { nameType = "sql" fieldName = cols.SQLName field = cols.FieldValue oldField = lastColumnData[i].FieldValue break } if cols.Name != "" && strings.EqualFold(cols.Name, path) { nameType = "struct" fieldName = cols.Name field = cols.FieldValue oldField = lastColumnData[i].FieldValue break } } if !field.IsValid() || !oldField.IsValid() { break } lastLoopField = oldField lastLoopNewField = field if i == pathLen-1 { if strings.Contains(strings.ToLower(fieldName), "json") { prevSrc := oldField.Bytes() newSrc := field.Bytes() pathstr := strings.Join(colsec.Path, ".") prevPathValue := gjson.GetBytes(prevSrc, pathstr) newBytes, err := sjson.SetBytes(newSrc, pathstr, prevPathValue.Str) if err == nil { if field.CanSet() { field.SetBytes(newBytes) } else { logger.Warn("Value not settable: %v", field) cols = append(cols, pathstr) } } break } if nameType == "sql" { if strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide") { field.Set(oldField) cols = append(cols, strings.Join(colsec.Path, ".")) } } break } lastRecords = interateStruct(field) newRecords = interateStruct(oldField) } } } return cols, nil } func interateStruct(val reflect.Value) []reflect.Value { list := make([]reflect.Value, 0) switch val.Kind() { case reflect.Pointer, reflect.Interface: elem := val.Elem() if elem.IsValid() { list = append(list, interateStruct(elem)...) } return list case reflect.Array, reflect.Slice: for i := 0; i < val.Len(); i++ { elem := val.Index(i) if !elem.IsValid() { continue } list = append(list, interateStruct(elem)...) } return list case reflect.Struct: list = append(list, val) return list default: return list } } func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName string) (int, reflect.Value) { fieldval := fieldsrc if fieldsrc.Kind() == reflect.Pointer || fieldsrc.Kind() == reflect.Interface { fieldval = fieldval.Elem() } fieldKindLower := strings.ToLower(fieldval.Kind().String()) switch { case strings.Contains(fieldKindLower, "int") && (strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")): if fieldval.CanInt() && fieldval.CanSet() { fieldval.SetInt(0) } case (strings.Contains(fieldKindLower, "time") || strings.Contains(fieldKindLower, "date")) && (strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")): fieldval.SetZero() case strings.Contains(fieldKindLower, "string"): strVal := fieldval.String() if strings.EqualFold(colsec.Accesstype, "mask") { fieldval.SetString(maskString(strVal, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert)) } else if strings.EqualFold(colsec.Accesstype, "hide") { fieldval.SetString("") } case strings.Contains(fieldTypeName, "json") && (strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")): if len(colsec.Path) < 2 { return 1, fieldval } pathstr := strings.Join(colsec.Path, ".") src := fieldval.Bytes() pathValue := gjson.GetBytes(src, pathstr) strValue := pathValue.String() if strings.EqualFold(colsec.Accesstype, "mask") { strValue = maskString(strValue, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert) } else if strings.EqualFold(colsec.Accesstype, "hide") { strValue = "" } newBytes, err := sjson.SetBytes(src, pathstr, strValue) if err == nil { fieldval.SetBytes(newBytes) } } return 0, fieldsrc } func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (reflect.Value, error) { defer logger.CatchPanic("ApplyColumnSecurity") if m.ColumnSecurity == nil { return records, fmt.Errorf("security not initialized") } m.ColumnSecurityMutex.RLock() defer m.ColumnSecurityMutex.RUnlock() colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)] if !ok || colsecList == nil { return records, fmt.Errorf("no security data") } for i := range colsecList { colsec := &colsecList[i] if !strings.EqualFold(colsec.Accesstype, "mask") && !strings.EqualFold(colsec.Accesstype, "hide") { continue } if records.Kind() == reflect.Array || records.Kind() == reflect.Slice { for i := 0; i < records.Len(); i++ { record := records.Index(i) if !record.IsValid() { continue } lastRecord := interateStruct(record) pathLen := len(colsec.Path) for i, path := range colsec.Path { var field reflect.Value var nameType, fieldName string if len(lastRecord) == 0 { break } columnData := reflection.GetModelColumnDetail(lastRecord[0]) for _, cols := range columnData { if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) { nameType = "sql" fieldName = cols.SQLName field = cols.FieldValue break } if cols.Name != "" && strings.EqualFold(cols.Name, path) { nameType = "struct" fieldName = cols.Name field = cols.FieldValue break } } if i == pathLen-1 { if nameType == "sql" || nameType == "struct" { setColSecValue(field, *colsec, fieldName) } break } if field.IsValid() { lastRecord = interateStruct(field) } } } } } return records, nil } func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) error { // Use the callback if provided if m.LoadColumnSecurityCallback == nil { return fmt.Errorf("LoadColumnSecurityCallback not set - you must provide a callback function") } m.ColumnSecurityMutex.Lock() defer m.ColumnSecurityMutex.Unlock() if m.ColumnSecurity == nil { m.ColumnSecurity = make(map[string][]ColumnSecurity, 0) } secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID) if pOverwrite || m.ColumnSecurity[secKey] == nil { m.ColumnSecurity[secKey] = make([]ColumnSecurity, 0) } // Call the user-provided callback to load security rules colSecList, err := m.LoadColumnSecurityCallback(pUserID, pSchema, pTablename) if err != nil { return fmt.Errorf("LoadColumnSecurityCallback failed: %v", err) } m.ColumnSecurity[secKey] = colSecList return nil } func (m *SecurityList) ClearSecurity(pUserID int, pSchema, pTablename string) error { var filtered []ColumnSecurity m.ColumnSecurityMutex.Lock() defer m.ColumnSecurityMutex.Unlock() secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID) list, ok := m.ColumnSecurity[secKey] if !ok { return nil } for i := range list { cs := &list[i] if cs.Schema != pSchema && cs.Tablename != pTablename && cs.UserID != pUserID { filtered = append(filtered, *cs) } } m.ColumnSecurity[secKey] = filtered return nil } func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) { // Use the callback if provided if m.LoadRowSecurityCallback == nil { return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback not set - you must provide a callback function") } m.RowSecurityMutex.Lock() defer m.RowSecurityMutex.Unlock() if m.RowSecurity == nil { m.RowSecurity = make(map[string]RowSecurity, 0) } secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID) // Call the user-provided callback to load security rules record, err := m.LoadRowSecurityCallback(pUserID, pSchema, pTablename) if err != nil { return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback failed: %v", err) } m.RowSecurity[secKey] = record return record, nil } func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) { defer logger.CatchPanic("GetRowSecurityTemplate") if m.RowSecurity == nil { return RowSecurity{}, fmt.Errorf("security not initialized") } m.RowSecurityMutex.RLock() defer m.RowSecurityMutex.RUnlock() rowSec, ok := m.RowSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)] if !ok { return RowSecurity{}, fmt.Errorf("no security data") } return rowSec, nil }