mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-29 14:04:26 +00:00
feat(preload): ✨ Add support for custom SQL joins
* Introduce SqlJoins and JoinAliases in PreloadOption. * Preserve SqlJoins and JoinAliases during filter processing. * Implement logic to apply custom SQL joins in handler. * Add tests for SqlJoins handling and join alias extraction.
This commit is contained in:
@@ -52,6 +52,10 @@ type PreloadOption struct {
|
|||||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||||
|
|
||||||
|
// Custom SQL JOINs from XFiles - used when preload needs additional joins
|
||||||
|
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses
|
||||||
|
JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation
|
||||||
}
|
}
|
||||||
|
|
||||||
type FilterOption struct {
|
type FilterOption struct {
|
||||||
|
|||||||
@@ -272,15 +272,31 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||||
|
|
||||||
|
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||||
|
filteredPreload.SqlJoins = preload.SqlJoins
|
||||||
|
filteredPreload.JoinAliases = preload.JoinAliases
|
||||||
|
|
||||||
// Filter preload filters
|
// Filter preload filters
|
||||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||||
for _, filter := range preload.Filters {
|
for _, filter := range preload.Filters {
|
||||||
if v.IsValidColumn(filter.Column) {
|
if v.IsValidColumn(filter.Column) {
|
||||||
validPreloadFilters = append(validPreloadFilters, filter)
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
|
} else {
|
||||||
|
// Check if the filter column references a joined table alias
|
||||||
|
foundJoin := false
|
||||||
|
for _, alias := range preload.JoinAliases {
|
||||||
|
if strings.Contains(filter.Column, alias) {
|
||||||
|
foundJoin = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundJoin {
|
||||||
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
filteredPreload.Filters = validPreloadFilters
|
filteredPreload.Filters = validPreloadFilters
|
||||||
|
|
||||||
// Filter preload sort columns
|
// Filter preload sort columns
|
||||||
|
|||||||
@@ -882,6 +882,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply custom SQL joins from XFiles
|
||||||
|
if len(preload.SqlJoins) > 0 {
|
||||||
|
logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation)
|
||||||
|
for _, joinClause := range preload.SqlJoins {
|
||||||
|
sq = sq.Join(joinClause)
|
||||||
|
logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply filters
|
// Apply filters
|
||||||
if len(preload.Filters) > 0 {
|
if len(preload.Filters) > 0 {
|
||||||
for _, filter := range preload.Filters {
|
for _, filter := range preload.Filters {
|
||||||
|
|||||||
@@ -1088,6 +1088,32 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transfer SqlJoins from XFiles to PreloadOption
|
||||||
|
if len(xfile.SqlJoins) > 0 {
|
||||||
|
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
|
||||||
|
for _, joinClause := range xfile.SqlJoins {
|
||||||
|
// Sanitize the join clause
|
||||||
|
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||||
|
if sanitizedJoin == "" {
|
||||||
|
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||||
|
|
||||||
|
// Extract join alias for validation
|
||||||
|
alias := extractJoinAlias(sanitizedJoin)
|
||||||
|
if alias != "" {
|
||||||
|
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||||
|
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||||
|
}
|
||||||
|
|
||||||
// Add the preload option
|
// Add the preload option
|
||||||
options.Preload = append(options.Preload, preloadOpt)
|
options.Preload = append(options.Preload, preloadOpt)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDecodeHeaderValue(t *testing.T) {
|
func TestDecodeHeaderValue(t *testing.T) {
|
||||||
@@ -37,6 +39,121 @@ func TestDecodeHeaderValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddXFilesPreload_WithSqlJoins(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
options := &ExtendedRequestOptions{
|
||||||
|
RequestOptions: common.RequestOptions{
|
||||||
|
Preload: make([]common.PreloadOption, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an XFiles with SqlJoins
|
||||||
|
xfile := &XFiles{
|
||||||
|
TableName: "users",
|
||||||
|
SqlJoins: []string{
|
||||||
|
"LEFT JOIN departments d ON d.id = users.department_id",
|
||||||
|
"INNER JOIN roles r ON r.id = users.role_id",
|
||||||
|
},
|
||||||
|
FilterFields: []struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Operator string `json:"operator"`
|
||||||
|
}{
|
||||||
|
{Field: "d.active", Value: "true", Operator: "eq"},
|
||||||
|
{Field: "r.name", Value: "admin", Operator: "eq"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the XFiles preload
|
||||||
|
handler.addXFilesPreload(xfile, options, "")
|
||||||
|
|
||||||
|
// Verify that a preload was added
|
||||||
|
if len(options.Preload) != 1 {
|
||||||
|
t.Fatalf("Expected 1 preload, got %d", len(options.Preload))
|
||||||
|
}
|
||||||
|
|
||||||
|
preload := options.Preload[0]
|
||||||
|
|
||||||
|
// Verify relation name
|
||||||
|
if preload.Relation != "users" {
|
||||||
|
t.Errorf("Expected relation 'users', got '%s'", preload.Relation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJoins were transferred
|
||||||
|
if len(preload.SqlJoins) != 2 {
|
||||||
|
t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify JoinAliases were extracted
|
||||||
|
if len(preload.JoinAliases) != 2 {
|
||||||
|
t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the aliases are correct
|
||||||
|
expectedAliases := []string{"d", "r"}
|
||||||
|
for i, expected := range expectedAliases {
|
||||||
|
if preload.JoinAliases[i] != expected {
|
||||||
|
t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify filters were added
|
||||||
|
if len(preload.Filters) != 2 {
|
||||||
|
t.Fatalf("Expected 2 filters, got %d", len(preload.Filters))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify filter columns reference joined tables
|
||||||
|
if preload.Filters[0].Column != "d.active" {
|
||||||
|
t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column)
|
||||||
|
}
|
||||||
|
if preload.Filters[1].Column != "r.name" {
|
||||||
|
t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractJoinAlias(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
joinClause string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LEFT JOIN with alias",
|
||||||
|
joinClause: "LEFT JOIN departments d ON d.id = users.department_id",
|
||||||
|
expected: "d",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INNER JOIN with AS keyword",
|
||||||
|
joinClause: "INNER JOIN users AS u ON u.id = orders.user_id",
|
||||||
|
expected: "u",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JOIN without alias",
|
||||||
|
joinClause: "JOIN roles ON roles.id = users.role_id",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Complex join with multiple conditions",
|
||||||
|
joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true",
|
||||||
|
expected: "p",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid join (no ON clause)",
|
||||||
|
joinClause: "LEFT JOIN departments",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractJoinAlias(tt.joinClause)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected alias '%s', got '%s'", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
|
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
|
||||||
// - parseSelectFields
|
// - parseSelectFields
|
||||||
// - parseFieldFilter
|
// - parseFieldFilter
|
||||||
|
|||||||
Reference in New Issue
Block a user