189 lines
4.1 KiB
Go
189 lines
4.1 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
// Auto-create table, columns, types, indexes for PostgreSQL
|
|
func CreateOrUpdateTablePG(db *sql.DB, table string, model interface{}) error {
|
|
// 0) Check if table exists
|
|
exists, err := tableExists(db, table)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
// Create table from model
|
|
if err := createTable(db, table, model); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Update table if it exists
|
|
return alterTable(db, table, model)
|
|
}
|
|
|
|
func tableExists(db *sql.DB, table string) (bool, error) {
|
|
query := `SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name=$1);`
|
|
var exists bool
|
|
err := db.QueryRow(query, table).Scan(&exists)
|
|
return exists, err
|
|
}
|
|
|
|
func createTable(db *sql.DB, table string, model interface{}) error {
|
|
t := reflect.TypeOf(model)
|
|
if t.Kind() != reflect.Struct {
|
|
return fmt.Errorf("model must be a struct")
|
|
}
|
|
|
|
cols := []string{}
|
|
idx := []string{}
|
|
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
|
|
// ignore fields not meant for DB
|
|
if f.Tag.Get("ignore") == "true" || f.Tag.Get("db") == "-" {
|
|
continue
|
|
}
|
|
|
|
col := f.Tag.Get("db")
|
|
if col == "" {
|
|
col = strings.ToLower(f.Name)
|
|
}
|
|
sqlType := f.Tag.Get("sql")
|
|
if sqlType == "" {
|
|
return fmt.Errorf("field '%s' missing sql tag", f.Name)
|
|
}
|
|
|
|
cols = append(cols, fmt.Sprintf("\"%s\" %s", col, sqlType))
|
|
|
|
if f.Tag.Get("index") == "true" {
|
|
idx = append(idx, col)
|
|
}
|
|
}
|
|
|
|
query := fmt.Sprintf(`CREATE TABLE "%s" (%s);`, table, strings.Join(cols, ", "))
|
|
log.Println("[CREATE TABLE]", query)
|
|
|
|
if _, err := db.Exec(query); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create indexes
|
|
for _, col := range idx {
|
|
indexName := fmt.Sprintf("%s_%s_idx", table, col)
|
|
iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, indexName, table, col)
|
|
log.Println("[CREATE INDEX]", iq)
|
|
if _, err := db.Exec(iq); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func alterTable(db *sql.DB, table string, model interface{}) error {
|
|
currentCols, err := getColumnsPG(db, table)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
currentIdx, err := getIndexesPG(db, table)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
t := reflect.TypeOf(model)
|
|
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
col := f.Tag.Get("db")
|
|
if col == "" {
|
|
col = strings.ToLower(f.Name)
|
|
}
|
|
sqlType := f.Tag.Get("sql")
|
|
|
|
// Missing column
|
|
if _, ok := currentCols[col]; !ok {
|
|
query := fmt.Sprintf(`ALTER TABLE "%s" ADD COLUMN "%s" %s;`, table, col, sqlType)
|
|
log.Println("[ADD COLUMN]", query)
|
|
if _, err := db.Exec(query); err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Type mismatch → alter
|
|
pgType := strings.ToLower(currentCols[col])
|
|
expType := strings.ToLower(sqlType)
|
|
|
|
if !strings.Contains(pgType, strings.Split(expType, "(")[0]) {
|
|
query := fmt.Sprintf(`ALTER TABLE "%s" ALTER COLUMN "%s" TYPE %s;`, table, col, sqlType)
|
|
log.Println("[ALTER TYPE]", query)
|
|
if _, err := db.Exec(query); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Indexing
|
|
if f.Tag.Get("index") == "true" {
|
|
idxName := fmt.Sprintf("%s_%s_idx", table, col)
|
|
|
|
if !currentIdx[idxName] {
|
|
iq := fmt.Sprintf(`CREATE INDEX "%s" ON "%s" ("%s");`, idxName, table, col)
|
|
log.Println("[CREATE INDEX]", iq)
|
|
if _, err := db.Exec(iq); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Helpers
|
|
func getColumnsPG(db *sql.DB, table string) (map[string]string, error) {
|
|
q := `SELECT column_name, data_type FROM information_schema.columns WHERE table_name=$1;`
|
|
rows, err := db.Query(q, table)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
cols := map[string]string{}
|
|
for rows.Next() {
|
|
var name, dtype string
|
|
if err := rows.Scan(&name, &dtype); err != nil {
|
|
return nil, err
|
|
}
|
|
cols[name] = dtype
|
|
}
|
|
return cols, nil
|
|
}
|
|
|
|
func getIndexesPG(db *sql.DB, table string) (map[string]bool, error) {
|
|
q := `SELECT indexname FROM pg_indexes WHERE tablename=$1;`
|
|
rows, err := db.Query(q, table)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
idx := map[string]bool{}
|
|
for rows.Next() {
|
|
var i string
|
|
if err := rows.Scan(&i); err != nil {
|
|
return nil, err
|
|
}
|
|
idx[i] = true
|
|
}
|
|
return idx, nil
|
|
}
|