Added cursor filters and hooks

This commit is contained in:
Hein
2025-11-10 10:22:55 +02:00
parent fc82a9bc50
commit c8704c07dd
8 changed files with 1487 additions and 5 deletions

View File

@@ -1,6 +1,11 @@
package database
import "strings"
import (
"reflect"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
)
// parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users")
@@ -11,3 +16,157 @@ func parseTableName(fullTableName string) (schema, table string) {
}
return "", fullTableName
}
// GetPrimaryKeyName extracts the primary key column name from a model
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
func GetPrimaryKeyName(model any) string {
// Check if model implements PrimaryKeyNameProvider
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
return provider.GetIDName()
}
// Try Bun tag first
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
return pkName
}
// Fall back to GORM tag
return getPrimaryKeyFromReflection(model, "gorm")
}
// GetModelColumns extracts all column names from a model using reflection
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
func GetModelColumns(model any) []string {
var columns []string
modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return columns
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Get column name using the same logic as primary key extraction
columnName := getColumnNameFromField(field)
if columnName != "" {
columns = append(columns, columnName)
}
}
return columns
}
// getColumnNameFromField extracts the column name from a struct field
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
func getColumnNameFromField(field reflect.StructField) string {
// Try bun tag first
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
if colName := extractColumnFromBunTag(bunTag); colName != "" {
return colName
}
}
// Try gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" && gormTag != "-" {
if colName := extractColumnFromGormTag(gormTag); colName != "" {
return colName
}
}
// Fall back to json tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract just the field name before any options
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
}
// Last resort: use field name in lowercase
return strings.ToLower(field.Name)
}
// getPrimaryKeyFromReflection uses reflection to find the primary key field
func getPrimaryKeyFromReflection(model any, ormType string) string {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return ""
}
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
switch ormType {
case "gorm":
// Check for gorm tag with primaryKey
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") {
// Try to extract column name from gorm tag
if colName := extractColumnFromGormTag(gormTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
case "bun":
// Check for bun tag with pk flag
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") {
// Extract column name from bun tag
if colName := extractColumnFromBunTag(bunTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
}
}
return ""
}
// extractColumnFromGormTag extracts the column name from a gorm tag
// Example: "column:id;primaryKey" -> "id"
func extractColumnFromGormTag(tag string) string {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if colName, found := strings.CutPrefix(part, "column:"); found {
return colName
}
}
return ""
}
// extractColumnFromBunTag extracts the column name from a bun tag
// Example: "id,pk" -> "id"
// Example: ",pk" -> "" (will fall back to json tag)
func extractColumnFromBunTag(tag string) string {
parts := strings.Split(tag, ",")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}

View File

@@ -0,0 +1,233 @@
package database
import (
"testing"
)
// Test models for GORM
type GormModelWithGetIDName struct {
ID int `gorm:"column:rid_test;primaryKey" json:"id"`
Name string `json:"name"`
}
func (m GormModelWithGetIDName) GetIDName() string {
return "rid_test"
}
type GormModelWithColumnTag struct {
ID int `gorm:"column:custom_id;primaryKey" json:"id"`
Name string `json:"name"`
}
type GormModelWithJSONFallback struct {
ID int `gorm:"primaryKey" json:"user_id"`
Name string `json:"name"`
}
// Test models for Bun
type BunModelWithGetIDName struct {
ID int `bun:"rid_test,pk" json:"id"`
Name string `json:"name"`
}
func (m BunModelWithGetIDName) GetIDName() string {
return "rid_test"
}
type BunModelWithColumnTag struct {
ID int `bun:"custom_id,pk" json:"id"`
Name string `json:"name"`
}
type BunModelWithJSONFallback struct {
ID int `bun:",pk" json:"user_id"`
Name string `json:"name"`
}
func TestGetPrimaryKeyName(t *testing.T) {
tests := []struct {
name string
model any
expected string
}{
{
name: "GORM model with GetIDName method",
model: GormModelWithGetIDName{},
expected: "rid_test",
},
{
name: "GORM model with column tag",
model: GormModelWithColumnTag{},
expected: "custom_id",
},
{
name: "GORM model with JSON fallback",
model: GormModelWithJSONFallback{},
expected: "user_id",
},
{
name: "GORM model pointer with GetIDName",
model: &GormModelWithGetIDName{},
expected: "rid_test",
},
{
name: "GORM model pointer with column tag",
model: &GormModelWithColumnTag{},
expected: "custom_id",
},
{
name: "Bun model with GetIDName method",
model: BunModelWithGetIDName{},
expected: "rid_test",
},
{
name: "Bun model with column tag",
model: BunModelWithColumnTag{},
expected: "custom_id",
},
{
name: "Bun model with JSON fallback",
model: BunModelWithJSONFallback{},
expected: "user_id",
},
{
name: "Bun model pointer with GetIDName",
model: &BunModelWithGetIDName{},
expected: "rid_test",
},
{
name: "Bun model pointer with column tag",
model: &BunModelWithColumnTag{},
expected: "custom_id",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyName(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExtractColumnFromGormTag(t *testing.T) {
tests := []struct {
name string
tag string
expected string
}{
{
name: "column tag with primaryKey",
tag: "column:rid_test;primaryKey",
expected: "rid_test",
},
{
name: "column tag with spaces",
tag: "column:user_id ; primaryKey ; autoIncrement",
expected: "user_id",
},
{
name: "no column tag",
tag: "primaryKey;autoIncrement",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractColumnFromGormTag(tt.tag)
if result != tt.expected {
t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExtractColumnFromBunTag(t *testing.T) {
tests := []struct {
name string
tag string
expected string
}{
{
name: "column name with pk flag",
tag: "rid_test,pk",
expected: "rid_test",
},
{
name: "only pk flag",
tag: ",pk",
expected: "",
},
{
name: "column with multiple flags",
tag: "user_id,pk,autoincrement",
expected: "user_id",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractColumnFromBunTag(tt.tag)
if result != tt.expected {
t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetModelColumns(t *testing.T) {
tests := []struct {
name string
model any
expected []string
}{
{
name: "Bun model with multiple columns",
model: BunModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "GORM model with multiple columns",
model: GormModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "Bun model pointer",
model: &BunModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "GORM model pointer",
model: &GormModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "Bun model with JSON fallback",
model: BunModelWithJSONFallback{},
expected: []string{"user_id", "name"},
},
{
name: "GORM model with JSON fallback",
model: GormModelWithJSONFallback{},
expected: []string{"user_id", "name"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetModelColumns(tt.model)
if len(result) != len(tt.expected) {
t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected))
return
}
for i, col := range result {
if col != tt.expected[i] {
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
}
}
})
}
}

View File

@@ -131,6 +131,11 @@ type TableNameProvider interface {
TableName() string
}
// PrimaryKeyNameProvider interface for models that provide primary key column names
type PrimaryKeyNameProvider interface {
GetIDName() string
}
// SchemaProvider interface for models that provide schema names
type SchemaProvider interface {
SchemaName() string