]> 127.0.0.1 Git - part/.git/commitdiff
Fix
authorqydysky <qydysky@foxmail.com>
Tue, 25 Apr 2023 15:52:48 +0000 (23:52 +0800)
committerqydysky <qydysky@foxmail.com>
Tue, 25 Apr 2023 15:52:48 +0000 (23:52 +0800)
sqlite/Sqlite.go
sqlite/Sqlite_test.go

index 54761148f3db755334ea9fa1aff0b88b0eebdd02..fa0ffdefcfe770ce883f9a85c6daea563de08a6b 100644 (file)
@@ -17,9 +17,11 @@ type CanTx interface {
 }
 
 type SqlTx[T any] struct {
-       tx    *sql.Tx
-       dataP *T
-       err   error
+       canTx    CanTx
+       ctx      context.Context
+       opts     *sql.TxOptions
+       sqlFuncs []*SqlFunc[T]
+       dataP    *T
 }
 
 type SqlFunc[T any] struct {
@@ -35,80 +37,80 @@ type SqlFunc[T any] struct {
 }
 
 func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] {
-       var sqlTX = SqlTx[T]{}
-
-       if tx, err := canTx.BeginTx(ctx, opts); err != nil {
-               sqlTX.err = fmt.Errorf("BeginTx; [] >> %s", err)
-       } else {
-               sqlTX.tx = tx
+       return &SqlTx[T]{
+               canTx: canTx,
+               ctx:   ctx,
+               opts:  opts,
        }
-
-       return &sqlTX
 }
 
 func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
-       if t.err != nil {
-               return t
-       }
-
-       if sqlf.Ctx == nil {
-               sqlf.Ctx = context.Background()
-       }
+       t.sqlFuncs = append(t.sqlFuncs, &sqlf)
+       return t
+}
 
-       switch sqlf.Ty {
-       case Execf:
-               if sqlf.BeforeEF != nil {
-                       if datap, err := sqlf.BeforeEF(t.dataP, &sqlf, t.err); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
-                       } else {
-                               t.dataP = datap
-                       }
-               }
-               if res, err := t.tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
-                       if !sqlf.SkipSqlErr {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
-                       }
-               } else if sqlf.AfterEF != nil {
-                       if datap, err := sqlf.AfterEF(t.dataP, res, t.err); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
-                       } else {
-                               t.dataP = datap
-                       }
-               }
-       case Queryf:
-               if sqlf.BeforeQF != nil {
-                       if datap, err := sqlf.BeforeQF(t.dataP, &sqlf, t.err); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
-                       } else {
-                               t.dataP = datap
-                       }
-               }
-               if res, err := t.tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
-                       if !sqlf.SkipSqlErr {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+func (t *SqlTx[T]) Fin() (e error) {
+       tx, err := t.canTx.BeginTx(t.ctx, t.opts)
+       if err != nil {
+               e = fmt.Errorf("BeginTx; [] >> %s", err)
+       } else {
+               for i := 0; i < len(t.sqlFuncs); i++ {
+                       sqlf := t.sqlFuncs[i]
+                       if sqlf.Ctx == nil {
+                               sqlf.Ctx = t.ctx
                        }
-               } else if sqlf.AfterQF != nil {
-                       if datap, err := sqlf.AfterQF(t.dataP, res, t.err); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
-                       } else {
-                               t.dataP = datap
+                       switch sqlf.Ty {
+                       case Execf:
+                               if sqlf.BeforeEF != nil {
+                                       if datap, err := sqlf.BeforeEF(t.dataP, sqlf, e); err != nil {
+                                               e = errors.Join(e, fmt.Errorf("%s >> %s", sqlf.Query, err))
+                                       } else {
+                                               t.dataP = datap
+                                       }
+                               }
+                               if res, err := tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
+                                       if !sqlf.SkipSqlErr {
+                                               e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+                                       }
+                               } else if sqlf.AfterEF != nil {
+                                       if datap, err := sqlf.AfterEF(t.dataP, res, e); err != nil {
+                                               e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+                                       } else {
+                                               t.dataP = datap
+                                       }
+                               }
+                       case Queryf:
+                               if sqlf.BeforeQF != nil {
+                                       if datap, err := sqlf.BeforeQF(t.dataP, sqlf, e); err != nil {
+                                               e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+                                       } else {
+                                               t.dataP = datap
+                                       }
+                               }
+                               if res, err := tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
+                                       if !sqlf.SkipSqlErr {
+                                               e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+                                       }
+                               } else if sqlf.AfterQF != nil {
+                                       if datap, err := sqlf.AfterQF(t.dataP, res, e); err != nil {
+                                               e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+                                       } else {
+                                               t.dataP = datap
+                                       }
+                               }
                        }
                }
        }
-       return t
-}
-
-func (t *SqlTx[T]) Fin() error {
-       if t.err != nil {
-               if t.tx != nil {
-                       if err := t.tx.Rollback(); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("Rollback; [] >> %s", err))
+       if e != nil {
+               if tx != nil {
+                       if err := tx.Rollback(); err != nil {
+                               e = errors.Join(e, fmt.Errorf("Rollback; [] >> %s", err))
                        }
                }
        } else {
-               if err := t.tx.Commit(); err != nil {
-                       t.err = errors.Join(t.err, fmt.Errorf("Commit; [] >> %s", err))
+               if err := tx.Commit(); err != nil {
+                       e = errors.Join(e, fmt.Errorf("Commit; [] >> %s", err))
                }
        }
-       return t.err
+       return e
 }
index d4fea56265acdad80085175884ae034e3d1417c6..3fabf5474b5f7ac1a7be6b973f93db0fce3a1c57 100644 (file)
@@ -8,12 +8,13 @@ import (
        "testing"
        "time"
 
-       _ "github.com/mattn/go-sqlite3"
+       file "github.com/qydysky/part/file"
+       _ "modernc.org/sqlite"
 )
 
 func TestMain(t *testing.T) {
        // connect
-       db, err := sql.Open("sqlite3", ":memory:")
+       db, err := sql.Open("sqlite", ":memory:")
        if err != nil {
                t.Fatal(err)
        }
@@ -111,7 +112,7 @@ func TestMain(t *testing.T) {
 
 func TestMain2(t *testing.T) {
        // connect
-       db, err := sql.Open("sqlite3", ":memory:")
+       db, err := sql.Open("sqlite", ":memory:")
        if err != nil {
                t.Fatal(err)
        }
@@ -151,3 +152,59 @@ func TestMain2(t *testing.T) {
                t.Fatal(<-res)
        }
 }
+
+func TestMain3(t *testing.T) {
+       // connect
+       db, err := sql.Open("sqlite", "test.sqlite3")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer db.Close()
+       defer file.New("test.sqlite3", 0, true).Delete()
+
+       conn, _ := db.Conn(context.Background())
+       if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+               Ty:    Execf,
+               Query: "create table log123 (msg text)",
+       }).Fin(); e != nil {
+               t.Fatal(e)
+       }
+       conn.Close()
+
+       tx1 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+               Ty:    Execf,
+               Query: "insert into log123 values ('1')",
+       })
+
+       tx2 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+               Ty:    Execf,
+               Query: "insert into log123 values ('2')",
+       })
+
+       if e := tx1.Fin(); e != nil {
+               t.Log(e)
+       }
+       if e := tx2.Fin(); e != nil {
+               t.Log(e)
+       }
+
+       tx1 = BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+               Ty:    Queryf,
+               Query: "select count(1) as c from log123",
+               AfterQF: func(_ *any, rows *sql.Rows, txE error) (_ *any, stopErr error) {
+                       for rows.Next() {
+                               var row int64
+                               stopErr = rows.Scan(&row)
+                               if row != 2 {
+                                       t.Fatal()
+                               }
+                       }
+                       return
+               },
+       })
+       if e := tx1.Fin(); e != nil {
+               t.Fatal(e)
+       }
+
+       time.Sleep(time.Second)
+}