From 86b19965ec5c6f474c66c3633414ebc2c7375ae4 Mon Sep 17 00:00:00 2001 From: qydysky Date: Tue, 25 Apr 2023 23:52:48 +0800 Subject: [PATCH] Fix --- sqlite/Sqlite.go | 132 +++++++++++++++++++++--------------------- sqlite/Sqlite_test.go | 63 +++++++++++++++++++- 2 files changed, 127 insertions(+), 68 deletions(-) diff --git a/sqlite/Sqlite.go b/sqlite/Sqlite.go index 5476114..fa0ffde 100644 --- a/sqlite/Sqlite.go +++ b/sqlite/Sqlite.go @@ -17,9 +17,11 @@ type CanTx interface { } type SqlTx[T any] struct { - tx *sql.Tx - dataP *T - err error + canTx CanTx + ctx context.Context + opts *sql.TxOptions + sqlFuncs []*SqlFunc[T] + dataP *T } type SqlFunc[T any] struct { @@ -35,80 +37,80 @@ type SqlFunc[T any] struct { } func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] { - var sqlTX = SqlTx[T]{} - - if tx, err := canTx.BeginTx(ctx, opts); err != nil { - sqlTX.err = fmt.Errorf("BeginTx; [] >> %s", err) - } else { - sqlTX.tx = tx + return &SqlTx[T]{ + canTx: canTx, + ctx: ctx, + opts: opts, } - - return &sqlTX } func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] { - if t.err != nil { - return t - } - - if sqlf.Ctx == nil { - sqlf.Ctx = context.Background() - } + t.sqlFuncs = append(t.sqlFuncs, &sqlf) + return t +} - switch sqlf.Ty { - case Execf: - if sqlf.BeforeEF != nil { - if datap, err := sqlf.BeforeEF(t.dataP, &sqlf, t.err); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err)) - } else { - t.dataP = datap - } - } - if res, err := t.tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { - if !sqlf.SkipSqlErr { - t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) - } - } else if sqlf.AfterEF != nil { - if datap, err := sqlf.AfterEF(t.dataP, res, t.err); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) - } else { - t.dataP = datap - } - } - case Queryf: - if sqlf.BeforeQF != nil { - if datap, err := sqlf.BeforeQF(t.dataP, &sqlf, t.err); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) - } else { - t.dataP = datap - } - } - if res, err := t.tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { - if !sqlf.SkipSqlErr { - t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) +func (t *SqlTx[T]) Fin() (e error) { + tx, err := t.canTx.BeginTx(t.ctx, t.opts) + if err != nil { + e = fmt.Errorf("BeginTx; [] >> %s", err) + } else { + for i := 0; i < len(t.sqlFuncs); i++ { + sqlf := t.sqlFuncs[i] + if sqlf.Ctx == nil { + sqlf.Ctx = t.ctx } - } else if sqlf.AfterQF != nil { - if datap, err := sqlf.AfterQF(t.dataP, res, t.err); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) - } else { - t.dataP = datap + switch sqlf.Ty { + case Execf: + if sqlf.BeforeEF != nil { + if datap, err := sqlf.BeforeEF(t.dataP, sqlf, e); err != nil { + e = errors.Join(e, fmt.Errorf("%s >> %s", sqlf.Query, err)) + } else { + t.dataP = datap + } + } + if res, err := tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { + if !sqlf.SkipSqlErr { + e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) + } + } else if sqlf.AfterEF != nil { + if datap, err := sqlf.AfterEF(t.dataP, res, e); err != nil { + e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) + } else { + t.dataP = datap + } + } + case Queryf: + if sqlf.BeforeQF != nil { + if datap, err := sqlf.BeforeQF(t.dataP, sqlf, e); err != nil { + e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) + } else { + t.dataP = datap + } + } + if res, err := tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { + if !sqlf.SkipSqlErr { + e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) + } + } else if sqlf.AfterQF != nil { + if datap, err := sqlf.AfterQF(t.dataP, res, e); err != nil { + e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) + } else { + t.dataP = datap + } + } } } } - return t -} - -func (t *SqlTx[T]) Fin() error { - if t.err != nil { - if t.tx != nil { - if err := t.tx.Rollback(); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("Rollback; [] >> %s", err)) + if e != nil { + if tx != nil { + if err := tx.Rollback(); err != nil { + e = errors.Join(e, fmt.Errorf("Rollback; [] >> %s", err)) } } } else { - if err := t.tx.Commit(); err != nil { - t.err = errors.Join(t.err, fmt.Errorf("Commit; [] >> %s", err)) + if err := tx.Commit(); err != nil { + e = errors.Join(e, fmt.Errorf("Commit; [] >> %s", err)) } } - return t.err + return e } diff --git a/sqlite/Sqlite_test.go b/sqlite/Sqlite_test.go index d4fea56..3fabf54 100644 --- a/sqlite/Sqlite_test.go +++ b/sqlite/Sqlite_test.go @@ -8,12 +8,13 @@ import ( "testing" "time" - _ "github.com/mattn/go-sqlite3" + file "github.com/qydysky/part/file" + _ "modernc.org/sqlite" ) func TestMain(t *testing.T) { // connect - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("sqlite", ":memory:") if err != nil { t.Fatal(err) } @@ -111,7 +112,7 @@ func TestMain(t *testing.T) { func TestMain2(t *testing.T) { // connect - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("sqlite", ":memory:") if err != nil { t.Fatal(err) } @@ -151,3 +152,59 @@ func TestMain2(t *testing.T) { t.Fatal(<-res) } } + +func TestMain3(t *testing.T) { + // connect + db, err := sql.Open("sqlite", "test.sqlite3") + if err != nil { + t.Fatal(err) + } + defer db.Close() + defer file.New("test.sqlite3", 0, true).Delete() + + conn, _ := db.Conn(context.Background()) + if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + Ty: Execf, + Query: "create table log123 (msg text)", + }).Fin(); e != nil { + t.Fatal(e) + } + conn.Close() + + tx1 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + Ty: Execf, + Query: "insert into log123 values ('1')", + }) + + tx2 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + Ty: Execf, + Query: "insert into log123 values ('2')", + }) + + if e := tx1.Fin(); e != nil { + t.Log(e) + } + if e := tx2.Fin(); e != nil { + t.Log(e) + } + + tx1 = BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + Ty: Queryf, + Query: "select count(1) as c from log123", + AfterQF: func(_ *any, rows *sql.Rows, txE error) (_ *any, stopErr error) { + for rows.Next() { + var row int64 + stopErr = rows.Scan(&row) + if row != 2 { + t.Fatal() + } + } + return + }, + }) + if e := tx1.Fin(); e != nil { + t.Fatal(e) + } + + time.Sleep(time.Second) +} -- 2.39.2