]> 127.0.0.1 Git - part/.git/commitdiff
Improve v0.25.2
authorqydysky <qydysky@foxmail.com>
Tue, 2 May 2023 08:56:03 +0000 (16:56 +0800)
committerqydysky <qydysky@foxmail.com>
Tue, 2 May 2023 08:56:03 +0000 (16:56 +0800)
sql/Sql.go
sql/Sql_test.go

index 461827012ce30a88222de59d01e348422dc592d4..b3dfbc6197b1ca59fc7589fd7f936f164ea49f97 100644 (file)
@@ -6,10 +6,12 @@ import (
        "errors"
        "fmt"
        "reflect"
+       "strings"
 )
 
 const (
-       Execf = iota
+       null = iota
+       Execf
        Queryf
 )
 
@@ -17,6 +19,10 @@ type CanTx interface {
        BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
 }
 
+type BeforeF[T any] func(dataP *T, sqlf *SqlFunc[T], txE error) (dataPR *T, stopErr error)
+type AfterEF[T any] func(dataP *T, result sql.Result, txE error) (dataPR *T, stopErr error)
+type AfterQF[T any] func(dataP *T, rows *sql.Rows, txE error) (dataPR *T, stopErr error)
+
 type SqlTx[T any] struct {
        canTx    CanTx
        ctx      context.Context
@@ -32,18 +38,20 @@ type SqlFunc[T any] struct {
        Query      string
        Args       []any
        SkipSqlErr bool
-       BeforeEF   func(dataP *T, sqlf *SqlFunc[T], txE error) (dataPR *T, stopErr error)
-       BeforeQF   func(dataP *T, sqlf *SqlFunc[T], txE error) (dataPR *T, stopErr error)
-       AfterEF    func(dataP *T, result sql.Result, txE error) (dataPR *T, stopErr error)
-       AfterQF    func(dataP *T, rows *sql.Rows, txE error) (dataPR *T, stopErr error)
+       beforeF    BeforeF[T]
+       afterEF    AfterEF[T]
+       afterQF    AfterQF[T]
 }
 
-func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] {
-       return &SqlTx[T]{
+func BeginTx[T any](canTx CanTx, ctx context.Context, opts ...*sql.TxOptions) *SqlTx[T] {
+       var tx = SqlTx[T]{
                canTx: canTx,
                ctx:   ctx,
-               opts:  opts,
        }
+       if len(opts) > 0 {
+               tx.opts = opts[0]
+       }
+       return &tx
 }
 
 func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
@@ -51,9 +59,45 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
        return t
 }
 
-func (t *SqlTx[T]) Fin() (e error) {
+func (t *SqlTx[T]) DoPlaceHolder(sqlf SqlFunc[T], ptr any) *SqlTx[T] {
+       dataR := reflect.ValueOf(ptr).Elem()
+       for i := 0; i < dataR.NumField(); i++ {
+               field := dataR.Field(i)
+               if field.IsValid() && field.CanSet() {
+                       replaceS := "{" + dataR.Type().Field(i).Name + "}"
+                       if strings.Contains(sqlf.Query, replaceS) {
+                               sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, "?")
+                               sqlf.Args = append(sqlf.Args, field.Interface())
+                       }
+               }
+       }
+       return t.Do(sqlf)
+}
+
+func (t *SqlTx[T]) BeforeF(f BeforeF[T]) *SqlTx[T] {
+       if len(t.sqlFuncs) > 0 {
+               t.sqlFuncs[len(t.sqlFuncs)-1].beforeF = f
+       }
+       return t
+}
+
+func (t *SqlTx[T]) AfterEF(f AfterEF[T]) *SqlTx[T] {
+       if len(t.sqlFuncs) > 0 {
+               t.sqlFuncs[len(t.sqlFuncs)-1].afterEF = f
+       }
+       return t
+}
+
+func (t *SqlTx[T]) AfterQF(f AfterQF[T]) *SqlTx[T] {
+       if len(t.sqlFuncs) > 0 {
+               t.sqlFuncs[len(t.sqlFuncs)-1].afterQF = f
+       }
+       return t
+}
+
+func (t *SqlTx[T]) Fin() (dataP *T, e error) {
        if t.fin {
-               return fmt.Errorf("BeginTx; [] >> fin")
+               return nil, fmt.Errorf("BeginTx; [] >> fin")
        }
 
        tx, err := t.canTx.BeginTx(t.ctx, t.opts)
@@ -62,43 +106,50 @@ func (t *SqlTx[T]) Fin() (e error) {
        } else {
                for i := 0; i < len(t.sqlFuncs); i++ {
                        sqlf := t.sqlFuncs[i]
+
+                       if sqlf.beforeF != nil {
+                               if datap, err := sqlf.beforeF(t.dataP, sqlf, e); err != nil {
+                                       e = errors.Join(e, fmt.Errorf("%s; >> %s", sqlf.Query, err))
+                               } else {
+                                       t.dataP = datap
+                               }
+                       }
+
+                       if strings.TrimSpace(sqlf.Query) == "" {
+                               continue
+                       }
+
                        if sqlf.Ctx == nil {
                                sqlf.Ctx = t.ctx
                        }
+
+                       if sqlf.Ty == null {
+                               sqlf.Ty = Execf
+                               if uquery := strings.ToUpper(strings.TrimSpace(sqlf.Query)); strings.HasPrefix(uquery, "SELECT") {
+                                       sqlf.Ty = Queryf
+                               }
+                       }
+
                        switch sqlf.Ty {
                        case Execf:
-                               if sqlf.BeforeEF != nil {
-                                       if datap, err := sqlf.BeforeEF(t.dataP, sqlf, e); err != nil {
-                                               e = errors.Join(e, fmt.Errorf("%s >> %s", sqlf.Query, err))
-                                       } else {
-                                               t.dataP = datap
-                                       }
-                               }
                                if res, err := tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
                                        if !sqlf.SkipSqlErr {
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        }
-                               } else if sqlf.AfterEF != nil {
-                                       if datap, err := sqlf.AfterEF(t.dataP, res, e); err != nil {
+                               } else if sqlf.afterEF != nil {
+                                       if datap, err := sqlf.afterEF(t.dataP, res, e); err != nil {
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        } else {
                                                t.dataP = datap
                                        }
                                }
                        case Queryf:
-                               if sqlf.BeforeQF != nil {
-                                       if datap, err := sqlf.BeforeQF(t.dataP, sqlf, e); err != nil {
-                                               e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
-                                       } else {
-                                               t.dataP = datap
-                                       }
-                               }
                                if res, err := tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
                                        if !sqlf.SkipSqlErr {
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        }
-                               } else if sqlf.AfterQF != nil {
-                                       if datap, err := sqlf.AfterQF(t.dataP, res, e); err != nil {
+                               } else if sqlf.afterQF != nil {
+                                       if datap, err := sqlf.afterQF(t.dataP, res, e); err != nil {
                                                e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                                        } else {
                                                t.dataP = datap
@@ -119,14 +170,14 @@ func (t *SqlTx[T]) Fin() (e error) {
                }
        }
        t.fin = true
-       return e
+       return t.dataP, e
 }
 
 func IsFin[T any](t *SqlTx[T]) bool {
        return t == nil || t.fin
 }
 
-func DealRows[T any](rows *sql.Rows, createF func() T) ([]T, error) {
+func DealRows[T any](rows *sql.Rows, createF func() T) (*[]T, error) {
        rowNames, err := rows.Columns()
        if err != nil {
                return nil, err
@@ -144,26 +195,38 @@ func DealRows[T any](rows *sql.Rows, createF func() T) ([]T, error) {
                        return nil, err
                }
 
-               var rowT = createF()
-               refV := reflect.ValueOf(&rowT)
+               var stu = createF()
                for i := 0; i < len(rowNames); i++ {
-                       v := refV.Elem().FieldByName(rowNames[i])
+                       v := reflect.ValueOf(&stu).Elem().FieldByName(rowNames[i])
                        if v.IsValid() {
+                               refT := reflect.TypeOf(&stu).Elem()
                                if v.CanSet() {
                                        val := reflect.ValueOf(*rowP[i].(*any))
-                                       if val.Kind() == v.Kind() {
+                                       if reflect.TypeOf(*rowP[i].(*any)).ConvertibleTo(v.Type()) {
                                                v.Set(val)
                                        } else {
-                                               return nil, fmt.Errorf("reflectFail:%s KindNotMatch:%v !> %v", rowNames[i], val.Kind(), v.Kind())
+                                               return nil, fmt.Errorf("DealRows:KindNotMatch:[sql] %v !> [%s.%s] %v", val.Kind(), refT.Name(), rowNames[i], v.Type())
                                        }
                                } else {
-                                       return nil, fmt.Errorf("reflectFail:%s CanSet:%v", rowNames[i], v.CanSet())
+                                       return nil, fmt.Errorf("DealRows:%s.%s CanSet:%v", refT.Name(), rowNames[i], v.CanSet())
                                }
                        }
                }
-               res = append(res, rowT)
+               res = append(res, stu)
 
        }
 
-       return res, nil
+       return &res, nil
+}
+
+func SimpleQ[T any](canTx CanTx, query string, ptr *T) (*[]T, error) {
+       tx := BeginTx[[]T](canTx, context.Background())
+       tx.DoPlaceHolder(SqlFunc[[]T]{Query: query}, ptr)
+       tx.AfterQF(func(_ *[]T, rows *sql.Rows, txE error) (dataPR *[]T, stopErr error) {
+               if txE != nil {
+                       return nil, txE
+               }
+               return DealRows(rows, func() T { return *ptr })
+       })
+       return tx.Fin()
 }
index c1706ad298135f061ec263369e4d97c42c0a1e52..8b96d9247cf4739eb0502e8505fcf2acefcd8ded 100644 (file)
@@ -55,57 +55,54 @@ func TestMain(t *testing.T) {
                Ty:    Queryf,
                Ctx:   ctx,
                Query: "select msg from log",
-               AfterQF: func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) {
-                       names := make([]string, 0)
-                       for rows.Next() {
-                               var name string
-                               if err := rows.Scan(&name); err != nil {
-                                       return nil, err
-                               }
-                               names = append(names, name)
+       }).AfterQF(func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) {
+               names := make([]string, 0)
+               for rows.Next() {
+                       var name string
+                       if err := rows.Scan(&name); err != nil {
+                               return nil, err
                        }
-                       rows.Close()
+                       names = append(names, name)
+               }
+               rows.Close()
 
-                       if len(names) != 1 || dateTime != names[0] {
-                               return nil, errors.New("no")
-                       }
+               if len(names) != 1 || dateTime != names[0] {
+                       return nil, errors.New("no")
+               }
 
-                       return &names, nil
-               },
+               return &names, nil
        })
        tx = tx.Do(SqlFunc[[]string]{
                Ty:  Execf,
                Ctx: ctx,
-               BeforeEF: func(dataP *[]string, sqlf *SqlFunc[[]string], txE error) (dataPR *[]string, stopErr error) {
-                       sqlf.Query = "insert into log2 values (?)"
-                       sqlf.Args = append(sqlf.Args, (*dataP)[0])
-                       return dataP, nil
-               },
+       }).BeforeF(func(dataP *[]string, sqlf *SqlFunc[[]string], txE error) (dataPR *[]string, stopErr error) {
+               sqlf.Query = "insert into log2 values (?)"
+               sqlf.Args = append(sqlf.Args, (*dataP)[0])
+               return dataP, nil
        })
        tx = tx.Do(SqlFunc[[]string]{
                Ty:    Queryf,
                Ctx:   ctx,
                Query: "select msg from log2",
-               AfterQF: func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) {
-                       names := make([]string, 0)
-                       for rows.Next() {
-                               var name string
-                               if err := rows.Scan(&name); err != nil {
-                                       return nil, err
-                               }
-                               names = append(names, name)
+       }).AfterQF(func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) {
+               names := make([]string, 0)
+               for rows.Next() {
+                       var name string
+                       if err := rows.Scan(&name); err != nil {
+                               return nil, err
                        }
-                       rows.Close()
+                       names = append(names, name)
+               }
+               rows.Close()
 
-                       if len(names) != 1 || dateTime != names[0] {
-                               return nil, errors.New("no2")
-                       }
+               if len(names) != 1 || dateTime != names[0] {
+                       return nil, errors.New("no2")
+               }
 
-                       return &names, nil
-               },
+               return &names, nil
        })
 
-       if e := tx.Fin(); e != nil {
+       if _, e := tx.Fin(); e != nil {
                t.Fatal(e)
        }
 }
@@ -120,7 +117,7 @@ func TestMain2(t *testing.T) {
        defer db.Close()
 
        conn, _ := db.Conn(context.Background())
-       if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+       if _, e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
                Ty:    Execf,
                Query: "create table log123 (msg text)",
        }).Fin(); e != nil {
@@ -140,7 +137,7 @@ func TestMain2(t *testing.T) {
                                Query: "insert into log123 values (?)",
                                Args:  []any{"1"},
                        })
-                       if e := x.Fin(); e != nil {
+                       if _, e := x.Fin(); e != nil {
                                res <- e.Error()
                        }
                        wg.Done()
@@ -162,56 +159,57 @@ func TestMain3(t *testing.T) {
        defer db.Close()
        defer file.New("test.sqlite3", 0, true).Delete()
 
-       conn, _ := db.Conn(context.Background())
-       if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
-               Ty:    Execf,
-               Query: "create table log123 (msg text,msg2 text)",
-       }).Fin(); e != nil {
-               t.Fatal(e)
+       {
+               tx := BeginTx[any](db, context.Background())
+               tx.Do(SqlFunc[any]{Query: "create table log123 (msg INT,msg2 text)"})
+               if _, e := tx.Fin(); e != nil {
+                       t.Fatal(e)
+               }
        }
-       conn.Close()
-
-       tx1 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
-               Ty:    Execf,
-               Query: "insert into log123 values ('1','a')",
-       })
 
-       tx2 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
-               Ty:    Execf,
-               Query: "insert into log123 values ('2','b')",
-       })
+       type logg struct {
+               Msg  int64
+               Msg2 string
+       }
 
-       if e := tx1.Fin(); e != nil {
-               t.Log(e)
+       insertLog123 := SqlFunc[any]{Query: "insert into log123 values ({Msg},{Msg2})"}
+       {
+               tx := BeginTx[any](db, context.Background())
+               tx.DoPlaceHolder(insertLog123, &logg{Msg: 1, Msg2: "a"})
+               tx.DoPlaceHolder(insertLog123, &logg{Msg: 2, Msg2: "b"})
+               if _, e := tx.Fin(); e != nil {
+                       t.Log(e)
+               }
+               if _, err := SimpleQ(db, "insert into log123 values ({Msg},{Msg2})", &logg{Msg: 3, Msg2: "b"}); err != nil {
+                       t.Fatal(err)
+               }
        }
-       if e := tx2.Fin(); e != nil {
-               t.Log(e)
+       {
+               selectLog123 := SqlFunc[[]logg]{Query: "select msg as Msg, msg2 as Msg2 from log123 where msg = {Msg}"}
+               tx := BeginTx[[]logg](db, context.Background())
+               tx.DoPlaceHolder(selectLog123, &logg{Msg: 2, Msg2: "b"})
+               tx.AfterQF(func(_ *[]logg, rows *sql.Rows, txE error) (dataPR *[]logg, stopErr error) {
+                       return DealRows(rows, func() logg { return logg{} })
+               })
+               if v, e := tx.Fin(); e != nil {
+                       t.Fatal(e)
+               } else {
+                       if (*v)[0].Msg2 != "b" || (*v)[0].Msg != 2 {
+                               t.Fatal()
+                       }
+               }
        }
-
-       tx1 = BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
-               Ty:    Queryf,
-               Query: "select 1 as Msg, msg2 as Msg2 from log123",
-               AfterQF: func(_ *any, rows *sql.Rows, txE error) (_ *any, stopErr error) {
-                       type logg struct {
-                               Msg  int64
-                               Msg2 string
+       {
+               if v, err := SimpleQ(db, "select msg as Msg, msg2 as Msg2 from log123 where msg2 = {Msg2}", &logg{Msg2: "b"}); err != nil {
+                       t.Fatal(err)
+               } else {
+                       if (*v)[0].Msg2 != "b" || (*v)[0].Msg != 2 {
+                               t.Fatal()
                        }
-
-                       if v, err := DealRows(rows, func() logg { return logg{} }); err != nil {
-                               return nil, err
-                       } else {
-                               if v[0].Msg2 != "a" {
-                                       t.Fatal()
-                               }
-                               if v[1].Msg2 != "b" {
-                                       t.Fatal()
-                               }
+                       if (*v)[1].Msg2 != "b" || (*v)[1].Msg != 3 {
+                               t.Fatal()
                        }
-                       return
-               },
-       })
-       if e := tx1.Fin(); e != nil {
-               t.Fatal(e)
+               }
        }
 }
 
@@ -225,7 +223,7 @@ func TestMain4(t *testing.T) {
        defer file.New("test.sqlite3", 0, true).Delete()
 
        conn, _ := db.Conn(context.Background())
-       if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+       if _, e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
                Ty:    Execf,
                Query: "create table log123 (msg text)",
        }).Fin(); e != nil {
@@ -238,7 +236,7 @@ func TestMain4(t *testing.T) {
                Query: "insert into log123 values ('1')",
        })
 
-       if e := tx1.Fin(); e != nil {
+       if _, e := tx1.Fin(); e != nil {
                t.Log(e)
        }