]> 127.0.0.1 Git - part/.git/commitdiff
sql support postgresql v0.28.0+202306188e4660c
authorqydysky <qydysky@foxmail.com>
Sun, 18 Jun 2023 13:01:56 +0000 (13:01 +0000)
committerqydysky <qydysky@foxmail.com>
Sun, 18 Jun 2023 13:01:56 +0000 (13:01 +0000)
go.mod
go.sum
sql/Sql.go
sql/Sql_test.go

diff --git a/go.mod b/go.mod
index ba4797a1112957d302a3c0f39fdfcd70b41f8a3b..5179b44ecd4b37a631b96418c7a80429ee2cbdf9 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -37,6 +37,7 @@ require (
 require (
        github.com/dustin/go-humanize v1.0.1
        github.com/go-ole/go-ole v1.2.6 // indirect
+       github.com/lib/pq v1.10.9
        github.com/stretchr/testify v1.8.2 // indirect
        golang.org/x/net v0.10.0 // indirect
        golang.org/x/sys v0.8.0 // indirect
diff --git a/go.sum b/go.sum
index 976bffdc8ea6231541b0baef39134698b1e48aaa..9d742efc1c85c752c2a7700705039cb16ebccef2 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -17,6 +17,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU
 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
 github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI=
 github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
+github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
+github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
 github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
 github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
index b3dfbc6197b1ca59fc7589fd7f936f164ea49f97..15b4784b49b38389e208d5ed991b180907bda629 100644 (file)
@@ -59,14 +59,20 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
        return t
 }
 
-func (t *SqlTx[T]) DoPlaceHolder(sqlf SqlFunc[T], ptr any) *SqlTx[T] {
+func (t *SqlTx[T]) DoPlaceHolder(sqlf SqlFunc[T], ptr any, replaceF ...func(index int, holder string) (replaceTo string)) *SqlTx[T] {
        dataR := reflect.ValueOf(ptr).Elem()
+       index := 0
        for i := 0; i < dataR.NumField(); i++ {
                field := dataR.Field(i)
                if field.IsValid() && field.CanSet() {
                        replaceS := "{" + dataR.Type().Field(i).Name + "}"
                        if strings.Contains(sqlf.Query, replaceS) {
-                               sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, "?")
+                               if len(replaceF) == 0 {
+                                       sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, "?")
+                               } else {
+                                       sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, replaceF[0](index, replaceS))
+                                       index += 1
+                               }
                                sqlf.Args = append(sqlf.Args, field.Interface())
                        }
                }
@@ -100,9 +106,12 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) {
                return nil, fmt.Errorf("BeginTx; [] >> fin")
        }
 
+       var hasErr bool
+
        tx, err := t.canTx.BeginTx(t.ctx, t.opts)
        if err != nil {
                e = fmt.Errorf("BeginTx; [] >> %s", err)
+               hasErr = true
        } else {
                for i := 0; i < len(t.sqlFuncs); i++ {
                        sqlf := t.sqlFuncs[i]
@@ -110,6 +119,7 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) {
                        if sqlf.beforeF != nil {
                                if datap, err := sqlf.beforeF(t.dataP, sqlf, e); err != nil {
                                        e = errors.Join(e, fmt.Errorf("%s; >> %s", sqlf.Query, err))
+                                       hasErr = true
                                } else {
                                        t.dataP = datap
                                }
@@ -133,11 +143,13 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) {
                        switch sqlf.Ty {
                        case Execf:
                                if res, err := tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
+                                       hasErr = true
                                        if !sqlf.SkipSqlErr {
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        }
                                } else if sqlf.afterEF != nil {
                                        if datap, err := sqlf.afterEF(t.dataP, res, e); err != nil {
+                                               hasErr = true
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        } else {
                                                t.dataP = datap
@@ -145,11 +157,13 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) {
                                }
                        case Queryf:
                                if res, err := tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
+                                       hasErr = true
                                        if !sqlf.SkipSqlErr {
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        }
                                } else if sqlf.afterQF != nil {
                                        if datap, err := sqlf.afterQF(t.dataP, res, e); err != nil {
+                                               hasErr = true
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        } else {
                                                t.dataP = datap
@@ -158,7 +172,7 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) {
                        }
                }
        }
-       if e != nil {
+       if hasErr {
                if tx != nil {
                        if err := tx.Rollback(); err != nil {
                                e = errors.Join(e, fmt.Errorf("Rollback; [] >> %s", err))
@@ -177,7 +191,7 @@ func IsFin[T any](t *SqlTx[T]) bool {
        return t == nil || t.fin
 }
 
-func DealRows[T any](rows *sql.Rows, createF func() T) (*[]T, error) {
+func DealRows[T any](rows *sql.Rows, newT func() T) (*[]T, error) {
        rowNames, err := rows.Columns()
        if err != nil {
                return nil, err
@@ -195,30 +209,50 @@ func DealRows[T any](rows *sql.Rows, createF func() T) (*[]T, error) {
                        return nil, err
                }
 
-               var stu = createF()
+               var (
+                       stu      = newT()
+                       refV     = reflect.ValueOf(&stu).Elem()
+                       refT     = reflect.TypeOf(&stu).Elem()
+                       FieldMap = make(map[string]*reflect.Value)
+               )
+
+               for NumField := refV.NumField() - 1; NumField >= 0; NumField-- {
+                       field := refV.Field(NumField)
+                       fieldT := refT.Field(NumField)
+                       fieldTName := fieldT.Name
+                       if value, ok := fieldT.Tag.Lookup("sql"); ok {
+                               fieldTName = value
+                       }
+                       if !field.IsValid() {
+                               continue
+                       }
+                       if !field.CanSet() {
+                               FieldMap[strings.ToUpper(fieldTName)] = nil
+                               continue
+                       }
+                       FieldMap[strings.ToUpper(fieldTName)] = &field
+               }
+
                for i := 0; i < len(rowNames); i++ {
-                       v := reflect.ValueOf(&stu).Elem().FieldByName(rowNames[i])
-                       if v.IsValid() {
-                               refT := reflect.TypeOf(&stu).Elem()
-                               if v.CanSet() {
-                                       val := reflect.ValueOf(*rowP[i].(*any))
-                                       if reflect.TypeOf(*rowP[i].(*any)).ConvertibleTo(v.Type()) {
-                                               v.Set(val)
-                                       } else {
-                                               return nil, fmt.Errorf("DealRows:KindNotMatch:[sql] %v !> [%s.%s] %v", val.Kind(), refT.Name(), rowNames[i], v.Type())
-                                       }
+                       if field, ok := FieldMap[strings.ToUpper(rowNames[i])]; ok {
+                               if field == nil {
+                                       return nil, fmt.Errorf("DealRows:%s.%s CanSet:false", refT.Name(), rowNames[i])
+                               }
+                               val := reflect.ValueOf(*rowP[i].(*any))
+                               if reflect.TypeOf(*rowP[i].(*any)).ConvertibleTo(field.Type()) {
+                                       field.Set(val)
                                } else {
-                                       return nil, fmt.Errorf("DealRows:%s.%s CanSet:%v", refT.Name(), rowNames[i], v.CanSet())
+                                       return nil, fmt.Errorf("DealRows:KindNotMatch:[sql] %v !> [%s.%s] %v", val.Kind(), refT.Name(), rowNames[i], field.Type())
                                }
                        }
                }
                res = append(res, stu)
-
        }
 
        return &res, nil
 }
 
+// for mysql,oracle not postgresql
 func SimpleQ[T any](canTx CanTx, query string, ptr *T) (*[]T, error) {
        tx := BeginTx[[]T](canTx, context.Background())
        tx.DoPlaceHolder(SqlFunc[[]T]{Query: query}, ptr)
index 8b96d9247cf4739eb0502e8505fcf2acefcd8ded..2c92aeba3de221173fc55194148ab5c10eb7c108 100644 (file)
@@ -4,10 +4,12 @@ import (
        "context"
        "database/sql"
        "errors"
+       "fmt"
        "sync"
        "testing"
        "time"
 
+       _ "github.com/lib/pq"
        file "github.com/qydysky/part/file"
        _ "modernc.org/sqlite"
 )
@@ -157,7 +159,9 @@ func TestMain3(t *testing.T) {
                t.Fatal(err)
        }
        defer db.Close()
-       defer file.New("test.sqlite3", 0, true).Delete()
+       defer func() {
+               _ = file.New("test.sqlite3", 0, true).Delete()
+       }()
 
        {
                tx := BeginTx[any](db, context.Background())
@@ -220,7 +224,9 @@ func TestMain4(t *testing.T) {
                t.Fatal(err)
        }
        defer db.Close()
-       defer file.New("test.sqlite3", 0, true).Delete()
+       defer func() {
+               _ = file.New("test.sqlite3", 0, true).Delete()
+       }()
 
        conn, _ := db.Conn(context.Background())
        if _, e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
@@ -244,3 +250,50 @@ func TestMain4(t *testing.T) {
                t.Fatal()
        }
 }
+
+func Local_TestPostgresql(t *testing.T) {
+       // connect
+       db, err := sql.Open("postgres", "postgres://postgres:qydysky@192.168.31.103:5432/postgres?sslmode=disable")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer db.Close()
+
+       type test1 struct {
+               Created string `sql:"sss"`
+       }
+
+       if _, e := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+               Query:      "create table test (created varchar(20))",
+               SkipSqlErr: true,
+       }).Fin(); e != nil {
+               t.Fatal(e)
+       }
+
+       if _, e := BeginTx[any](db, context.Background(), &sql.TxOptions{}).DoPlaceHolder(SqlFunc[any]{
+               Query: "insert into test (created) values ({Created})",
+       }, &test1{"1"}, func(index int, holder string) (replaceTo string) {
+               return fmt.Sprintf("$%d", index+1)
+       }).Fin(); e != nil {
+               t.Fatal(e)
+       }
+
+       if _, e := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+               Query: "select created as sss from test",
+               afterQF: func(_ *any, rows *sql.Rows, txE error) (dataPR *any, stopErr error) {
+                       if rowsP, e := DealRows[test1](rows, func() test1 { return test1{} }); e != nil {
+                               return nil, e
+                       } else {
+                               if len(*rowsP) != 1 {
+                                       return nil, errors.New("no match")
+                               }
+                               if (*rowsP)[0].Created != "1" {
+                                       return nil, errors.New("no match")
+                               }
+                       }
+                       return nil, nil
+               },
+       }).Fin(); e != nil {
+               t.Fatal(e)
+       }
+}