ADD: added function to update db scheme automaticly
This commit is contained in:
188
backend/internal/database/migrateScheme.go
Normal file
188
backend/internal/database/migrateScheme.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Auto-create table, columns, types, indexes for PostgreSQL
|
||||
func CreateOrUpdateTablePG(db *sql.DB, table string, model interface{}) error {
|
||||
// 0) Check if table exists
|
||||
exists, err := tableExists(db, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// Create table from model
|
||||
if err := createTable(db, table, model); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update table if it exists
|
||||
return alterTable(db, table, model)
|
||||
}
|
||||
|
||||
func tableExists(db *sql.DB, table string) (bool, error) {
|
||||
query := `SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=$1);`
|
||||
var exists bool
|
||||
err := db.QueryRow(query, table).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func createTable(db *sql.DB, table string, model interface{}) error {
|
||||
t := reflect.TypeOf(model)
|
||||
if t.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("model must be a struct")
|
||||
}
|
||||
|
||||
cols := []string{}
|
||||
idx := []string{}
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
|
||||
// ignore fields not meant for DB
|
||||
if f.Tag.Get("ignore") == "true" || f.Tag.Get("db") == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
col := f.Tag.Get("db")
|
||||
if col == "" {
|
||||
col = strings.ToLower(f.Name)
|
||||
}
|
||||
sqlType := f.Tag.Get("sql")
|
||||
if sqlType == "" {
|
||||
return fmt.Errorf("field '%s' missing sql tag", f.Name)
|
||||
}
|
||||
|
||||
cols = append(cols, fmt.Sprintf("\"%s\" %s", col, sqlType))
|
||||
|
||||
if f.Tag.Get("index") == "true" {
|
||||
idx = append(idx, col)
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`CREATE TABLE "%s" (%s);`, table, strings.Join(cols, ", "))
|
||||
log.Println("[CREATE TABLE]", query)
|
||||
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create indexes
|
||||
for _, col := range idx {
|
||||
indexName := fmt.Sprintf("%s_%s_idx", table, col)
|
||||
iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, indexName, table, col)
|
||||
log.Println("[CREATE INDEX]", iq)
|
||||
if _, err := db.Exec(iq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func alterTable(db *sql.DB, table string, model interface{}) error {
|
||||
currentCols, err := getColumnsPG(db, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
currentIdx, err := getIndexesPG(db, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(model)
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
col := f.Tag.Get("db")
|
||||
if col == "" {
|
||||
col = strings.ToLower(f.Name)
|
||||
}
|
||||
sqlType := f.Tag.Get("sql")
|
||||
|
||||
// Missing column
|
||||
if _, ok := currentCols[col]; !ok {
|
||||
query := fmt.Sprintf(`ALTER TABLE "%s" ADD COLUMN "%s" %s;`, table, col, sqlType)
|
||||
log.Println("[ADD COLUMN]", query)
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Type mismatch → alter
|
||||
pgType := strings.ToLower(currentCols[col])
|
||||
expType := strings.ToLower(sqlType)
|
||||
|
||||
if !strings.Contains(pgType, strings.Split(expType, "(")[0]) {
|
||||
query := fmt.Sprintf(`ALTER TABLE "%s" ALTER COLUMN "%s" TYPE %s;`, table, col, sqlType)
|
||||
log.Println("[ALTER TYPE]", query)
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Indexing
|
||||
if f.Tag.Get("index") == "true" {
|
||||
idxName := fmt.Sprintf("%s_%s_idx", table, col)
|
||||
|
||||
if !currentIdx[idxName] {
|
||||
iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, idxName, table, col)
|
||||
log.Println("[CREATE INDEX]", iq)
|
||||
if _, err := db.Exec(iq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helpers
|
||||
func getColumnsPG(db *sql.DB, table string) (map[string]string, error) {
|
||||
q := `SELECT column_name, data_type FROM information_schema.columns WHERE table_name=$1;`
|
||||
rows, err := db.Query(q, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols := map[string]string{}
|
||||
for rows.Next() {
|
||||
var name, dtype string
|
||||
if err := rows.Scan(&name, &dtype); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols[name] = dtype
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func getIndexesPG(db *sql.DB, table string) (map[string]bool, error) {
|
||||
q := `SELECT indexname FROM pg_indexes WHERE tablename=$1;`
|
||||
rows, err := db.Query(q, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
idx := map[string]bool{}
|
||||
for rows.Next() {
|
||||
var i string
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idx[i] = true
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
Reference in New Issue
Block a user