From: qydysky Date: Tue, 25 Apr 2023 12:48:56 +0000 (+0800) Subject: Fix X-Git-Tag: v0.24.17~1 X-Git-Url: http://127.0.0.1:8081/?a=commitdiff_plain;h=71e3b6caf04f78c8062d7829cf997a4030133949;p=part%2F.git Fix --- diff --git a/sqlite/Sqlite.go b/sqlite/Sqlite.go index 28b6b35..baf0ec4 100644 --- a/sqlite/Sqlite.go +++ b/sqlite/Sqlite.go @@ -5,8 +5,6 @@ import ( "database/sql" "errors" "fmt" - - _ "github.com/mattn/go-sqlite3" ) const ( @@ -14,6 +12,10 @@ const ( Queryf ) +type CanTx interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + type SqlTx[T any] struct { tx *sql.Tx dataP *T @@ -32,14 +34,15 @@ type SqlFunc[T any] struct { AfterQF func(dataP *T, rows *sql.Rows, txE error) (dataPR *T, stopErr error) } -func BeginTx[T any](db *sql.DB, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] { +func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] { var sqlTX = SqlTx[T]{} - if tx, e := db.BeginTx(ctx, opts); e != nil { + if tx, e := canTx.BeginTx(ctx, opts); e != nil { sqlTX.err = e } else { sqlTX.tx = tx } + return &sqlTX } @@ -48,6 +51,10 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] { return t } + if sqlf.Ctx == nil { + sqlf.Ctx = context.Background() + } + switch sqlf.Ty { case Execf: if sqlf.BeforeEF != nil { @@ -59,11 +66,11 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] { } if res, err := t.tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { if !sqlf.SkipSqlErr { - t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err)) + t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } } else if sqlf.AfterEF != nil { if datap, err := sqlf.AfterEF(t.dataP, res, t.err); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err)) + t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } else { t.dataP = datap } @@ -71,18 +78,18 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] { case Queryf: if sqlf.BeforeQF != nil { if datap, err := sqlf.BeforeQF(t.dataP, &sqlf, t.err); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err)) + t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } else { t.dataP = datap } } if res, err := t.tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { if !sqlf.SkipSqlErr { - t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err)) + t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } } else if sqlf.AfterQF != nil { if datap, err := sqlf.AfterQF(t.dataP, res, t.err); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err)) + t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } else { t.dataP = datap } diff --git a/sqlite/Sqlite_test.go b/sqlite/Sqlite_test.go index 26c6da7..d4fea56 100644 --- a/sqlite/Sqlite_test.go +++ b/sqlite/Sqlite_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "sync" "testing" "time" @@ -107,3 +108,46 @@ func TestMain(t *testing.T) { t.Fatal(e) } } + +func TestMain2(t *testing.T) { + // connect + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + db.SetMaxOpenConns(1) + defer db.Close() + + conn, _ := db.Conn(context.Background()) + if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + Ty: Execf, + Query: "create table log123 (msg text)", + }).Fin(); e != nil { + t.Fatal(e) + } + conn.Close() + + var res = make(chan string, 101) + var wg sync.WaitGroup + wg.Add(100) + + for i := 0; i < 100; i++ { + go func() { + x := BeginTx[any](db, context.Background(), &sql.TxOptions{}) + x.Do(SqlFunc[any]{ + Ty: Execf, + Query: "insert into log123 values (?)", + Args: []any{"1"}, + }) + if e := x.Fin(); e != nil { + res <- e.Error() + } + wg.Done() + }() + } + + wg.Wait() + for len(res) > 0 { + t.Fatal(<-res) + } +}