"database/sql"
"errors"
"fmt"
-
- _ "github.com/mattn/go-sqlite3"
)
const (
Queryf
)
+type CanTx interface {
+ BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
+}
+
type SqlTx[T any] struct {
tx *sql.Tx
dataP *T
AfterQF func(dataP *T, rows *sql.Rows, txE error) (dataPR *T, stopErr error)
}
-func BeginTx[T any](db *sql.DB, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] {
+func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] {
var sqlTX = SqlTx[T]{}
- if tx, e := db.BeginTx(ctx, opts); e != nil {
+ if tx, e := canTx.BeginTx(ctx, opts); e != nil {
sqlTX.err = e
} else {
sqlTX.tx = tx
}
+
return &sqlTX
}
return t
}
+ if sqlf.Ctx == nil {
+ sqlf.Ctx = context.Background()
+ }
+
switch sqlf.Ty {
case Execf:
if sqlf.BeforeEF != nil {
}
if res, err := t.tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
if !sqlf.SkipSqlErr {
- t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+ t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
}
} else if sqlf.AfterEF != nil {
if datap, err := sqlf.AfterEF(t.dataP, res, t.err); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+ t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
} else {
t.dataP = datap
}
case Queryf:
if sqlf.BeforeQF != nil {
if datap, err := sqlf.BeforeQF(t.dataP, &sqlf, t.err); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+ t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
} else {
t.dataP = datap
}
}
if res, err := t.tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil {
if !sqlf.SkipSqlErr {
- t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+ t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
}
} else if sqlf.AfterQF != nil {
if datap, err := sqlf.AfterQF(t.dataP, res, t.err); err != nil {
- t.err = errors.Join(t.err, fmt.Errorf("%s >> %s", sqlf.Query, err))
+ t.err = errors.Join(t.err, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err))
} else {
t.dataP = datap
}
"context"
"database/sql"
"errors"
+ "sync"
"testing"
"time"
t.Fatal(e)
}
}
+
+func TestMain2(t *testing.T) {
+ // connect
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ db.SetMaxOpenConns(1)
+ defer db.Close()
+
+ conn, _ := db.Conn(context.Background())
+ if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{
+ Ty: Execf,
+ Query: "create table log123 (msg text)",
+ }).Fin(); e != nil {
+ t.Fatal(e)
+ }
+ conn.Close()
+
+ var res = make(chan string, 101)
+ var wg sync.WaitGroup
+ wg.Add(100)
+
+ for i := 0; i < 100; i++ {
+ go func() {
+ x := BeginTx[any](db, context.Background(), &sql.TxOptions{})
+ x.Do(SqlFunc[any]{
+ Ty: Execf,
+ Query: "insert into log123 values (?)",
+ Args: []any{"1"},
+ })
+ if e := x.Fin(); e != nil {
+ res <- e.Error()
+ }
+ wg.Done()
+ }()
+ }
+
+ wg.Wait()
+ for len(res) > 0 {
+ t.Fatal(<-res)
+ }
+}