From: qydysky Date: Tue, 2 May 2023 08:56:03 +0000 (+0800) Subject: Improve X-Git-Tag: v0.25.2 X-Git-Url: http://127.0.0.1:8081/?a=commitdiff_plain;h=d6f971b72692663e59a93effb3d1063be8f43080;p=part%2F.git Improve --- diff --git a/sql/Sql.go b/sql/Sql.go index 4618270..b3dfbc6 100644 --- a/sql/Sql.go +++ b/sql/Sql.go @@ -6,10 +6,12 @@ import ( "errors" "fmt" "reflect" + "strings" ) const ( - Execf = iota + null = iota + Execf Queryf ) @@ -17,6 +19,10 @@ type CanTx interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } +type BeforeF[T any] func(dataP *T, sqlf *SqlFunc[T], txE error) (dataPR *T, stopErr error) +type AfterEF[T any] func(dataP *T, result sql.Result, txE error) (dataPR *T, stopErr error) +type AfterQF[T any] func(dataP *T, rows *sql.Rows, txE error) (dataPR *T, stopErr error) + type SqlTx[T any] struct { canTx CanTx ctx context.Context @@ -32,18 +38,20 @@ type SqlFunc[T any] struct { Query string Args []any SkipSqlErr bool - BeforeEF func(dataP *T, sqlf *SqlFunc[T], txE error) (dataPR *T, stopErr error) - BeforeQF func(dataP *T, sqlf *SqlFunc[T], txE error) (dataPR *T, stopErr error) - AfterEF func(dataP *T, result sql.Result, txE error) (dataPR *T, stopErr error) - AfterQF func(dataP *T, rows *sql.Rows, txE error) (dataPR *T, stopErr error) + beforeF BeforeF[T] + afterEF AfterEF[T] + afterQF AfterQF[T] } -func BeginTx[T any](canTx CanTx, ctx context.Context, opts *sql.TxOptions) *SqlTx[T] { - return &SqlTx[T]{ +func BeginTx[T any](canTx CanTx, ctx context.Context, opts ...*sql.TxOptions) *SqlTx[T] { + var tx = SqlTx[T]{ canTx: canTx, ctx: ctx, - opts: opts, } + if len(opts) > 0 { + tx.opts = opts[0] + } + return &tx } func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] { @@ -51,9 +59,45 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] { return t } -func (t *SqlTx[T]) Fin() (e error) { +func (t *SqlTx[T]) DoPlaceHolder(sqlf SqlFunc[T], ptr any) *SqlTx[T] { + dataR := reflect.ValueOf(ptr).Elem() + for i := 0; i < dataR.NumField(); i++ { + field := dataR.Field(i) + if field.IsValid() && field.CanSet() { + replaceS := "{" + dataR.Type().Field(i).Name + "}" + if strings.Contains(sqlf.Query, replaceS) { + sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, "?") + sqlf.Args = append(sqlf.Args, field.Interface()) + } + } + } + return t.Do(sqlf) +} + +func (t *SqlTx[T]) BeforeF(f BeforeF[T]) *SqlTx[T] { + if len(t.sqlFuncs) > 0 { + t.sqlFuncs[len(t.sqlFuncs)-1].beforeF = f + } + return t +} + +func (t *SqlTx[T]) AfterEF(f AfterEF[T]) *SqlTx[T] { + if len(t.sqlFuncs) > 0 { + t.sqlFuncs[len(t.sqlFuncs)-1].afterEF = f + } + return t +} + +func (t *SqlTx[T]) AfterQF(f AfterQF[T]) *SqlTx[T] { + if len(t.sqlFuncs) > 0 { + t.sqlFuncs[len(t.sqlFuncs)-1].afterQF = f + } + return t +} + +func (t *SqlTx[T]) Fin() (dataP *T, e error) { if t.fin { - return fmt.Errorf("BeginTx; [] >> fin") + return nil, fmt.Errorf("BeginTx; [] >> fin") } tx, err := t.canTx.BeginTx(t.ctx, t.opts) @@ -62,43 +106,50 @@ func (t *SqlTx[T]) Fin() (e error) { } else { for i := 0; i < len(t.sqlFuncs); i++ { sqlf := t.sqlFuncs[i] + + if sqlf.beforeF != nil { + if datap, err := sqlf.beforeF(t.dataP, sqlf, e); err != nil { + e = errors.Join(e, fmt.Errorf("%s; >> %s", sqlf.Query, err)) + } else { + t.dataP = datap + } + } + + if strings.TrimSpace(sqlf.Query) == "" { + continue + } + if sqlf.Ctx == nil { sqlf.Ctx = t.ctx } + + if sqlf.Ty == null { + sqlf.Ty = Execf + if uquery := strings.ToUpper(strings.TrimSpace(sqlf.Query)); strings.HasPrefix(uquery, "SELECT") { + sqlf.Ty = Queryf + } + } + switch sqlf.Ty { case Execf: - if sqlf.BeforeEF != nil { - if datap, err := sqlf.BeforeEF(t.dataP, sqlf, e); err != nil { - e = errors.Join(e, fmt.Errorf("%s >> %s", sqlf.Query, err)) - } else { - t.dataP = datap - } - } if res, err := tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { if !sqlf.SkipSqlErr { e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } - } else if sqlf.AfterEF != nil { - if datap, err := sqlf.AfterEF(t.dataP, res, e); err != nil { + } else if sqlf.afterEF != nil { + if datap, err := sqlf.afterEF(t.dataP, res, e); err != nil { e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } else { t.dataP = datap } } case Queryf: - if sqlf.BeforeQF != nil { - if datap, err := sqlf.BeforeQF(t.dataP, sqlf, e); err != nil { - e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) - } else { - t.dataP = datap - } - } if res, err := tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { if !sqlf.SkipSqlErr { e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } - } else if sqlf.AfterQF != nil { - if datap, err := sqlf.AfterQF(t.dataP, res, e); err != nil { + } else if sqlf.afterQF != nil { + if datap, err := sqlf.afterQF(t.dataP, res, e); err != nil { e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } else { t.dataP = datap @@ -119,14 +170,14 @@ func (t *SqlTx[T]) Fin() (e error) { } } t.fin = true - return e + return t.dataP, e } func IsFin[T any](t *SqlTx[T]) bool { return t == nil || t.fin } -func DealRows[T any](rows *sql.Rows, createF func() T) ([]T, error) { +func DealRows[T any](rows *sql.Rows, createF func() T) (*[]T, error) { rowNames, err := rows.Columns() if err != nil { return nil, err @@ -144,26 +195,38 @@ func DealRows[T any](rows *sql.Rows, createF func() T) ([]T, error) { return nil, err } - var rowT = createF() - refV := reflect.ValueOf(&rowT) + var stu = createF() for i := 0; i < len(rowNames); i++ { - v := refV.Elem().FieldByName(rowNames[i]) + v := reflect.ValueOf(&stu).Elem().FieldByName(rowNames[i]) if v.IsValid() { + refT := reflect.TypeOf(&stu).Elem() if v.CanSet() { val := reflect.ValueOf(*rowP[i].(*any)) - if val.Kind() == v.Kind() { + if reflect.TypeOf(*rowP[i].(*any)).ConvertibleTo(v.Type()) { v.Set(val) } else { - return nil, fmt.Errorf("reflectFail:%s KindNotMatch:%v !> %v", rowNames[i], val.Kind(), v.Kind()) + return nil, fmt.Errorf("DealRows:KindNotMatch:[sql] %v !> [%s.%s] %v", val.Kind(), refT.Name(), rowNames[i], v.Type()) } } else { - return nil, fmt.Errorf("reflectFail:%s CanSet:%v", rowNames[i], v.CanSet()) + return nil, fmt.Errorf("DealRows:%s.%s CanSet:%v", refT.Name(), rowNames[i], v.CanSet()) } } } - res = append(res, rowT) + res = append(res, stu) } - return res, nil + return &res, nil +} + +func SimpleQ[T any](canTx CanTx, query string, ptr *T) (*[]T, error) { + tx := BeginTx[[]T](canTx, context.Background()) + tx.DoPlaceHolder(SqlFunc[[]T]{Query: query}, ptr) + tx.AfterQF(func(_ *[]T, rows *sql.Rows, txE error) (dataPR *[]T, stopErr error) { + if txE != nil { + return nil, txE + } + return DealRows(rows, func() T { return *ptr }) + }) + return tx.Fin() } diff --git a/sql/Sql_test.go b/sql/Sql_test.go index c1706ad..8b96d92 100644 --- a/sql/Sql_test.go +++ b/sql/Sql_test.go @@ -55,57 +55,54 @@ func TestMain(t *testing.T) { Ty: Queryf, Ctx: ctx, Query: "select msg from log", - AfterQF: func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) { - names := make([]string, 0) - for rows.Next() { - var name string - if err := rows.Scan(&name); err != nil { - return nil, err - } - names = append(names, name) + }).AfterQF(func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) { + names := make([]string, 0) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err } - rows.Close() + names = append(names, name) + } + rows.Close() - if len(names) != 1 || dateTime != names[0] { - return nil, errors.New("no") - } + if len(names) != 1 || dateTime != names[0] { + return nil, errors.New("no") + } - return &names, nil - }, + return &names, nil }) tx = tx.Do(SqlFunc[[]string]{ Ty: Execf, Ctx: ctx, - BeforeEF: func(dataP *[]string, sqlf *SqlFunc[[]string], txE error) (dataPR *[]string, stopErr error) { - sqlf.Query = "insert into log2 values (?)" - sqlf.Args = append(sqlf.Args, (*dataP)[0]) - return dataP, nil - }, + }).BeforeF(func(dataP *[]string, sqlf *SqlFunc[[]string], txE error) (dataPR *[]string, stopErr error) { + sqlf.Query = "insert into log2 values (?)" + sqlf.Args = append(sqlf.Args, (*dataP)[0]) + return dataP, nil }) tx = tx.Do(SqlFunc[[]string]{ Ty: Queryf, Ctx: ctx, Query: "select msg from log2", - AfterQF: func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) { - names := make([]string, 0) - for rows.Next() { - var name string - if err := rows.Scan(&name); err != nil { - return nil, err - } - names = append(names, name) + }).AfterQF(func(dataP *[]string, rows *sql.Rows, err error) (dataPR *[]string, stopErr error) { + names := make([]string, 0) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err } - rows.Close() + names = append(names, name) + } + rows.Close() - if len(names) != 1 || dateTime != names[0] { - return nil, errors.New("no2") - } + if len(names) != 1 || dateTime != names[0] { + return nil, errors.New("no2") + } - return &names, nil - }, + return &names, nil }) - if e := tx.Fin(); e != nil { + if _, e := tx.Fin(); e != nil { t.Fatal(e) } } @@ -120,7 +117,7 @@ func TestMain2(t *testing.T) { defer db.Close() conn, _ := db.Conn(context.Background()) - if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + if _, e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ Ty: Execf, Query: "create table log123 (msg text)", }).Fin(); e != nil { @@ -140,7 +137,7 @@ func TestMain2(t *testing.T) { Query: "insert into log123 values (?)", Args: []any{"1"}, }) - if e := x.Fin(); e != nil { + if _, e := x.Fin(); e != nil { res <- e.Error() } wg.Done() @@ -162,56 +159,57 @@ func TestMain3(t *testing.T) { defer db.Close() defer file.New("test.sqlite3", 0, true).Delete() - conn, _ := db.Conn(context.Background()) - if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ - Ty: Execf, - Query: "create table log123 (msg text,msg2 text)", - }).Fin(); e != nil { - t.Fatal(e) + { + tx := BeginTx[any](db, context.Background()) + tx.Do(SqlFunc[any]{Query: "create table log123 (msg INT,msg2 text)"}) + if _, e := tx.Fin(); e != nil { + t.Fatal(e) + } } - conn.Close() - - tx1 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ - Ty: Execf, - Query: "insert into log123 values ('1','a')", - }) - tx2 := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ - Ty: Execf, - Query: "insert into log123 values ('2','b')", - }) + type logg struct { + Msg int64 + Msg2 string + } - if e := tx1.Fin(); e != nil { - t.Log(e) + insertLog123 := SqlFunc[any]{Query: "insert into log123 values ({Msg},{Msg2})"} + { + tx := BeginTx[any](db, context.Background()) + tx.DoPlaceHolder(insertLog123, &logg{Msg: 1, Msg2: "a"}) + tx.DoPlaceHolder(insertLog123, &logg{Msg: 2, Msg2: "b"}) + if _, e := tx.Fin(); e != nil { + t.Log(e) + } + if _, err := SimpleQ(db, "insert into log123 values ({Msg},{Msg2})", &logg{Msg: 3, Msg2: "b"}); err != nil { + t.Fatal(err) + } } - if e := tx2.Fin(); e != nil { - t.Log(e) + { + selectLog123 := SqlFunc[[]logg]{Query: "select msg as Msg, msg2 as Msg2 from log123 where msg = {Msg}"} + tx := BeginTx[[]logg](db, context.Background()) + tx.DoPlaceHolder(selectLog123, &logg{Msg: 2, Msg2: "b"}) + tx.AfterQF(func(_ *[]logg, rows *sql.Rows, txE error) (dataPR *[]logg, stopErr error) { + return DealRows(rows, func() logg { return logg{} }) + }) + if v, e := tx.Fin(); e != nil { + t.Fatal(e) + } else { + if (*v)[0].Msg2 != "b" || (*v)[0].Msg != 2 { + t.Fatal() + } + } } - - tx1 = BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ - Ty: Queryf, - Query: "select 1 as Msg, msg2 as Msg2 from log123", - AfterQF: func(_ *any, rows *sql.Rows, txE error) (_ *any, stopErr error) { - type logg struct { - Msg int64 - Msg2 string + { + if v, err := SimpleQ(db, "select msg as Msg, msg2 as Msg2 from log123 where msg2 = {Msg2}", &logg{Msg2: "b"}); err != nil { + t.Fatal(err) + } else { + if (*v)[0].Msg2 != "b" || (*v)[0].Msg != 2 { + t.Fatal() } - - if v, err := DealRows(rows, func() logg { return logg{} }); err != nil { - return nil, err - } else { - if v[0].Msg2 != "a" { - t.Fatal() - } - if v[1].Msg2 != "b" { - t.Fatal() - } + if (*v)[1].Msg2 != "b" || (*v)[1].Msg != 3 { + t.Fatal() } - return - }, - }) - if e := tx1.Fin(); e != nil { - t.Fatal(e) + } } } @@ -225,7 +223,7 @@ func TestMain4(t *testing.T) { defer file.New("test.sqlite3", 0, true).Delete() conn, _ := db.Conn(context.Background()) - if e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + if _, e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ Ty: Execf, Query: "create table log123 (msg text)", }).Fin(); e != nil { @@ -238,7 +236,7 @@ func TestMain4(t *testing.T) { Query: "insert into log123 values ('1')", }) - if e := tx1.Fin(); e != nil { + if _, e := tx1.Fin(); e != nil { t.Log(e) }