Files
Volleyball/backend/internal/database/migrateScheme.go

189 lines
4.1 KiB
Go

package database
import (
"database/sql"
"fmt"
"log"
"reflect"
"strings"
)
// Auto-create table, columns, types, indexes for PostgreSQL
func CreateOrUpdateTablePG(db *sql.DB, table string, model interface{}) error {
// 0) Check if table exists
exists, err := tableExists(db, table)
if err != nil {
return err
}
if !exists {
// Create table from model
if err := createTable(db, table, model); err != nil {
return err
}
return nil
}
// Update table if it exists
return alterTable(db, table, model)
}
func tableExists(db *sql.DB, table string) (bool, error) {
query := `SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=$1);`
var exists bool
err := db.QueryRow(query, table).Scan(&exists)
return exists, err
}
func createTable(db *sql.DB, table string, model interface{}) error {
t := reflect.TypeOf(model)
if t.Kind() != reflect.Struct {
return fmt.Errorf("model must be a struct")
}
cols := []string{}
idx := []string{}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
// ignore fields not meant for DB
if f.Tag.Get("ignore") == "true" || f.Tag.Get("db") == "-" {
continue
}
col := f.Tag.Get("db")
if col == "" {
col = strings.ToLower(f.Name)
}
sqlType := f.Tag.Get("sql")
if sqlType == "" {
return fmt.Errorf("field '%s' missing sql tag", f.Name)
}
cols = append(cols, fmt.Sprintf("\"%s\" %s", col, sqlType))
if f.Tag.Get("index") == "true" {
idx = append(idx, col)
}
}
query := fmt.Sprintf(`CREATE TABLE "%s" (%s);`, table, strings.Join(cols, ", "))
log.Println("[CREATE TABLE]", query)
if _, err := db.Exec(query); err != nil {
return err
}
// Create indexes
for _, col := range idx {
indexName := fmt.Sprintf("%s_%s_idx", table, col)
iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, indexName, table, col)
log.Println("[CREATE INDEX]", iq)
if _, err := db.Exec(iq); err != nil {
return err
}
}
return nil
}
func alterTable(db *sql.DB, table string, model interface{}) error {
currentCols, err := getColumnsPG(db, table)
if err != nil {
return err
}
currentIdx, err := getIndexesPG(db, table)
if err != nil {
return err
}
t := reflect.TypeOf(model)
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
col := f.Tag.Get("db")
if col == "" {
col = strings.ToLower(f.Name)
}
sqlType := f.Tag.Get("sql")
// Missing column
if _, ok := currentCols[col]; !ok {
query := fmt.Sprintf(`ALTER TABLE "%s" ADD COLUMN "%s" %s;`, table, col, sqlType)
log.Println("[ADD COLUMN]", query)
if _, err := db.Exec(query); err != nil {
return err
}
continue
}
// Type mismatch → alter
pgType := strings.ToLower(currentCols[col])
expType := strings.ToLower(sqlType)
if !strings.Contains(pgType, strings.Split(expType, "(")[0]) {
query := fmt.Sprintf(`ALTER TABLE "%s" ALTER COLUMN "%s" TYPE %s;`, table, col, sqlType)
log.Println("[ALTER TYPE]", query)
if _, err := db.Exec(query); err != nil {
return err
}
}
// Indexing
if f.Tag.Get("index") == "true" {
idxName := fmt.Sprintf("%s_%s_idx", table, col)
if !currentIdx[idxName] {
iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, idxName, table, col)
log.Println("[CREATE INDEX]", iq)
if _, err := db.Exec(iq); err != nil {
return err
}
}
}
}
return nil
}
// Helpers
func getColumnsPG(db *sql.DB, table string) (map[string]string, error) {
q := `SELECT column_name, data_type FROM information_schema.columns WHERE table_name=$1;`
rows, err := db.Query(q, table)
if err != nil {
return nil, err
}
defer rows.Close()
cols := map[string]string{}
for rows.Next() {
var name, dtype string
if err := rows.Scan(&name, &dtype); err != nil {
return nil, err
}
cols[name] = dtype
}
return cols, nil
}
func getIndexesPG(db *sql.DB, table string) (map[string]bool, error) {
q := `SELECT indexname FROM pg_indexes WHERE tablename=$1;`
rows, err := db.Query(q, table)
if err != nil {
return nil, err
}
defer rows.Close()
idx := map[string]bool{}
for rows.Next() {
var i string
if err := rows.Scan(&i); err != nil {
return nil, err
}
idx[i] = true
}
return idx, nil
}