From 8e4660c16bc9b46a46851c715fbfa161f55e4b68 Mon Sep 17 00:00:00 2001 From: qydysky Date: Sun, 18 Jun 2023 13:01:56 +0000 Subject: [PATCH] sql support postgresql --- go.mod | 1 + go.sum | 2 ++ sql/Sql.go | 68 ++++++++++++++++++++++++++++++++++++------------- sql/Sql_test.go | 57 +++++++++++++++++++++++++++++++++++++++-- 4 files changed, 109 insertions(+), 19 deletions(-) diff --git a/go.mod b/go.mod index ba4797a..5179b44 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( require ( github.com/dustin/go-humanize v1.0.1 github.com/go-ole/go-ole v1.2.6 // indirect + github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.8.2 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.8.0 // indirect diff --git a/go.sum b/go.sum index 976bffd..9d742ef 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= diff --git a/sql/Sql.go b/sql/Sql.go index b3dfbc6..15b4784 100644 --- a/sql/Sql.go +++ b/sql/Sql.go @@ -59,14 +59,20 @@ func (t *SqlTx[T]) Do(sqlf SqlFunc[T]) *SqlTx[T] { return t } -func (t *SqlTx[T]) DoPlaceHolder(sqlf SqlFunc[T], ptr any) *SqlTx[T] { +func (t *SqlTx[T]) DoPlaceHolder(sqlf SqlFunc[T], ptr any, replaceF ...func(index int, holder string) (replaceTo string)) *SqlTx[T] { dataR := reflect.ValueOf(ptr).Elem() + index := 0 for i := 0; i < dataR.NumField(); i++ { field := dataR.Field(i) if field.IsValid() && field.CanSet() { replaceS := "{" + dataR.Type().Field(i).Name + "}" if strings.Contains(sqlf.Query, replaceS) { - sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, "?") + if len(replaceF) == 0 { + sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, "?") + } else { + sqlf.Query = strings.ReplaceAll(sqlf.Query, replaceS, replaceF[0](index, replaceS)) + index += 1 + } sqlf.Args = append(sqlf.Args, field.Interface()) } } @@ -100,9 +106,12 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) { return nil, fmt.Errorf("BeginTx; [] >> fin") } + var hasErr bool + tx, err := t.canTx.BeginTx(t.ctx, t.opts) if err != nil { e = fmt.Errorf("BeginTx; [] >> %s", err) + hasErr = true } else { for i := 0; i < len(t.sqlFuncs); i++ { sqlf := t.sqlFuncs[i] @@ -110,6 +119,7 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) { if sqlf.beforeF != nil { if datap, err := sqlf.beforeF(t.dataP, sqlf, e); err != nil { e = errors.Join(e, fmt.Errorf("%s; >> %s", sqlf.Query, err)) + hasErr = true } else { t.dataP = datap } @@ -133,11 +143,13 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) { switch sqlf.Ty { case Execf: if res, err := tx.ExecContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { + hasErr = true if !sqlf.SkipSqlErr { e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } } else if sqlf.afterEF != nil { if datap, err := sqlf.afterEF(t.dataP, res, e); err != nil { + hasErr = true e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } else { t.dataP = datap @@ -145,11 +157,13 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) { } case Queryf: if res, err := tx.QueryContext(sqlf.Ctx, sqlf.Query, sqlf.Args...); err != nil { + hasErr = true if !sqlf.SkipSqlErr { e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } } else if sqlf.afterQF != nil { if datap, err := sqlf.afterQF(t.dataP, res, e); err != nil { + hasErr = true e = errors.Join(e, fmt.Errorf("%s; %s >> %s", sqlf.Query, sqlf.Args, err)) } else { t.dataP = datap @@ -158,7 +172,7 @@ func (t *SqlTx[T]) Fin() (dataP *T, e error) { } } } - if e != nil { + if hasErr { if tx != nil { if err := tx.Rollback(); err != nil { e = errors.Join(e, fmt.Errorf("Rollback; [] >> %s", err)) @@ -177,7 +191,7 @@ func IsFin[T any](t *SqlTx[T]) bool { return t == nil || t.fin } -func DealRows[T any](rows *sql.Rows, createF func() T) (*[]T, error) { +func DealRows[T any](rows *sql.Rows, newT func() T) (*[]T, error) { rowNames, err := rows.Columns() if err != nil { return nil, err @@ -195,30 +209,50 @@ func DealRows[T any](rows *sql.Rows, createF func() T) (*[]T, error) { return nil, err } - var stu = createF() + var ( + stu = newT() + refV = reflect.ValueOf(&stu).Elem() + refT = reflect.TypeOf(&stu).Elem() + FieldMap = make(map[string]*reflect.Value) + ) + + for NumField := refV.NumField() - 1; NumField >= 0; NumField-- { + field := refV.Field(NumField) + fieldT := refT.Field(NumField) + fieldTName := fieldT.Name + if value, ok := fieldT.Tag.Lookup("sql"); ok { + fieldTName = value + } + if !field.IsValid() { + continue + } + if !field.CanSet() { + FieldMap[strings.ToUpper(fieldTName)] = nil + continue + } + FieldMap[strings.ToUpper(fieldTName)] = &field + } + for i := 0; i < len(rowNames); i++ { - v := reflect.ValueOf(&stu).Elem().FieldByName(rowNames[i]) - if v.IsValid() { - refT := reflect.TypeOf(&stu).Elem() - if v.CanSet() { - val := reflect.ValueOf(*rowP[i].(*any)) - if reflect.TypeOf(*rowP[i].(*any)).ConvertibleTo(v.Type()) { - v.Set(val) - } else { - return nil, fmt.Errorf("DealRows:KindNotMatch:[sql] %v !> [%s.%s] %v", val.Kind(), refT.Name(), rowNames[i], v.Type()) - } + if field, ok := FieldMap[strings.ToUpper(rowNames[i])]; ok { + if field == nil { + return nil, fmt.Errorf("DealRows:%s.%s CanSet:false", refT.Name(), rowNames[i]) + } + val := reflect.ValueOf(*rowP[i].(*any)) + if reflect.TypeOf(*rowP[i].(*any)).ConvertibleTo(field.Type()) { + field.Set(val) } else { - return nil, fmt.Errorf("DealRows:%s.%s CanSet:%v", refT.Name(), rowNames[i], v.CanSet()) + return nil, fmt.Errorf("DealRows:KindNotMatch:[sql] %v !> [%s.%s] %v", val.Kind(), refT.Name(), rowNames[i], field.Type()) } } } res = append(res, stu) - } return &res, nil } +// for mysql,oracle not postgresql func SimpleQ[T any](canTx CanTx, query string, ptr *T) (*[]T, error) { tx := BeginTx[[]T](canTx, context.Background()) tx.DoPlaceHolder(SqlFunc[[]T]{Query: query}, ptr) diff --git a/sql/Sql_test.go b/sql/Sql_test.go index 8b96d92..2c92aeb 100644 --- a/sql/Sql_test.go +++ b/sql/Sql_test.go @@ -4,10 +4,12 @@ import ( "context" "database/sql" "errors" + "fmt" "sync" "testing" "time" + _ "github.com/lib/pq" file "github.com/qydysky/part/file" _ "modernc.org/sqlite" ) @@ -157,7 +159,9 @@ func TestMain3(t *testing.T) { t.Fatal(err) } defer db.Close() - defer file.New("test.sqlite3", 0, true).Delete() + defer func() { + _ = file.New("test.sqlite3", 0, true).Delete() + }() { tx := BeginTx[any](db, context.Background()) @@ -220,7 +224,9 @@ func TestMain4(t *testing.T) { t.Fatal(err) } defer db.Close() - defer file.New("test.sqlite3", 0, true).Delete() + defer func() { + _ = file.New("test.sqlite3", 0, true).Delete() + }() conn, _ := db.Conn(context.Background()) if _, e := BeginTx[any](conn, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ @@ -244,3 +250,50 @@ func TestMain4(t *testing.T) { t.Fatal() } } + +func Local_TestPostgresql(t *testing.T) { + // connect + db, err := sql.Open("postgres", "postgres://postgres:qydysky@192.168.31.103:5432/postgres?sslmode=disable") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + type test1 struct { + Created string `sql:"sss"` + } + + if _, e := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + Query: "create table test (created varchar(20))", + SkipSqlErr: true, + }).Fin(); e != nil { + t.Fatal(e) + } + + if _, e := BeginTx[any](db, context.Background(), &sql.TxOptions{}).DoPlaceHolder(SqlFunc[any]{ + Query: "insert into test (created) values ({Created})", + }, &test1{"1"}, func(index int, holder string) (replaceTo string) { + return fmt.Sprintf("$%d", index+1) + }).Fin(); e != nil { + t.Fatal(e) + } + + if _, e := BeginTx[any](db, context.Background(), &sql.TxOptions{}).Do(SqlFunc[any]{ + Query: "select created as sss from test", + afterQF: func(_ *any, rows *sql.Rows, txE error) (dataPR *any, stopErr error) { + if rowsP, e := DealRows[test1](rows, func() test1 { return test1{} }); e != nil { + return nil, e + } else { + if len(*rowsP) != 1 { + return nil, errors.New("no match") + } + if (*rowsP)[0].Created != "1" { + return nil, errors.New("no match") + } + } + return nil, nil + }, + }).Fin(); e != nil { + t.Fatal(e) + } +} -- 2.39.2