package gorm import ( "fmt" "go/ast" "go/parser" "go/token" "os" "path/filepath" "reflect" "strconv" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) // Reader implements the readers.Reader interface for GORM Go model files type Reader struct { options *readers.ReaderOptions } // NewReader creates a new GORM reader with the given options func NewReader(options *readers.ReaderOptions) *Reader { return &Reader{ options: options, } } // ReadDatabase reads GORM Go model files and returns a Database model func (r *Reader) ReadDatabase() (*models.Database, error) { if r.options.FilePath == "" { return nil, fmt.Errorf("file path is required for GORM reader") } // Check if path is a directory or file info, err := os.Stat(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to stat path: %w", err) } var files []string if info.IsDir() { // Read all .go files in directory entries, err := os.ReadDir(r.options.FilePath) if err != nil { return nil, fmt.Errorf("failed to read directory: %w", err) } for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") { files = append(files, filepath.Join(r.options.FilePath, entry.Name())) } } } else { files = append(files, r.options.FilePath) } if len(files) == 0 { return nil, fmt.Errorf("no Go files found") } // Parse all files and collect tables db := models.InitDatabase("database") schemaMap := make(map[string]*models.Schema) for _, file := range files { tables, err := r.parseFile(file) if err != nil { return nil, fmt.Errorf("failed to parse file %s: %w", file, err) } for _, table := range tables { // Get or create schema schema, ok := schemaMap[table.Schema] if !ok { schema = models.InitSchema(table.Schema) schemaMap[table.Schema] = schema } schema.Tables = append(schema.Tables, table) } } // Convert schema map to slice for _, schema := range schemaMap { db.Schemas = append(db.Schemas, schema) } return db, nil } // ReadSchema reads GORM Go model files and returns a Schema model func (r *Reader) ReadSchema() (*models.Schema, error) { db, err := r.ReadDatabase() if err != nil { return nil, err } if len(db.Schemas) == 0 { return nil, fmt.Errorf("no schemas found") } return db.Schemas[0], nil } // ReadTable reads a GORM Go model file and returns a Table model func (r *Reader) ReadTable() (*models.Table, error) { schema, err := r.ReadSchema() if err != nil { return nil, err } if len(schema.Tables) == 0 { return nil, fmt.Errorf("no tables found") } return schema.Tables[0], nil } // parseFile parses a single Go file and extracts table models func (r *Reader) parseFile(filename string) ([]*models.Table, error) { fset := token.NewFileSet() node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) if err != nil { return nil, fmt.Errorf("failed to parse Go file: %w", err) } var tables []*models.Table structMap := make(map[string]*models.Table) // First pass: collect struct definitions for _, decl := range node.Decls { genDecl, ok := decl.(*ast.GenDecl) if !ok || genDecl.Tok != token.TYPE { continue } for _, spec := range genDecl.Specs { typeSpec, ok := spec.(*ast.TypeSpec) if !ok { continue } structType, ok := typeSpec.Type.(*ast.StructType) if !ok { continue } // Check if this struct has gorm tags (indicates it's a model) if r.hasModelFields(structType) { table := r.parseStruct(typeSpec.Name.Name, structType) if table != nil { structMap[typeSpec.Name.Name] = table tables = append(tables, table) } } } } // Second pass: find TableName() methods for _, decl := range node.Decls { funcDecl, ok := decl.(*ast.FuncDecl) if !ok || funcDecl.Name.Name != "TableName" { continue } // Get receiver type if funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 { continue } receiverType := r.getReceiverType(funcDecl.Recv.List[0].Type) if receiverType == "" { continue } // Find the table for this struct table, ok := structMap[receiverType] if !ok { continue } // Parse the return value tableName, schemaName := r.parseTableNameMethod(funcDecl) if tableName != "" { table.Name = tableName if schemaName != "" { table.Schema = schemaName } // Update columns and indexes for _, col := range table.Columns { col.Table = tableName col.Schema = table.Schema } for _, idx := range table.Indexes { idx.Table = tableName idx.Schema = table.Schema } } } // Third pass: parse relationship fields for constraints // Re-parse the file to get relationship information for _, decl := range node.Decls { genDecl, ok := decl.(*ast.GenDecl) if !ok || genDecl.Tok != token.TYPE { continue } for _, spec := range genDecl.Specs { typeSpec, ok := spec.(*ast.TypeSpec) if !ok { continue } structType, ok := typeSpec.Type.(*ast.StructType) if !ok { continue } table, ok := structMap[typeSpec.Name.Name] if !ok { continue } // Parse relationship fields r.parseRelationshipConstraints(table, structType, structMap) } } return tables, nil } // getReceiverType extracts the type name from a receiver func (r *Reader) getReceiverType(expr ast.Expr) string { switch t := expr.(type) { case *ast.Ident: return t.Name case *ast.StarExpr: if ident, ok := t.X.(*ast.Ident); ok { return ident.Name } } return "" } // parseTableNameMethod parses a TableName() method and extracts the table and schema name func (r *Reader) parseTableNameMethod(funcDecl *ast.FuncDecl) (tableName string, schemaName string) { if funcDecl.Body == nil { return "", "" } // Look for return statement for _, stmt := range funcDecl.Body.List { retStmt, ok := stmt.(*ast.ReturnStmt) if !ok { continue } if len(retStmt.Results) == 0 { continue } // Get the return value (should be a string literal) if basicLit, ok := retStmt.Results[0].(*ast.BasicLit); ok { if basicLit.Kind == token.STRING { // Remove quotes fullName := strings.Trim(basicLit.Value, "\"") // Split schema.table if strings.Contains(fullName, ".") { parts := strings.SplitN(fullName, ".", 2) return parts[1], parts[0] } return fullName, "public" } } } return "", "" } // hasModelFields checks if the struct has fields with gorm tags func (r *Reader) hasModelFields(structType *ast.StructType) bool { for _, field := range structType.Fields.List { if field.Tag != nil { tag := field.Tag.Value if strings.Contains(tag, "gorm:") { return true } } } return false } // parseStruct converts an AST struct to a Table model func (r *Reader) parseStruct(structName string, structType *ast.StructType) *models.Table { tableName := r.deriveTableName(structName) schemaName := "public" table := models.InitTable(tableName, schemaName) sequence := uint(1) // Parse fields for _, field := range structType.Fields.List { if field.Tag == nil { continue } tag := field.Tag.Value if !strings.Contains(tag, "gorm:") { continue } // Skip embedded GORM model if r.isGORMModel(field) { continue } // Parse relationship fields for foreign key constraints if r.isRelationship(tag) { // We'll parse constraints in a second pass after we know all table names continue } // Get field name fieldName := "" if len(field.Names) > 0 { fieldName = field.Names[0].Name } // Parse column from tag column, inlineRef := r.parseColumn(fieldName, field.Type, tag, sequence) if column != nil { // Extract schema and table name from TableName() method if present if strings.Contains(tag, "gorm:") { tablePart, schemaPart := r.extractTableFromGormTag(tag) if tablePart != "" { tableName = tablePart } if schemaPart != "" { schemaName = schemaPart } } column.Table = tableName column.Schema = schemaName table.Name = tableName table.Schema = schemaName table.Columns[column.Name] = column // Parse indexes from GORM tags r.parseIndexesFromTag(table, column, tag) // Handle inline references (e.g., "bigint references mainaccount(id_mainaccount)") if inlineRef != "" { r.createInlineReferenceConstraint(table, column, inlineRef) } sequence++ } } return table } // createInlineReferenceConstraint creates a foreign key constraint from inline reference // e.g., "mainaccount(id_mainaccount) ON DELETE CASCADE" creates an FK constraint func (r *Reader) createInlineReferenceConstraint(table *models.Table, column *models.Column, refInfo string) { // Parse refInfo: "mainaccount(id_mainaccount) ON DELETE CASCADE ON UPDATE RESTRICT" // Extract table name, column name, and constraint actions openParen := strings.Index(refInfo, "(") closeParen := strings.Index(refInfo, ")") if openParen == -1 || closeParen == -1 || openParen >= closeParen { // Invalid format, skip return } refTableFull := strings.TrimSpace(refInfo[:openParen]) refColumn := strings.TrimSpace(refInfo[openParen+1 : closeParen]) // Extract ON DELETE/UPDATE clauses (everything after the closing paren) constraintActions := "" if closeParen+1 < len(refInfo) { constraintActions = strings.TrimSpace(refInfo[closeParen+1:]) } // Parse schema.table or just table refSchema := "public" refTable := refTableFull if strings.Contains(refTableFull, ".") { parts := strings.SplitN(refTableFull, ".", 2) refSchema = parts[0] refTable = parts[1] } // Parse ON DELETE and ON UPDATE actions onDelete := "NO ACTION" onUpdate := "NO ACTION" if constraintActions != "" { constraintActionsLower := strings.ToLower(constraintActions) // Parse ON DELETE if strings.Contains(constraintActionsLower, "on delete cascade") { onDelete = "CASCADE" } else if strings.Contains(constraintActionsLower, "on delete set null") { onDelete = "SET NULL" } else if strings.Contains(constraintActionsLower, "on delete restrict") { onDelete = "RESTRICT" } else if strings.Contains(constraintActionsLower, "on delete no action") { onDelete = "NO ACTION" } // Parse ON UPDATE if strings.Contains(constraintActionsLower, "on update cascade") { onUpdate = "CASCADE" } else if strings.Contains(constraintActionsLower, "on update set null") { onUpdate = "SET NULL" } else if strings.Contains(constraintActionsLower, "on update restrict") { onUpdate = "RESTRICT" } else if strings.Contains(constraintActionsLower, "on update no action") { onUpdate = "NO ACTION" } } // Create a foreign key constraint constraintName := fmt.Sprintf("fk_%s_%s", table.Name, column.Name) constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint) constraint.Schema = table.Schema constraint.Table = table.Name constraint.Columns = []string{column.Name} constraint.ReferencedSchema = refSchema constraint.ReferencedTable = refTable constraint.ReferencedColumns = []string{refColumn} constraint.OnDelete = onDelete constraint.OnUpdate = onUpdate table.Constraints[constraintName] = constraint } // isGORMModel checks if a field is gorm.Model func (r *Reader) isGORMModel(field *ast.Field) bool { if len(field.Names) > 0 { return false // gorm.Model is embedded, so it has no name } // Check if the type is gorm.Model selExpr, ok := field.Type.(*ast.SelectorExpr) if !ok { return false } ident, ok := selExpr.X.(*ast.Ident) if !ok { return false } return ident.Name == "gorm" && selExpr.Sel.Name == "Model" } // isRelationship checks if a field is a relationship based on gorm tag func (r *Reader) isRelationship(tag string) bool { gormTag := r.extractGormTag(tag) gormTagLower := strings.ToLower(gormTag) return strings.Contains(gormTagLower, "foreignkey:") || strings.Contains(gormTagLower, "references:") || strings.Contains(gormTagLower, "many2many:") || strings.Contains(gormTagLower, "preload") } // parseRelationshipConstraints parses relationship fields to extract foreign key constraints func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) { for _, field := range structType.Fields.List { if field.Tag == nil { continue } tag := field.Tag.Value if !r.isRelationship(tag) { continue } gormTag := r.extractGormTag(tag) parts := r.parseGormTag(gormTag) // Get the referenced type name from the field type referencedType := r.getRelationshipType(field.Type) if referencedType == "" { continue } // Determine if this is a has-many or belongs-to relationship isSlice := r.isSliceRelationship(field.Type) // Only process has-many relationships (slices) // Belongs-to relationships (*Type) should be handled by the other side's has-many if !isSlice { continue } // Find the referenced table referencedTable, ok := structMap[referencedType] if !ok { continue } // Extract foreign key information foreignKey, hasForeignKey := parts["foreignkey"] if !hasForeignKey { continue } // GORM's foreignkey tag contains the column name (already in snake_case), not field name fkColumn := foreignKey // Extract association_foreignkey (the column being referenced) assocForeignKey, hasAssocFK := parts["association_foreignkey"] if !hasAssocFK { // Default to "id" if not specified assocForeignKey = "id" } // Determine constraint behavior onDelete := "NO ACTION" onUpdate := "NO ACTION" if constraintStr, hasConstraint := parts["constraint"]; hasConstraint { // Parse constraint:OnDelete:CASCADE,OnUpdate:CASCADE constraintLower := strings.ToLower(constraintStr) if strings.Contains(constraintLower, "ondelete:cascade") { onDelete = "CASCADE" } else if strings.Contains(constraintLower, "ondelete:set null") { onDelete = "SET NULL" } else if strings.Contains(constraintLower, "ondelete:restrict") { onDelete = "RESTRICT" } if strings.Contains(constraintLower, "onupdate:cascade") { onUpdate = "CASCADE" } else if strings.Contains(constraintLower, "onupdate:set null") { onUpdate = "SET NULL" } else if strings.Contains(constraintLower, "onupdate:restrict") { onUpdate = "RESTRICT" } } // The FK is on the referenced table, pointing back to this table // For has-many, the FK is on the "many" side constraint := &models.Constraint{ Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, fkColumn), Type: models.ForeignKeyConstraint, Table: referencedTable.Name, Schema: referencedTable.Schema, Columns: []string{fkColumn}, ReferencedTable: table.Name, ReferencedSchema: table.Schema, ReferencedColumns: []string{assocForeignKey}, OnDelete: onDelete, OnUpdate: onUpdate, } referencedTable.Constraints[constraint.Name] = constraint } } // isSliceRelationship checks if a relationship field is a slice (has-many) or pointer (belongs-to) func (r *Reader) isSliceRelationship(expr ast.Expr) bool { switch expr.(type) { case *ast.ArrayType: // []*Type is a has-many relationship return true case *ast.StarExpr: // *Type is a belongs-to relationship return false } return false } // getRelationshipType extracts the type name from a relationship field func (r *Reader) getRelationshipType(expr ast.Expr) string { switch t := expr.(type) { case *ast.ArrayType: // []*ModelPost -> ModelPost if starExpr, ok := t.Elt.(*ast.StarExpr); ok { if ident, ok := starExpr.X.(*ast.Ident); ok { return ident.Name } } case *ast.StarExpr: // *ModelPost -> ModelPost if ident, ok := t.X.(*ast.Ident); ok { return ident.Name } } return "" } // parseIndexesFromTag extracts index definitions from GORM tags func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) { gormTag := r.extractGormTag(tag) parts := r.parseGormTag(gormTag) // Check for regular index: index:idx_name or index if indexName, ok := parts["index"]; ok { if indexName == "" { // Auto-generated index name indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name) } // Check if index already exists if _, exists := table.Indexes[indexName]; !exists { index := &models.Index{ Name: indexName, Table: table.Name, Schema: table.Schema, Columns: []string{column.Name}, Unique: false, Type: "btree", } table.Indexes[indexName] = index } else { // Add column to existing index table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name) } } // Check for unique index: uniqueIndex:idx_name or uniqueIndex if indexName, ok := parts["uniqueindex"]; ok { if indexName == "" { // Auto-generated index name indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name) } // Check if index already exists if _, exists := table.Indexes[indexName]; !exists { index := &models.Index{ Name: indexName, Table: table.Name, Schema: table.Schema, Columns: []string{column.Name}, Unique: true, Type: "btree", } table.Indexes[indexName] = index } else { // Add column to existing index table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name) } } // Check for simple unique flag (creates a unique index for this column) if _, ok := parts["unique"]; ok { // Auto-generated index name for unique constraint indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name) if _, exists := table.Indexes[indexName]; !exists { index := &models.Index{ Name: indexName, Table: table.Name, Schema: table.Schema, Columns: []string{column.Name}, Unique: true, Type: "btree", } table.Indexes[indexName] = index } } } // extractTableFromGormTag extracts table and schema from gorm tag func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaName string) { // This is typically set via TableName() method, not in tags // We'll return empty strings and rely on deriveTableName return "", "" } // deriveTableName derives a table name from struct name func (r *Reader) deriveTableName(structName string) string { // Remove "Model" prefix if present name := strings.TrimPrefix(structName, "Model") // Convert PascalCase to snake_case var result strings.Builder for i, r := range name { if i > 0 && r >= 'A' && r <= 'Z' { result.WriteRune('_') } result.WriteRune(r) } return strings.ToLower(result.String()) } // parseColumn parses a struct field into a Column model // Returns the column and any inline reference information (e.g., "mainaccount(id_mainaccount)") func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) (col *models.Column, ref string) { // Extract gorm tag gormTag := r.extractGormTag(tag) if gormTag == "" { return nil, "" } column := models.InitColumn("", "", "") column.Sequence = sequence // Parse gorm tag parts := r.parseGormTag(gormTag) // Get column name if colName, ok := parts["column"]; ok { column.Name = colName } else if fieldName != "" { // Derive column name from field name column.Name = r.fieldNameToColumnName(fieldName) } // Parse tag attributes var inlineRef string if typ, ok := parts["type"]; ok { // Parse type and extract length if present (e.g., varchar(255)) // Also extract references if present (e.g., bigint references table(column)) baseType, length, refInfo := r.parseTypeWithReferences(typ) column.Type = baseType column.Length = length inlineRef = refInfo } if _, ok := parts["primarykey"]; ok { column.IsPrimaryKey = true } if _, ok := parts["primary_key"]; ok { column.IsPrimaryKey = true } if _, ok := parts["not null"]; ok { column.NotNull = true } if _, ok := parts["autoincrement"]; ok { column.AutoIncrement = true } if def, ok := parts["default"]; ok { // Default value from GORM tag (e.g., default:gen_random_uuid()) column.Default = def } if size, ok := parts["size"]; ok { if s, err := strconv.Atoi(size); err == nil { column.Length = s } } // If no type specified in tag, derive from Go type if column.Type == "" { column.Type = r.goTypeToSQL(fieldType) } // Determine if nullable based on GORM tags and Go type // In GORM: // - explicit "not null" tag means NOT NULL // - absence of "not null" tag with sql_types means nullable // - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable // Primary keys are always NOT NULL column.NotNull = false if _, hasNotNull := parts["not null"]; hasNotNull { column.NotNull = true } else { // sql_types.SqlString, etc. are nullable by default column.NotNull = !r.isNullableGoType(fieldType) } if column.IsPrimaryKey { column.NotNull = true } return column, inlineRef } // extractGormTag extracts the gorm tag value from a struct tag func (r *Reader) extractGormTag(tag string) string { // Remove backticks tag = strings.Trim(tag, "`") // Use reflect.StructTag to properly parse st := reflect.StructTag(tag) return st.Get("gorm") } // parseTypeWithLength parses a type string and extracts length if present // e.g., "varchar(255)" returns ("varchar", 255) func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) { // Check for type with length: varchar(255), char(10), etc. // Also handle precision/scale: numeric(10,2) if strings.Contains(typeStr, "(") { idx := strings.Index(typeStr, "(") baseType = strings.TrimSpace(typeStr[:idx]) // Extract numbers from parentheses parens := typeStr[idx+1:] if endIdx := strings.Index(parens, ")"); endIdx > 0 { parens = parens[:endIdx] } // For now, just handle single number (length) if !strings.Contains(parens, ",") { if _, err := fmt.Sscanf(parens, "%d", &length); err == nil { return } } } baseType = typeStr return } // parseTypeWithReferences parses a type string and extracts base type, length, and references // e.g., "bigint references mainaccount(id_mainaccount) ON DELETE CASCADE" returns ("bigint", 0, "mainaccount(id_mainaccount) ON DELETE CASCADE") func (r *Reader) parseTypeWithReferences(typeStr string) (baseType string, length int, refInfo string) { // Check if the type contains "references" (case-insensitive) lowerType := strings.ToLower(typeStr) if strings.Contains(lowerType, " references ") { // Split on "references" (case-insensitive) idx := strings.Index(lowerType, " references ") baseTypePart := strings.TrimSpace(typeStr[:idx]) // Keep the entire reference info including ON DELETE/UPDATE clauses refInfo = strings.TrimSpace(typeStr[idx+len(" references "):]) // Parse base type for length baseType, length = r.parseTypeWithLength(baseTypePart) return } // No references, just parse type and length baseType, length = r.parseTypeWithLength(typeStr) return } // parseGormTag parses a gorm tag string into a map func (r *Reader) parseGormTag(gormTag string) map[string]string { result := make(map[string]string) // Split by semicolon parts := strings.Split(gormTag, ";") for _, part := range parts { part = strings.TrimSpace(part) if part == "" { continue } // Check for key:value pairs if strings.Contains(part, ":") { kv := strings.SplitN(part, ":", 2) // Normalize key to lowercase for case-insensitive matching key := strings.ToLower(kv[0]) result[key] = kv[1] } else { // Flags like "primaryKey", "not null", etc. // Normalize to lowercase result[strings.ToLower(part)] = "" } } return result } // fieldNameToColumnName converts a field name to a column name func (r *Reader) fieldNameToColumnName(fieldName string) string { var result strings.Builder for i, r := range fieldName { if i > 0 && r >= 'A' && r <= 'Z' { result.WriteRune('_') } result.WriteRune(r) } return strings.ToLower(result.String()) } // goTypeToSQL maps Go types to SQL types func (r *Reader) goTypeToSQL(expr ast.Expr) string { switch t := expr.(type) { case *ast.Ident: switch t.Name { case "int", "int32": return "integer" case "int64": return "bigint" case "string": return "text" case "bool": return "boolean" case "float32": return "real" case "float64": return "double precision" } case *ast.SelectorExpr: // Handle types like time.Time, sql_types.SqlString, etc. if ident, ok := t.X.(*ast.Ident); ok { switch ident.Name { case "time": if t.Sel.Name == "Time" { return "timestamp" } case "sql_types": return r.sqlTypeToSQL(t.Sel.Name) } } case *ast.StarExpr: // Pointer type - nullable version return r.goTypeToSQL(t.X) } return "text" } // sqlTypeToSQL maps sql_types types to SQL types func (r *Reader) sqlTypeToSQL(typeName string) string { switch typeName { case "SqlString": return "text" case "SqlInt": return "integer" case "SqlInt64": return "bigint" case "SqlFloat": return "double precision" case "SqlBool": return "boolean" case "SqlTime": return "timestamp" default: return "text" } } // isNullableGoType checks if a Go type represents a nullable field type // (this is for types that CAN be nullable, not whether they ARE nullable) func (r *Reader) isNullableGoType(expr ast.Expr) bool { switch t := expr.(type) { case *ast.StarExpr: // Pointer type can be nullable return true case *ast.SelectorExpr: // Check for sql_types nullable types if ident, ok := t.X.(*ast.Ident); ok { if ident.Name == "sql_types" { return strings.HasPrefix(t.Sel.Name, "Sql") } } } return false }