}
type SqlTx[T any] struct {
- tx *sql.Tx
- dataP *T
- err error
+ canTx CanTx
+ ctx context.Context
+ opts *sql.TxOptions
+ sqlFuncs []*SqlFunc[T]
+ dataP *T
}
type SqlFunc[T any] struct {
}
func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] {
- var sqlTX = SqlTx[T]{}
-
- if tx, err := canTx.BeginTx(ctx, opts); err != nil {
- sqlTX.err = fmt.Errorf("BeginTx; [] >> %s", err)
- } else {
- sqlTX.tx = tx
+ return &SqlTx[T]{
+ canTx: canTx,
+ ctx: ctx,
+ opts: opts,
}
-
- return &sqlTX
}
func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] {
- if t.err != nil {
- return t
- }
-
- if sqlf.Ctx == nil {
- sqlf.Ctx = context.Background()
- }
+ t.sqlFuncs = append(t.sqlFuncs, &sqlf)
+ return t
+}
- switch sqlf.Ty {
- case Execf:
- if sqlf.BeforeEF != nil {
- if datap, err := sqlf.BeforeEF(t.dataP, &sqlf, t.err); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
- } else {
- t.dataP = datap
- }
- }
- if res, err := t.tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
- if !sqlf.SkipSqlErr {
- t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
- }
- } else if sqlf.AfterEF != nil {
- if datap, err := sqlf.AfterEF(t.dataP, res, t.err); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
- } else {
- t.dataP = datap
- }
- }
- case Queryf:
- if sqlf.BeforeQF != nil {
- if datap, err := sqlf.BeforeQF(t.dataP, &sqlf, t.err); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
- } else {
- t.dataP = datap
- }
- }
- if res, err := t.tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
- if !sqlf.SkipSqlErr {
- t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+func (t *SqlTx[T]) Fin() (e error) {
+ tx, err := t.canTx.BeginTx(t.ctx, t.opts)
+ if err != nil {
+ e = fmt.Errorf("BeginTx; [] >> %s", err)
+ } else {
+ for i := 0; i < len(t.sqlFuncs); i++ {
+ sqlf := t.sqlFuncs[i]
+ if sqlf.Ctx == nil {
+ sqlf.Ctx = t.ctx
}
- } else if sqlf.AfterQF != nil {
- if datap, err := sqlf.AfterQF(t.dataP, res, t.err); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
- } else {
- t.dataP = datap
+ switch sqlf.Ty {
+ case Execf:
+ if sqlf.BeforeEF != nil {
+ if datap, err := sqlf.BeforeEF(t.dataP, sqlf, e); err != nil {
+ e = errors.Join(e, fmt.Errorf("%s >> %s", sqlf.Query, err))
+ } else {
+ t.dataP = datap
+ }
+ }
+ if res, err := tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
+ if !sqlf.SkipSqlErr {
+ e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+ }
+ } else if sqlf.AfterEF != nil {
+ if datap, err := sqlf.AfterEF(t.dataP, res, e); err != nil {
+ e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+ } else {
+ t.dataP = datap
+ }
+ }
+ case Queryf:
+ if sqlf.BeforeQF != nil {
+ if datap, err := sqlf.BeforeQF(t.dataP, sqlf, e); err != nil {
+ e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+ } else {
+ t.dataP = datap
+ }
+ }
+ if res, err := tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
+ if !sqlf.SkipSqlErr {
+ e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+ }
+ } else if sqlf.AfterQF != nil {
+ if datap, err := sqlf.AfterQF(t.dataP, res, e); err != nil {
+ e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
+ } else {
+ t.dataP = datap
+ }
+ }
}
}
}
- return t
-}
-
-func (t *SqlTx[T]) Fin() error {
- if t.err != nil {
- if t.tx != nil {
- if err := t.tx.Rollback(); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("Rollback; [] >> %s", err))
+ if e != nil {
+ if tx != nil {
+ if err := tx.Rollback(); err != nil {
+ e = errors.Join(e, fmt.Errorf("Rollback; [] >> %s", err))
}
}
} else {
- if err := t.tx.Commit(); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("Commit; [] >> %s", err))
+ if err := tx.Commit(); err != nil {
+ e = errors.Join(e, fmt.Errorf("Commit; [] >> %s", err))
}
}
- return t.err
+ return e
}
"testing"
"time"
- _ "github.com/mattn/go-sqlite3"
+ file "github.com/qydysky/part/file"
+ _ "modernc.org/sqlite"
)
func TestMain(t *testing.T) {
// connect
- db, err := sql.Open("sqlite3", ":memory:")
+ db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
func TestMain2(t *testing.T) {
// connect
- db, err := sql.Open("sqlite3", ":memory:")
+ db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
t.Fatal(<-res)
}
}
+
+func TestMain3(t *testing.T) {
+ // connect
+ db, err := sql.Open("sqlite", "test.sqlite3")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+ defer file.New("test.sqlite3", 0, true).Delete()
+
+ conn, _ := db.Conn(context.Background())
+ if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+ Ty: Execf,
+ Query: "create table log123 (msg text)",
+ }).Fin(); e != nil {
+ t.Fatal(e)
+ }
+ conn.Close()
+
+ tx1 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+ Ty: Execf,
+ Query: "insert into log123 values ('1')",
+ })
+
+ tx2 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+ Ty: Execf,
+ Query: "insert into log123 values ('2')",
+ })
+
+ if e := tx1.Fin(); e != nil {
+ t.Log(e)
+ }
+ if e := tx2.Fin(); e != nil {
+ t.Log(e)
+ }
+
+ tx1 = BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+ Ty: Queryf,
+ Query: "select count(1) as c from log123",
+ AfterQF: func(_ *any, rows *sql.Rows, txE error) (_ *any, stopErr error) {
+ for rows.Next() {
+ var row int64
+ stopErr = rows.Scan(&row)
+ if row != 2 {
+ t.Fatal()
+ }
+ }
+ return
+ },
+ })
+ if e := tx1.Fin(); e != nil {
+ t.Fatal(e)
+ }
+
+ time.Sleep(time.Second)
+}