mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-29 14:04:26 +00:00
feat(preload): ✨ Add support for custom SQL joins
* Introduce SqlJoins and JoinAliases in PreloadOption. * Preserve SqlJoins and JoinAliases during filter processing. * Implement logic to apply custom SQL joins in handler. * Add tests for SqlJoins handling and join alias extraction.
This commit is contained in:
@@ -52,6 +52,10 @@ type PreloadOption struct {
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
|
||||
// Custom SQL JOINs from XFiles - used when preload needs additional joins
|
||||
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses
|
||||
JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation
|
||||
}
|
||||
|
||||
type FilterOption struct {
|
||||
|
||||
@@ -272,13 +272,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||
|
||||
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||
filteredPreload.SqlJoins = preload.SqlJoins
|
||||
filteredPreload.JoinAliases = preload.JoinAliases
|
||||
|
||||
// Filter preload filters
|
||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||
for _, filter := range preload.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||
// Check if the filter column references a joined table alias
|
||||
foundJoin := false
|
||||
for _, alias := range preload.JoinAliases {
|
||||
if strings.Contains(filter.Column, alias) {
|
||||
foundJoin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundJoin {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||
}
|
||||
}
|
||||
}
|
||||
filteredPreload.Filters = validPreloadFilters
|
||||
|
||||
@@ -882,6 +882,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL joins from XFiles
|
||||
if len(preload.SqlJoins) > 0 {
|
||||
logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation)
|
||||
for _, joinClause := range preload.SqlJoins {
|
||||
sq = sq.Join(joinClause)
|
||||
logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
if len(preload.Filters) > 0 {
|
||||
for _, filter := range preload.Filters {
|
||||
|
||||
@@ -1088,6 +1088,32 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||
}
|
||||
|
||||
// Transfer SqlJoins from XFiles to PreloadOption
|
||||
if len(xfile.SqlJoins) > 0 {
|
||||
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||
|
||||
for _, joinClause := range xfile.SqlJoins {
|
||||
// Sanitize the join clause
|
||||
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||
if sanitizedJoin == "" {
|
||||
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||
continue
|
||||
}
|
||||
|
||||
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||
|
||||
// Extract join alias for validation
|
||||
alias := extractJoinAlias(sanitizedJoin)
|
||||
if alias != "" {
|
||||
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||
}
|
||||
|
||||
// Add the preload option
|
||||
options.Preload = append(options.Preload, preloadOpt)
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestDecodeHeaderValue(t *testing.T) {
|
||||
@@ -37,6 +39,121 @@ func TestDecodeHeaderValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddXFilesPreload_WithSqlJoins(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
options := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Preload: make([]common.PreloadOption, 0),
|
||||
},
|
||||
}
|
||||
|
||||
// Create an XFiles with SqlJoins
|
||||
xfile := &XFiles{
|
||||
TableName: "users",
|
||||
SqlJoins: []string{
|
||||
"LEFT JOIN departments d ON d.id = users.department_id",
|
||||
"INNER JOIN roles r ON r.id = users.role_id",
|
||||
},
|
||||
FilterFields: []struct {
|
||||
Field string `json:"field"`
|
||||
Value string `json:"value"`
|
||||
Operator string `json:"operator"`
|
||||
}{
|
||||
{Field: "d.active", Value: "true", Operator: "eq"},
|
||||
{Field: "r.name", Value: "admin", Operator: "eq"},
|
||||
},
|
||||
}
|
||||
|
||||
// Add the XFiles preload
|
||||
handler.addXFilesPreload(xfile, options, "")
|
||||
|
||||
// Verify that a preload was added
|
||||
if len(options.Preload) != 1 {
|
||||
t.Fatalf("Expected 1 preload, got %d", len(options.Preload))
|
||||
}
|
||||
|
||||
preload := options.Preload[0]
|
||||
|
||||
// Verify relation name
|
||||
if preload.Relation != "users" {
|
||||
t.Errorf("Expected relation 'users', got '%s'", preload.Relation)
|
||||
}
|
||||
|
||||
// Verify SqlJoins were transferred
|
||||
if len(preload.SqlJoins) != 2 {
|
||||
t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins))
|
||||
}
|
||||
|
||||
// Verify JoinAliases were extracted
|
||||
if len(preload.JoinAliases) != 2 {
|
||||
t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases))
|
||||
}
|
||||
|
||||
// Verify the aliases are correct
|
||||
expectedAliases := []string{"d", "r"}
|
||||
for i, expected := range expectedAliases {
|
||||
if preload.JoinAliases[i] != expected {
|
||||
t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify filters were added
|
||||
if len(preload.Filters) != 2 {
|
||||
t.Fatalf("Expected 2 filters, got %d", len(preload.Filters))
|
||||
}
|
||||
|
||||
// Verify filter columns reference joined tables
|
||||
if preload.Filters[0].Column != "d.active" {
|
||||
t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column)
|
||||
}
|
||||
if preload.Filters[1].Column != "r.name" {
|
||||
t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJoinAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
joinClause string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN with alias",
|
||||
joinClause: "LEFT JOIN departments d ON d.id = users.department_id",
|
||||
expected: "d",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN with AS keyword",
|
||||
joinClause: "INNER JOIN users AS u ON u.id = orders.user_id",
|
||||
expected: "u",
|
||||
},
|
||||
{
|
||||
name: "JOIN without alias",
|
||||
joinClause: "JOIN roles ON roles.id = users.role_id",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Complex join with multiple conditions",
|
||||
joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true",
|
||||
expected: "p",
|
||||
},
|
||||
{
|
||||
name: "Invalid join (no ON clause)",
|
||||
joinClause: "LEFT JOIN departments",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractJoinAlias(tt.joinClause)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected alias '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
|
||||
// - parseSelectFields
|
||||
// - parseFieldFilter
|
||||
|
||||
Reference in New Issue
Block a user