package database import ( "database/sql" "fmt" "log" "reflect" "strings" ) // Auto-create table, columns, types, indexes for PostgreSQL func CreateOrUpdateTablePG(db *sql.DB, table string, model interface{}) error { // 0) Check if table exists exists, err := tableExists(db, table) if err != nil { return err } if !exists { // Create table from model if err := createTable(db, table, model); err != nil { return err } return nil } // Update table if it exists return alterTable(db, table, model) } func tableExists(db *sql.DB, table string) (bool, error) { query := `SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=$1);` var exists bool err := db.QueryRow(query, table).Scan(&exists) return exists, err } func createTable(db *sql.DB, table string, model interface{}) error { t := reflect.TypeOf(model) if t.Kind() != reflect.Struct { return fmt.Errorf("model must be a struct") } cols := []string{} idx := []string{} for i := 0; i < t.NumField(); i++ { f := t.Field(i) // ignore fields not meant for DB if f.Tag.Get("ignore") == "true" || f.Tag.Get("db") == "-" { continue } col := f.Tag.Get("db") if col == "" { col = strings.ToLower(f.Name) } sqlType := f.Tag.Get("sql") if sqlType == "" { return fmt.Errorf("field '%s' missing sql tag", f.Name) } cols = append(cols, fmt.Sprintf("\"%s\" %s", col, sqlType)) if f.Tag.Get("index") == "true" { idx = append(idx, col) } } query := fmt.Sprintf(`CREATE TABLE "%s" (%s);`, table, strings.Join(cols, ", ")) log.Println("[CREATE TABLE]", query) if _, err := db.Exec(query); err != nil { return err } // Create indexes for _, col := range idx { indexName := fmt.Sprintf("%s_%s_idx", table, col) iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, indexName, table, col) log.Println("[CREATE INDEX]", iq) if _, err := db.Exec(iq); err != nil { return err } } return nil } func alterTable(db *sql.DB, table string, model interface{}) error { currentCols, err := getColumnsPG(db, table) if err != nil { return err } currentIdx, err := getIndexesPG(db, table) if err != nil { return err } t := reflect.TypeOf(model) for i := 0; i < t.NumField(); i++ { f := t.Field(i) col := f.Tag.Get("db") if col == "" { col = strings.ToLower(f.Name) } sqlType := f.Tag.Get("sql") // Missing column if _, ok := currentCols[col]; !ok { query := fmt.Sprintf(`ALTER TABLE "%s" ADD COLUMN "%s" %s;`, table, col, sqlType) log.Println("[ADD COLUMN]", query) if _, err := db.Exec(query); err != nil { return err } continue } // Type mismatch → alter pgType := strings.ToLower(currentCols[col]) expType := strings.ToLower(sqlType) if !strings.Contains(pgType, strings.Split(expType, "(")[0]) { query := fmt.Sprintf(`ALTER TABLE "%s" ALTER COLUMN "%s" TYPE %s;`, table, col, sqlType) log.Println("[ALTER TYPE]", query) if _, err := db.Exec(query); err != nil { return err } } // Indexing if f.Tag.Get("index") == "true" { idxName := fmt.Sprintf("%s_%s_idx", table, col) if !currentIdx[idxName] { iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, idxName, table, col) log.Println("[CREATE INDEX]", iq) if _, err := db.Exec(iq); err != nil { return err } } } } return nil } // Helpers func getColumnsPG(db *sql.DB, table string) (map[string]string, error) { q := `SELECT column_name, data_type FROM information_schema.columns WHERE table_name=$1;` rows, err := db.Query(q, table) if err != nil { return nil, err } defer rows.Close() cols := map[string]string{} for rows.Next() { var name, dtype string if err := rows.Scan(&name, &dtype); err != nil { return nil, err } cols[name] = dtype } return cols, nil } func getIndexesPG(db *sql.DB, table string) (map[string]bool, error) { q := `SELECT indexname FROM pg_indexes WHERE tablename=$1;` rows, err := db.Query(q, table) if err != nil { return nil, err } defer rows.Close() idx := map[string]bool{} for rows.Next() { var i string if err := rows.Scan(&i); err != nil { return nil, err } idx[i] = true } return idx, nil }