]> 127.0.0.1 Git - part/.git/commitdiff
Fix
authorqydysky <qydysky@foxmail.com>
Tue, 25 Apr 2023 12:48:56 +0000 (20:48 +0800)
committerqydysky <qydysky@foxmail.com>
Tue, 25 Apr 2023 12:48:56 +0000 (20:48 +0800)
sqlite/Sqlite.go
sqlite/Sqlite_test.go

index 28b6b356f203422cea0bb92bd9ee3f9880718d68..baf0ec4ef0a623356c7b575edcfd228c698d159d 100644 (file)
@@ -5,8 +5,6 @@ import (
        "database/sql"
        "errors"
        "fmt"
-
-       _ "github.com/mattn/go-sqlite3"
 )
 
 const (
@@ -14,6 +12,10 @@ const (
        Queryf
 )
 
+type CanTx interface {
+       BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
+}
+
 type SqlTx[T any] struct {
        tx    *sql.Tx
        dataP *T
@@ -32,14 +34,15 @@ type SqlFunc[T any] struct {
        AfterQF    func(dataP *T, rows *sql.Rows, txE error) (dataPR *T, stopErr error)
 }
 
-func BeginTx[T any](db *sql.DB, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] {
+func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] {
        var sqlTX = SqlTx[T]{}
 
-       if tx, e := db.BeginTx(ctx, opts); e != nil {
+       if tx, e := canTx.BeginTx(ctx, opts); e != nil {
                sqlTX.err = e
        } else {
                sqlTX.tx = tx
        }
+
        return &sqlTX
 }
 
@@ -48,6 +51,10 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
                return t
        }
 
+       if sqlf.Ctx == nil {
+               sqlf.Ctx = context.Background()
+       }
+
        switch sqlf.Ty {
        case Execf:
                if sqlf.BeforeEF != nil {
@@ -59,11 +66,11 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
                }
                if res, err := t.tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
                        if !sqlf.SkipSqlErr {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                        }
                } else if sqlf.AfterEF != nil {
                        if datap, err := sqlf.AfterEF(t.dataP, res, t.err); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                        } else {
                                t.dataP = datap
                        }
@@ -71,18 +78,18 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
        case Queryf:
                if sqlf.BeforeQF != nil {
                        if datap, err := sqlf.BeforeQF(t.dataP, &sqlf, t.err); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                        } else {
                                t.dataP = datap
                        }
                }
                if res, err := t.tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
                        if !sqlf.SkipSqlErr {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                        }
                } else if sqlf.AfterQF != nil {
                        if datap, err := sqlf.AfterQF(t.dataP, res, t.err); err != nil {
-                               t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+                               t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
                        } else {
                                t.dataP = datap
                        }
index 26c6da7ba22f9b2707b1e65e3ba4aa9a586d7b49..d4fea56265acdad80085175884ae034e3d1417c6 100644 (file)
@@ -4,6 +4,7 @@ import (
        "context"
        "database/sql"
        "errors"
+       "sync"
        "testing"
        "time"
 
@@ -107,3 +108,46 @@ func TestMain(t *testing.T) {
                t.Fatal(e)
        }
 }
+
+func TestMain2(t *testing.T) {
+       // connect
+       db, err := sql.Open("sqlite3", ":memory:")
+       if err != nil {
+               t.Fatal(err)
+       }
+       db.SetMaxOpenConns(1)
+       defer db.Close()
+
+       conn, _ := db.Conn(context.Background())
+       if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+               Ty:    Execf,
+               Query: "create table log123 (msg text)",
+       }).Fin(); e != nil {
+               t.Fatal(e)
+       }
+       conn.Close()
+
+       var res = make(chan string, 101)
+       var wg sync.WaitGroup
+       wg.Add(100)
+
+       for i := 0; i < 100; i++ {
+               go func() {
+                       x := BeginTx[any](db, context.Background(), &sql.TxOptions{})
+                       x.Do(SqlFunc[any]{
+                               Ty:    Execf,
+                               Query: "insert into log123 values (?)",
+                               Args:  []any{"1"},
+                       })
+                       if e := x.Fin(); e != nil {
+                               res <- e.Error()
+                       }
+                       wg.Done()
+               }()
+       }
+
+       wg.Wait()
+       for len(res) > 0 {
+               t.Fatal(<-res)
+       }
+}