Files
relspecgo/pkg/readers/gorm/reader.go
Hein d93a4b6f08
Some checks are pending
CI / Build (push) Waiting to run
CI / Test (1.23) (push) Waiting to run
CI / Test (1.24) (push) Waiting to run
CI / Test (1.25) (push) Waiting to run
CI / Lint (push) Waiting to run
Fixed bug/gorm indexes
2025-12-18 19:15:22 +02:00

802 lines
20 KiB
Go

package gorm
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
)
// Reader implements the readers.Reader interface for GORM Go model files
type Reader struct {
options *readers.ReaderOptions
}
// NewReader creates a new GORM reader with the given options
func NewReader(options *readers.ReaderOptions) *Reader {
return &Reader{
options: options,
}
}
// ReadDatabase reads GORM Go model files and returns a Database model
func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for GORM reader")
}
// Check if path is a directory or file
info, err := os.Stat(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to stat path: %w", err)
}
var files []string
if info.IsDir() {
// Read all .go files in directory
entries, err := os.ReadDir(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
files = append(files, filepath.Join(r.options.FilePath, entry.Name()))
}
}
} else {
files = append(files, r.options.FilePath)
}
if len(files) == 0 {
return nil, fmt.Errorf("no Go files found")
}
// Parse all files and collect tables
db := models.InitDatabase("database")
schemaMap := make(map[string]*models.Schema)
for _, file := range files {
tables, err := r.parseFile(file)
if err != nil {
return nil, fmt.Errorf("failed to parse file %s: %w", file, err)
}
for _, table := range tables {
// Get or create schema
schema, ok := schemaMap[table.Schema]
if !ok {
schema = models.InitSchema(table.Schema)
schemaMap[table.Schema] = schema
}
schema.Tables = append(schema.Tables, table)
}
}
// Convert schema map to slice
for _, schema := range schemaMap {
db.Schemas = append(db.Schemas, schema)
}
return db, nil
}
// ReadSchema reads GORM Go model files and returns a Schema model
func (r *Reader) ReadSchema() (*models.Schema, error) {
db, err := r.ReadDatabase()
if err != nil {
return nil, err
}
if len(db.Schemas) == 0 {
return nil, fmt.Errorf("no schemas found")
}
return db.Schemas[0], nil
}
// ReadTable reads a GORM Go model file and returns a Table model
func (r *Reader) ReadTable() (*models.Table, error) {
schema, err := r.ReadSchema()
if err != nil {
return nil, err
}
if len(schema.Tables) == 0 {
return nil, fmt.Errorf("no tables found")
}
return schema.Tables[0], nil
}
// parseFile parses a single Go file and extracts table models
func (r *Reader) parseFile(filename string) ([]*models.Table, error) {
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("failed to parse Go file: %w", err)
}
var tables []*models.Table
structMap := make(map[string]*models.Table)
// First pass: collect struct definitions
for _, decl := range node.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
// Check if this struct has gorm tags (indicates it's a model)
if r.hasModelFields(structType) {
table := r.parseStruct(typeSpec.Name.Name, structType)
if table != nil {
structMap[typeSpec.Name.Name] = table
tables = append(tables, table)
}
}
}
}
// Second pass: find TableName() methods
for _, decl := range node.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok || funcDecl.Name.Name != "TableName" {
continue
}
// Get receiver type
if funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
receiverType := r.getReceiverType(funcDecl.Recv.List[0].Type)
if receiverType == "" {
continue
}
// Find the table for this struct
table, ok := structMap[receiverType]
if !ok {
continue
}
// Parse the return value
tableName, schemaName := r.parseTableNameMethod(funcDecl)
if tableName != "" {
table.Name = tableName
if schemaName != "" {
table.Schema = schemaName
}
// Update columns and indexes
for _, col := range table.Columns {
col.Table = tableName
col.Schema = table.Schema
}
for _, idx := range table.Indexes {
idx.Table = tableName
idx.Schema = table.Schema
}
}
}
// Third pass: parse relationship fields for constraints
// Re-parse the file to get relationship information
for _, decl := range node.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
table, ok := structMap[typeSpec.Name.Name]
if !ok {
continue
}
// Parse relationship fields
r.parseRelationshipConstraints(table, structType, structMap)
}
}
return tables, nil
}
// getReceiverType extracts the type name from a receiver
func (r *Reader) getReceiverType(expr ast.Expr) string {
switch t := expr.(type) {
case *ast.Ident:
return t.Name
case *ast.StarExpr:
if ident, ok := t.X.(*ast.Ident); ok {
return ident.Name
}
}
return ""
}
// parseTableNameMethod parses a TableName() method and extracts the table and schema name
func (r *Reader) parseTableNameMethod(funcDecl *ast.FuncDecl) (tableName string, schemaName string) {
if funcDecl.Body == nil {
return "", ""
}
// Look for return statement
for _, stmt := range funcDecl.Body.List {
retStmt, ok := stmt.(*ast.ReturnStmt)
if !ok {
continue
}
if len(retStmt.Results) == 0 {
continue
}
// Get the return value (should be a string literal)
if basicLit, ok := retStmt.Results[0].(*ast.BasicLit); ok {
if basicLit.Kind == token.STRING {
// Remove quotes
fullName := strings.Trim(basicLit.Value, "\"")
// Split schema.table
if strings.Contains(fullName, ".") {
parts := strings.SplitN(fullName, ".", 2)
return parts[1], parts[0]
}
return fullName, "public"
}
}
}
return "", ""
}
// hasModelFields checks if the struct has fields with gorm tags
func (r *Reader) hasModelFields(structType *ast.StructType) bool {
for _, field := range structType.Fields.List {
if field.Tag != nil {
tag := field.Tag.Value
if strings.Contains(tag, "gorm:") {
return true
}
}
}
return false
}
// parseStruct converts an AST struct to a Table model
func (r *Reader) parseStruct(structName string, structType *ast.StructType) *models.Table {
tableName := r.deriveTableName(structName)
schemaName := "public"
table := models.InitTable(tableName, schemaName)
sequence := uint(1)
// Parse fields
for _, field := range structType.Fields.List {
if field.Tag == nil {
continue
}
tag := field.Tag.Value
if !strings.Contains(tag, "gorm:") {
continue
}
// Skip embedded GORM model
if r.isGORMModel(field) {
continue
}
// Parse relationship fields for foreign key constraints
if r.isRelationship(tag) {
// We'll parse constraints in a second pass after we know all table names
continue
}
// Get field name
fieldName := ""
if len(field.Names) > 0 {
fieldName = field.Names[0].Name
}
// Parse column from tag
column := r.parseColumn(fieldName, field.Type, tag, sequence)
if column != nil {
// Extract schema and table name from TableName() method if present
if strings.Contains(tag, "gorm:") {
tablePart, schemaPart := r.extractTableFromGormTag(tag)
if tablePart != "" {
tableName = tablePart
}
if schemaPart != "" {
schemaName = schemaPart
}
}
column.Table = tableName
column.Schema = schemaName
table.Name = tableName
table.Schema = schemaName
table.Columns[column.Name] = column
// Parse indexes from GORM tags
r.parseIndexesFromTag(table, column, tag)
sequence++
}
}
return table
}
// isGORMModel checks if a field is gorm.Model
func (r *Reader) isGORMModel(field *ast.Field) bool {
if len(field.Names) > 0 {
return false // gorm.Model is embedded, so it has no name
}
// Check if the type is gorm.Model
selExpr, ok := field.Type.(*ast.SelectorExpr)
if !ok {
return false
}
ident, ok := selExpr.X.(*ast.Ident)
if !ok {
return false
}
return ident.Name == "gorm" && selExpr.Sel.Name == "Model"
}
// isRelationship checks if a field is a relationship based on gorm tag
func (r *Reader) isRelationship(tag string) bool {
gormTag := r.extractGormTag(tag)
return strings.Contains(gormTag, "foreignKey:") ||
strings.Contains(gormTag, "references:") ||
strings.Contains(gormTag, "many2many:")
}
// parseRelationshipConstraints parses relationship fields to extract foreign key constraints
func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) {
for _, field := range structType.Fields.List {
if field.Tag == nil {
continue
}
tag := field.Tag.Value
if !r.isRelationship(tag) {
continue
}
gormTag := r.extractGormTag(tag)
parts := r.parseGormTag(gormTag)
// Get the referenced type name from the field type
referencedType := r.getRelationshipType(field.Type)
if referencedType == "" {
continue
}
// Find the referenced table
referencedTable, ok := structMap[referencedType]
if !ok {
continue
}
// Extract foreign key information
foreignKey, hasForeignKey := parts["foreignKey"]
if !hasForeignKey {
continue
}
// Convert field name to column name
fkColumn := r.fieldNameToColumnName(foreignKey)
// Determine constraint behavior
onDelete := "NO ACTION"
onUpdate := "NO ACTION"
if constraintStr, hasConstraint := parts["constraint"]; hasConstraint {
// Parse constraint:OnDelete:CASCADE,OnUpdate:CASCADE
if strings.Contains(constraintStr, "OnDelete:CASCADE") {
onDelete = "CASCADE"
} else if strings.Contains(constraintStr, "OnDelete:SET NULL") {
onDelete = "SET NULL"
}
if strings.Contains(constraintStr, "OnUpdate:CASCADE") {
onUpdate = "CASCADE"
} else if strings.Contains(constraintStr, "OnUpdate:SET NULL") {
onUpdate = "SET NULL"
}
}
// The FK is on the referenced table, pointing back to this table
// For has-many, the FK is on the "many" side
constraint := &models.Constraint{
Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name),
Type: models.ForeignKeyConstraint,
Table: referencedTable.Name,
Schema: referencedTable.Schema,
Columns: []string{fkColumn},
ReferencedTable: table.Name,
ReferencedSchema: table.Schema,
ReferencedColumns: []string{"id"}, // Typically references the primary key
OnDelete: onDelete,
OnUpdate: onUpdate,
}
referencedTable.Constraints[constraint.Name] = constraint
}
}
// getRelationshipType extracts the type name from a relationship field
func (r *Reader) getRelationshipType(expr ast.Expr) string {
switch t := expr.(type) {
case *ast.ArrayType:
// []*ModelPost -> ModelPost
if starExpr, ok := t.Elt.(*ast.StarExpr); ok {
if ident, ok := starExpr.X.(*ast.Ident); ok {
return ident.Name
}
}
case *ast.StarExpr:
// *ModelPost -> ModelPost
if ident, ok := t.X.(*ast.Ident); ok {
return ident.Name
}
}
return ""
}
// parseIndexesFromTag extracts index definitions from GORM tags
func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) {
gormTag := r.extractGormTag(tag)
parts := r.parseGormTag(gormTag)
// Check for regular index: index:idx_name or index
if indexName, ok := parts["index"]; ok {
if indexName == "" {
// Auto-generated index name
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
}
// Check if index already exists
if _, exists := table.Indexes[indexName]; !exists {
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: false,
Type: "btree",
}
table.Indexes[indexName] = index
} else {
// Add column to existing index
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
}
}
// Check for unique index: uniqueIndex:idx_name or uniqueIndex
if indexName, ok := parts["uniqueIndex"]; ok {
if indexName == "" {
// Auto-generated index name
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
}
// Check if index already exists
if _, exists := table.Indexes[indexName]; !exists {
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: true,
Type: "btree",
}
table.Indexes[indexName] = index
} else {
// Add column to existing index
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
}
}
// Check for simple unique flag (creates a unique index for this column)
if _, ok := parts["unique"]; ok {
// Auto-generated index name for unique constraint
indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
if _, exists := table.Indexes[indexName]; !exists {
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: true,
Type: "btree",
}
table.Indexes[indexName] = index
}
}
}
// extractTableFromGormTag extracts table and schema from gorm tag
func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaName string) {
// This is typically set via TableName() method, not in tags
// We'll return empty strings and rely on deriveTableName
return "", ""
}
// deriveTableName derives a table name from struct name
func (r *Reader) deriveTableName(structName string) string {
// Remove "Model" prefix if present
name := strings.TrimPrefix(structName, "Model")
// Convert PascalCase to snake_case
var result strings.Builder
for i, r := range name {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// parseColumn parses a struct field into a Column model
func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) *models.Column {
// Extract gorm tag
gormTag := r.extractGormTag(tag)
if gormTag == "" {
return nil
}
column := models.InitColumn("", "", "")
column.Sequence = sequence
// Parse gorm tag
parts := r.parseGormTag(gormTag)
// Get column name
if colName, ok := parts["column"]; ok {
column.Name = colName
} else if fieldName != "" {
// Derive column name from field name
column.Name = r.fieldNameToColumnName(fieldName)
}
// Parse tag attributes
if typ, ok := parts["type"]; ok {
// Parse type and extract length if present (e.g., varchar(255))
column.Type, column.Length = r.parseTypeWithLength(typ)
}
if _, ok := parts["primaryKey"]; ok {
column.IsPrimaryKey = true
}
if _, ok := parts["not null"]; ok {
column.NotNull = true
}
if _, ok := parts["autoIncrement"]; ok {
column.AutoIncrement = true
}
if def, ok := parts["default"]; ok {
// Default value from GORM tag (e.g., default:gen_random_uuid())
column.Default = def
}
if size, ok := parts["size"]; ok {
if s, err := strconv.Atoi(size); err == nil {
column.Length = s
}
}
// If no type specified in tag, derive from Go type
if column.Type == "" {
column.Type = r.goTypeToSQL(fieldType)
}
// Determine if nullable based on GORM tags and Go type
// In GORM:
// - explicit "not null" tag means NOT NULL
// - absence of "not null" tag with sql_types means nullable
// - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable
if _, hasNotNull := parts["not null"]; hasNotNull {
column.NotNull = true
} else {
// If no explicit "not null" tag, check the Go type
if r.isNullableGoType(fieldType) {
// sql_types.SqlString, etc. are nullable by default
column.NotNull = false
} else {
// Primitive types default to NOT NULL
column.NotNull = false // Default to nullable unless explicitly set
}
}
// Primary keys are always NOT NULL
if column.IsPrimaryKey {
column.NotNull = true
}
return column
}
// extractGormTag extracts the gorm tag value from a struct tag
func (r *Reader) extractGormTag(tag string) string {
// Remove backticks
tag = strings.Trim(tag, "`")
// Use reflect.StructTag to properly parse
st := reflect.StructTag(tag)
return st.Get("gorm")
}
// parseTypeWithLength parses a type string and extracts length if present
// e.g., "varchar(255)" returns ("varchar", 255)
func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) {
// Check for type with length: varchar(255), char(10), etc.
// Also handle precision/scale: numeric(10,2)
if strings.Contains(typeStr, "(") {
idx := strings.Index(typeStr, "(")
baseType = strings.TrimSpace(typeStr[:idx])
// Extract numbers from parentheses
parens := typeStr[idx+1:]
if endIdx := strings.Index(parens, ")"); endIdx > 0 {
parens = parens[:endIdx]
}
// For now, just handle single number (length)
if !strings.Contains(parens, ",") {
if _, err := fmt.Sscanf(parens, "%d", &length); err == nil {
return
}
}
}
baseType = typeStr
return
}
// parseGormTag parses a gorm tag string into a map
func (r *Reader) parseGormTag(gormTag string) map[string]string {
result := make(map[string]string)
// Split by semicolon
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Check for key:value pairs
if strings.Contains(part, ":") {
kv := strings.SplitN(part, ":", 2)
result[kv[0]] = kv[1]
} else {
// Flags like "primaryKey", "not null", etc.
result[part] = ""
}
}
return result
}
// fieldNameToColumnName converts a field name to a column name
func (r *Reader) fieldNameToColumnName(fieldName string) string {
var result strings.Builder
for i, r := range fieldName {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// goTypeToSQL maps Go types to SQL types
func (r *Reader) goTypeToSQL(expr ast.Expr) string {
switch t := expr.(type) {
case *ast.Ident:
switch t.Name {
case "int", "int32":
return "integer"
case "int64":
return "bigint"
case "string":
return "text"
case "bool":
return "boolean"
case "float32":
return "real"
case "float64":
return "double precision"
}
case *ast.SelectorExpr:
// Handle types like time.Time, sql_types.SqlString, etc.
if ident, ok := t.X.(*ast.Ident); ok {
switch ident.Name {
case "time":
if t.Sel.Name == "Time" {
return "timestamp"
}
case "sql_types":
return r.sqlTypeToSQL(t.Sel.Name)
}
}
case *ast.StarExpr:
// Pointer type - nullable version
return r.goTypeToSQL(t.X)
}
return "text"
}
// sqlTypeToSQL maps sql_types types to SQL types
func (r *Reader) sqlTypeToSQL(typeName string) string {
switch typeName {
case "SqlString":
return "text"
case "SqlInt":
return "integer"
case "SqlInt64":
return "bigint"
case "SqlFloat":
return "double precision"
case "SqlBool":
return "boolean"
case "SqlTime":
return "timestamp"
default:
return "text"
}
}
// isNullableGoType checks if a Go type represents a nullable field type
// (this is for types that CAN be nullable, not whether they ARE nullable)
func (r *Reader) isNullableGoType(expr ast.Expr) bool {
switch t := expr.(type) {
case *ast.StarExpr:
// Pointer type can be nullable
return true
case *ast.SelectorExpr:
// Check for sql_types nullable types
if ident, ok := t.X.(*ast.Ident); ok {
if ident.Name == "sql_types" {
return strings.HasPrefix(t.Sel.Name, "Sql")
}
}
}
return false
}