Compare commits
23 Commits
tools/goct
...
v1.5.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
421e6617b1 | ||
|
|
0ee7a271d3 | ||
|
|
af022b9655 | ||
|
|
98d46261d9 | ||
|
|
4222fd97bc | ||
|
|
814852f0b8 | ||
|
|
ded2888759 | ||
|
|
18d66a795d | ||
|
|
4211672bfd | ||
|
|
68df0c3620 | ||
|
|
5e435b6a76 | ||
|
|
0dcede6457 | ||
|
|
cc21f5fae2 | ||
|
|
b22ad50d59 | ||
|
|
974252980c | ||
|
|
8d83986d27 | ||
|
|
6821b0a7dd | ||
|
|
1ba1724c65 | ||
|
|
ca5a7df5b0 | ||
|
|
69a3024853 | ||
|
|
fd3abf3717 | ||
|
|
99b3750d10 | ||
|
|
33f6d7ebb8 |
@@ -29,6 +29,8 @@ func NewSafeMap() *SafeMap {
|
||||
// Del deletes the value with the given key from m.
|
||||
func (m *SafeMap) Del(key any) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if _, ok := m.dirtyOld[key]; ok {
|
||||
delete(m.dirtyOld, key)
|
||||
m.deletionOld++
|
||||
@@ -52,7 +54,6 @@ func (m *SafeMap) Del(key any) {
|
||||
m.dirtyNew = make(map[any]any)
|
||||
m.deletionNew = 0
|
||||
}
|
||||
m.lock.Unlock()
|
||||
}
|
||||
|
||||
// Get gets the value with the given key from m.
|
||||
@@ -89,6 +90,8 @@ func (m *SafeMap) Range(f func(key, val any) bool) {
|
||||
// Set sets the value into m with the given key.
|
||||
func (m *SafeMap) Set(key, value any) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if m.deletionOld <= maxDeletion {
|
||||
if _, ok := m.dirtyNew[key]; ok {
|
||||
delete(m.dirtyNew, key)
|
||||
@@ -102,7 +105,6 @@ func (m *SafeMap) Set(key, value any) {
|
||||
}
|
||||
m.dirtyNew[key] = value
|
||||
}
|
||||
m.lock.Unlock()
|
||||
}
|
||||
|
||||
// Size returns the size of m.
|
||||
|
||||
@@ -147,3 +147,65 @@ func TestSafeMap_Range(t *testing.T) {
|
||||
assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
|
||||
assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
|
||||
}
|
||||
|
||||
func TestSetManyTimes(t *testing.T) {
|
||||
const iteration = maxDeletion * 2
|
||||
m := NewSafeMap()
|
||||
for i := 0; i < iteration; i++ {
|
||||
m.Set(i, i)
|
||||
if i%3 == 0 {
|
||||
m.Del(i / 2)
|
||||
}
|
||||
}
|
||||
var count int
|
||||
m.Range(func(k, v any) bool {
|
||||
count++
|
||||
return count < maxDeletion/2
|
||||
})
|
||||
assert.Equal(t, maxDeletion/2, count)
|
||||
for i := 0; i < iteration; i++ {
|
||||
m.Set(i, i)
|
||||
if i%3 == 0 {
|
||||
m.Del(i / 2)
|
||||
}
|
||||
}
|
||||
for i := 0; i < iteration; i++ {
|
||||
m.Set(i, i)
|
||||
if i%3 == 0 {
|
||||
m.Del(i / 2)
|
||||
}
|
||||
}
|
||||
for i := 0; i < iteration; i++ {
|
||||
m.Set(i, i)
|
||||
if i%3 == 0 {
|
||||
m.Del(i / 2)
|
||||
}
|
||||
}
|
||||
|
||||
count = 0
|
||||
m.Range(func(k, v any) bool {
|
||||
count++
|
||||
return count < maxDeletion
|
||||
})
|
||||
assert.Equal(t, maxDeletion, count)
|
||||
}
|
||||
|
||||
func TestSetManyTimesNew(t *testing.T) {
|
||||
m := NewSafeMap()
|
||||
for i := 0; i < maxDeletion*3; i++ {
|
||||
m.Set(i, i)
|
||||
}
|
||||
for i := 0; i < maxDeletion*2; i++ {
|
||||
m.Del(i)
|
||||
}
|
||||
for i := 0; i < maxDeletion*3; i++ {
|
||||
m.Set(i+maxDeletion*3, i+maxDeletion*3)
|
||||
}
|
||||
for i := 0; i < maxDeletion*2; i++ {
|
||||
m.Del(i + maxDeletion*2)
|
||||
}
|
||||
for i := 0; i < maxDeletion-copyThreshold+1; i++ {
|
||||
m.Del(i + maxDeletion*2)
|
||||
}
|
||||
assert.Equal(t, 0, len(m.dirtyNew))
|
||||
}
|
||||
|
||||
12
core/iox/nopcloser_test.go
Normal file
12
core/iox/nopcloser_test.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package iox
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNopCloser(t *testing.T) {
|
||||
closer := NopCloser(nil)
|
||||
assert.NoError(t, closer.Close())
|
||||
}
|
||||
@@ -35,6 +35,16 @@ func KeepSpace() TextReadOption {
|
||||
}
|
||||
}
|
||||
|
||||
// LimitDupReadCloser returns two io.ReadCloser that read from the first will be written to the second.
|
||||
// But the second io.ReadCloser is limited to up to n bytes.
|
||||
// The first returned reader needs to be read first, because the content
|
||||
// read from it will be written to the underlying buffer of the second reader.
|
||||
func LimitDupReadCloser(reader io.ReadCloser, n int64) (io.ReadCloser, io.ReadCloser) {
|
||||
var buf bytes.Buffer
|
||||
tee := LimitTeeReader(reader, &buf, n)
|
||||
return io.NopCloser(tee), io.NopCloser(&buf)
|
||||
}
|
||||
|
||||
// ReadBytes reads exactly the bytes with the length of len(buf)
|
||||
func ReadBytes(reader io.Reader, buf []byte) error {
|
||||
var got int
|
||||
|
||||
@@ -51,6 +51,11 @@ b`,
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTextError(t *testing.T) {
|
||||
_, err := ReadText("not-exist")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestReadTextLines(t *testing.T) {
|
||||
text := `1
|
||||
|
||||
@@ -94,6 +99,11 @@ func TestReadTextLines(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTextLinesError(t *testing.T) {
|
||||
_, err := ReadTextLines("not-exist")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestDupReadCloser(t *testing.T) {
|
||||
input := "hello"
|
||||
reader := io.NopCloser(bytes.NewBufferString(input))
|
||||
@@ -108,6 +118,29 @@ func TestDupReadCloser(t *testing.T) {
|
||||
verify(r2)
|
||||
}
|
||||
|
||||
func TestLimitDupReadCloser(t *testing.T) {
|
||||
input := "hello world"
|
||||
limitBytes := int64(4)
|
||||
reader := io.NopCloser(bytes.NewBufferString(input))
|
||||
r1, r2 := LimitDupReadCloser(reader, limitBytes)
|
||||
verify := func(r io.Reader) {
|
||||
output, err := io.ReadAll(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, input, string(output))
|
||||
}
|
||||
verifyLimit := func(r io.Reader, limit int64) {
|
||||
output, err := io.ReadAll(r)
|
||||
if limit < int64(len(input)) {
|
||||
input = input[:limit]
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, input, string(output))
|
||||
}
|
||||
|
||||
verify(r1)
|
||||
verifyLimit(r2, limitBytes)
|
||||
}
|
||||
|
||||
func TestReadBytes(t *testing.T) {
|
||||
reader := io.NopCloser(bytes.NewBufferString("helloworld"))
|
||||
buf := make([]byte, 5)
|
||||
|
||||
35
core/iox/tee.go
Normal file
35
core/iox/tee.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package iox
|
||||
|
||||
import "io"
|
||||
|
||||
// LimitTeeReader returns a Reader that writes up to n bytes to w what it reads from r.
|
||||
// First n bytes reads from r performed through it are matched with
|
||||
// corresponding writes to w. There is no internal buffering -
|
||||
// the write must complete before the first n bytes read completes.
|
||||
// Any error encountered while writing is reported as a read error.
|
||||
func LimitTeeReader(r io.Reader, w io.Writer, n int64) io.Reader {
|
||||
return &limitTeeReader{r, w, n}
|
||||
}
|
||||
|
||||
type limitTeeReader struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
n int64 // limit bytes remaining
|
||||
}
|
||||
|
||||
func (t *limitTeeReader) Read(p []byte) (n int, err error) {
|
||||
n, err = t.r.Read(p)
|
||||
if n > 0 && t.n > 0 {
|
||||
limit := int64(n)
|
||||
if limit > t.n {
|
||||
limit = t.n
|
||||
}
|
||||
if n, err := t.w.Write(p[:limit]); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
t.n -= limit
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
40
core/iox/tee_test.go
Normal file
40
core/iox/tee_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package iox
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLimitTeeReader(t *testing.T) {
|
||||
limit := int64(4)
|
||||
src := []byte("hello, world")
|
||||
dst := make([]byte, len(src))
|
||||
rb := bytes.NewBuffer(src)
|
||||
wb := new(bytes.Buffer)
|
||||
r := LimitTeeReader(rb, wb, limit)
|
||||
if n, err := io.ReadFull(r, dst); err != nil || n != len(src) {
|
||||
t.Fatalf("ReadFull(r, dst) = %d, %v; want %d, nil", n, err, len(src))
|
||||
}
|
||||
if !bytes.Equal(dst, src) {
|
||||
t.Errorf("bytes read = %q want %q", dst, src)
|
||||
}
|
||||
if !bytes.Equal(wb.Bytes(), src[:limit]) {
|
||||
t.Errorf("bytes written = %q want %q", wb.Bytes(), src)
|
||||
}
|
||||
|
||||
n, err := r.Read(dst)
|
||||
assert.Equal(t, 0, n)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
|
||||
rb = bytes.NewBuffer(src)
|
||||
pr, pw := io.Pipe()
|
||||
if assert.NoError(t, pr.Close()) {
|
||||
r = LimitTeeReader(rb, pw, limit)
|
||||
n, err := io.ReadFull(r, dst)
|
||||
assert.Equal(t, 0, n)
|
||||
assert.Equal(t, io.ErrClosedPipe, err)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package iox
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
@@ -26,7 +27,7 @@ func CountLines(file string) (int, error) {
|
||||
count += bytes.Count(buf[:c], lineSep)
|
||||
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
case errors.Is(err, io.EOF):
|
||||
if noEol {
|
||||
count++
|
||||
}
|
||||
|
||||
@@ -24,3 +24,8 @@ func TestCountLines(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, lines)
|
||||
}
|
||||
|
||||
func TestCountLinesError(t *testing.T) {
|
||||
_, err := CountLines("not-exist")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package iox
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -22,3 +23,10 @@ func TestScanner(t *testing.T) {
|
||||
}
|
||||
assert.EqualValues(t, []string{"1", "2", "3", "4"}, lines)
|
||||
}
|
||||
|
||||
func TestBadScanner(t *testing.T) {
|
||||
scanner := NewTextLineScanner(iotest.ErrReader(iotest.ErrTimeout))
|
||||
assert.False(t, scanner.Scan())
|
||||
_, err := scanner.Line()
|
||||
assert.ErrorIs(t, err, iotest.ErrTimeout)
|
||||
}
|
||||
|
||||
@@ -298,6 +298,7 @@ func (l *RotateLogger) initialize() error {
|
||||
if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.currentSize = fileInfo.Size()
|
||||
}
|
||||
|
||||
@@ -381,9 +382,17 @@ func (l *RotateLogger) startWorker() {
|
||||
case event := <-l.channel:
|
||||
l.write(event)
|
||||
case <-l.done:
|
||||
// avoid losing logs before closing.
|
||||
for {
|
||||
select {
|
||||
case event := <-l.channel:
|
||||
l.write(event)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -206,6 +206,27 @@ func TestRotateLoggerClose(t *testing.T) {
|
||||
_, err := logger.Write([]byte("foo"))
|
||||
assert.ErrorIs(t, err, ErrLogFileClosed)
|
||||
})
|
||||
|
||||
t.Run("close without losing logs", func(t *testing.T) {
|
||||
text := "foo"
|
||||
filename, err := fs.TempFilenameWithText(text)
|
||||
assert.Nil(t, err)
|
||||
if len(filename) > 0 {
|
||||
defer os.Remove(filename)
|
||||
}
|
||||
logger, err := NewLogger(filename, new(DailyRotateRule), false)
|
||||
assert.Nil(t, err)
|
||||
msg := []byte("foo")
|
||||
n := 100
|
||||
for i := 0; i < n; i++ {
|
||||
_, err = logger.Write(msg)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
assert.Nil(t, logger.Close())
|
||||
bs, err := os.ReadFile(filename)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(msg)*n+len(text), len(bs))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRotateLoggerGetBackupFilename(t *testing.T) {
|
||||
@@ -496,6 +517,21 @@ func TestGzipFile(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRotateLogger_WithExistingFile(t *testing.T) {
|
||||
const body = "foo"
|
||||
filename, err := fs.TempFilenameWithText(body)
|
||||
assert.Nil(t, err)
|
||||
if len(filename) > 0 {
|
||||
defer os.Remove(filename)
|
||||
}
|
||||
|
||||
rule := NewSizeLimitRotateRule(filename, "-", 1, 100, 3, false)
|
||||
logger, err := NewLogger(filename, rule, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(len(body)), logger.currentSize)
|
||||
assert.Nil(t, logger.Close())
|
||||
}
|
||||
|
||||
func BenchmarkRotateLogger(b *testing.B) {
|
||||
filename := "./test.log"
|
||||
filename2 := "./test2.log"
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -609,25 +610,23 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
|
||||
target := reflect.New(Deref(fieldType)).Elem()
|
||||
|
||||
switch typeKind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
iValue, err := v.Int64()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if err := setValueFromString(typeKind, target, v.String()); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Float32:
|
||||
fValue, err := v.Float64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
target.SetInt(iValue)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
iValue, err := v.Int64()
|
||||
if err != nil {
|
||||
return err
|
||||
if fValue > math.MaxFloat32 {
|
||||
return float32OverflowError(v.String())
|
||||
}
|
||||
|
||||
if iValue < 0 {
|
||||
return fmt.Errorf("unmarshal %q with bad value %q", fullName, v.String())
|
||||
}
|
||||
|
||||
target.SetUint(uint64(iValue))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
target.SetFloat(fValue)
|
||||
case reflect.Float64:
|
||||
fValue, err := v.Float64()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -569,6 +569,468 @@ func TestUnmarshalIntWithString(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalInt8WithOverflow(t *testing.T) {
|
||||
t.Run("int8 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int8 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "8589934592", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int8 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int8 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int8 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int8 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("-8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int8 from int64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int8 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": int64(1) << 36, // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalInt16WithOverflow(t *testing.T) {
|
||||
t.Run("int16 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int16 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "8589934592", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int16 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int16 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int16 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int16 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("-8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int16 from int64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int16 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": int64(1) << 36, // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalInt32WithOverflow(t *testing.T) {
|
||||
t.Run("int32 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int32 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "8589934592", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int32 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int32 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int32 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int32 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("-8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int32 from int64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int32 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": int64(1) << 36, // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalInt64WithOverflow(t *testing.T) {
|
||||
t.Run("int64 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int64 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "18446744073709551616", // overflow, 1 << 64
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int64 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value int64 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("18446744073709551616"), // overflow, 1 << 64
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalUint8WithOverflow(t *testing.T) {
|
||||
t.Run("uint8 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint8 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "8589934592", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint8 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint8 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint8 from json.Number with negative", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint8 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("-1"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint8 from int64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint8 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": int64(1) << 36, // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalUint16WithOverflow(t *testing.T) {
|
||||
t.Run("uint16 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint16 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "8589934592", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint16 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint16 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint16 from json.Number with negative", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint16 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("-1"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint16 from int64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint16 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": int64(1) << 36, // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalUint32WithOverflow(t *testing.T) {
|
||||
t.Run("uint32 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint32 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "8589934592", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint32 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint32 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("8589934592"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint32 from json.Number with negative", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint32 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("-1"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint32 from int64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint32 `key:"int"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": int64(1) << 36, // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalUint64WithOverflow(t *testing.T) {
|
||||
t.Run("uint64 from string", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint64 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": "18446744073709551616", // overflow, 1 << 64
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("uint64 from json.Number", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value uint64 `key:"int,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"int": json.Number("18446744073709551616"), // overflow, 1 << 64
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalFloat32WithOverflow(t *testing.T) {
|
||||
t.Run("float32 from string greater than float64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value float32 `key:"float,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"float": "1.79769313486231570814527423731704356798070e+309", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("float32 from string greater than float32", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value float32 `key:"float,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"float": "1.79769313486231570814527423731704356798070e+300", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("float32 from json.Number greater than float64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value float32 `key:"float"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"float": json.Number("1.79769313486231570814527423731704356798070e+309"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("float32 from json.Number greater than float32", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value float32 `key:"float"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"float": json.Number("1.79769313486231570814527423731704356798070e+300"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalFloat64WithOverflow(t *testing.T) {
|
||||
t.Run("float64 from string greater than float64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value float64 `key:"float,string"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"float": "1.79769313486231570814527423731704356798070e+309", // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("float32 from json.Number greater than float64", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Value float64 `key:"float"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"float": json.Number("1.79769313486231570814527423731704356798070e+309"), // overflow
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalBoolSliceRequired(t *testing.T) {
|
||||
type inner struct {
|
||||
Bools []bool `key:"bools"`
|
||||
@@ -795,16 +1257,20 @@ func TestUnmarshalFloat(t *testing.T) {
|
||||
type inner struct {
|
||||
Float32 float32 `key:"float32"`
|
||||
Float32Str float32 `key:"float32str,string"`
|
||||
Float32Num float32 `key:"float32num"`
|
||||
Float64 float64 `key:"float64"`
|
||||
Float64Str float64 `key:"float64str,string"`
|
||||
Float64Num float64 `key:"float64num"`
|
||||
DefaultFloat float32 `key:"defaultfloat,default=5.5"`
|
||||
Optional float32 `key:",optional"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"float32": float32(1.5),
|
||||
"float32str": "2.5",
|
||||
"float64": float64(3.5),
|
||||
"float32num": json.Number("2.6"),
|
||||
"float64": 3.5,
|
||||
"float64str": "4.5",
|
||||
"float64num": json.Number("4.6"),
|
||||
}
|
||||
|
||||
var in inner
|
||||
@@ -812,8 +1278,10 @@ func TestUnmarshalFloat(t *testing.T) {
|
||||
if ast.NoError(UnmarshalKey(m, &in)) {
|
||||
ast.Equal(float32(1.5), in.Float32)
|
||||
ast.Equal(float32(2.5), in.Float32Str)
|
||||
ast.Equal(float32(2.6), in.Float32Num)
|
||||
ast.Equal(3.5, in.Float64)
|
||||
ast.Equal(4.5, in.Float64Str)
|
||||
ast.Equal(4.6, in.Float64Num)
|
||||
ast.Equal(float32(5.5), in.DefaultFloat)
|
||||
}
|
||||
}
|
||||
@@ -5206,15 +5674,13 @@ func TestUnmarshalWithIgnoreFields(t *testing.T) {
|
||||
assert.Equal(t, 0, bar1.IgnoreInt)
|
||||
}
|
||||
|
||||
var bar2 Bar1
|
||||
var bar2 Bar2
|
||||
if assert.NoError(t, unmarshaler.Unmarshal(map[string]any{
|
||||
"Value": "foo",
|
||||
"IgnoreString": "any",
|
||||
"IgnoreInt": 2,
|
||||
}, &bar2)) {
|
||||
assert.Empty(t, bar2.Value)
|
||||
assert.Empty(t, bar2.IgnoreString)
|
||||
assert.Equal(t, 0, bar2.IgnoreInt)
|
||||
assert.Nil(t, bar2.Foo)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,6 +42,10 @@ var (
|
||||
)
|
||||
|
||||
type (
|
||||
integer interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
|
||||
}
|
||||
|
||||
optionsCacheValue struct {
|
||||
key string
|
||||
options *fieldOptions
|
||||
@@ -103,21 +107,32 @@ func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("the value %q cannot be parsed as int", str)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return intValue, nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
uintValue, err := strconv.ParseUint(str, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("the value %q cannot be parsed as uint", str)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uintValue, nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
case reflect.Float32:
|
||||
floatValue, err := strconv.ParseFloat(str, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("the value %q cannot be parsed as float", str)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if floatValue > math.MaxFloat32 {
|
||||
return 0, float32OverflowError(str)
|
||||
}
|
||||
|
||||
return floatValue, nil
|
||||
case reflect.Float64:
|
||||
floatValue, err := strconv.ParseFloat(str, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return floatValue, nil
|
||||
@@ -215,6 +230,10 @@ func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func intOverflowError[T integer](v T, kind reflect.Kind) error {
|
||||
return fmt.Errorf("parsing \"%d\" as %s: value out of range", v, kind.String())
|
||||
}
|
||||
|
||||
func isLeftInclude(b byte) (bool, error) {
|
||||
switch b {
|
||||
case '[':
|
||||
@@ -237,6 +256,10 @@ func isRightInclude(b byte) (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func float32OverflowError(str string) error {
|
||||
return fmt.Errorf("parsing %q as float32: value out of range", str)
|
||||
}
|
||||
|
||||
func maybeNewValue(fieldType reflect.Type, value reflect.Value) {
|
||||
if fieldType.Kind() == reflect.Ptr && value.IsNil() {
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
@@ -482,22 +505,61 @@ func parseSegments(val string) []string {
|
||||
return segments
|
||||
}
|
||||
|
||||
func setIntValue(value reflect.Value, v any, min, max int64) error {
|
||||
iv := v.(int64)
|
||||
if iv < min || iv > max {
|
||||
return intOverflowError(iv, value.Kind())
|
||||
}
|
||||
|
||||
value.SetInt(iv)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v any) error {
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
value.SetBool(v.(bool))
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return nil
|
||||
case reflect.Int: // int depends on int size, 32 or 64
|
||||
return setIntValue(value, v, math.MinInt, math.MaxInt)
|
||||
case reflect.Int8:
|
||||
return setIntValue(value, v, math.MinInt8, math.MaxInt8)
|
||||
case reflect.Int16:
|
||||
return setIntValue(value, v, math.MinInt16, math.MaxInt16)
|
||||
case reflect.Int32:
|
||||
return setIntValue(value, v, math.MinInt32, math.MaxInt32)
|
||||
case reflect.Int64:
|
||||
value.SetInt(v.(int64))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return nil
|
||||
case reflect.Uint: // uint depends on int size, 32 or 64
|
||||
return setUintValue(value, v, math.MaxUint)
|
||||
case reflect.Uint8:
|
||||
return setUintValue(value, v, math.MaxUint8)
|
||||
case reflect.Uint16:
|
||||
return setUintValue(value, v, math.MaxUint16)
|
||||
case reflect.Uint32:
|
||||
return setUintValue(value, v, math.MaxUint32)
|
||||
case reflect.Uint64:
|
||||
value.SetUint(v.(uint64))
|
||||
return nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
value.SetFloat(v.(float64))
|
||||
return nil
|
||||
case reflect.String:
|
||||
value.SetString(v.(string))
|
||||
return nil
|
||||
default:
|
||||
return errUnsupportedType
|
||||
}
|
||||
}
|
||||
|
||||
func setUintValue(value reflect.Value, v any, boundary uint64) error {
|
||||
iv := v.(uint64)
|
||||
if iv > boundary {
|
||||
return intOverflowError(iv, value.Kind())
|
||||
}
|
||||
|
||||
value.SetUint(iv)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -575,7 +637,8 @@ func usingDifferentKeys(key string, field reflect.StructField) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opts *fieldOptionsWithContext) error {
|
||||
func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string,
|
||||
opts *fieldOptionsWithContext) error {
|
||||
if !value.CanSet() {
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package redis
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
red "github.com/go-redis/redis/v8"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -41,3 +43,17 @@ func TestSplitClusterAddrs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCluster(t *testing.T) {
|
||||
r := miniredis.RunT(t)
|
||||
defer r.Close()
|
||||
c, err := getCluster(&Redis{
|
||||
Addr: r.Addr(),
|
||||
Type: ClusterType,
|
||||
tls: true,
|
||||
hooks: []red.Hook{durationHook},
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, c)
|
||||
}
|
||||
}
|
||||
|
||||
22
go.mod
22
go.mod
@@ -6,7 +6,7 @@ require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||
github.com/alicebob/miniredis/v2 v2.30.5
|
||||
github.com/fatih/color v1.15.0
|
||||
github.com/fullstorydev/grpcurl v1.8.7
|
||||
github.com/fullstorydev/grpcurl v1.8.8
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
@@ -33,11 +33,11 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.14.0
|
||||
go.uber.org/automaxprocs v1.5.3
|
||||
go.uber.org/goleak v1.2.1
|
||||
golang.org/x/net v0.14.0
|
||||
golang.org/x/sys v0.11.0
|
||||
golang.org/x/net v0.15.0
|
||||
golang.org/x/sys v0.12.0
|
||||
golang.org/x/time v0.3.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9
|
||||
google.golang.org/grpc v1.57.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e
|
||||
google.golang.org/grpc v1.58.2
|
||||
google.golang.org/protobuf v1.31.0
|
||||
gopkg.in/cheggaaa/pb.v1 v1.0.28
|
||||
gopkg.in/h2non/gock.v1 v1.1.2
|
||||
@@ -103,14 +103,14 @@ require (
|
||||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
go.uber.org/zap v1.24.0 // indirect
|
||||
golang.org/x/crypto v0.12.0 // indirect
|
||||
golang.org/x/oauth2 v0.7.0 // indirect
|
||||
golang.org/x/crypto v0.13.0 // indirect
|
||||
golang.org/x/oauth2 v0.10.0 // indirect
|
||||
golang.org/x/sync v0.3.0 // indirect
|
||||
golang.org/x/term v0.11.0 // indirect
|
||||
golang.org/x/text v0.12.0 // indirect
|
||||
golang.org/x/term v0.12.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230913181813-007df8e322eb // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
k8s.io/klog/v2 v2.90.1 // indirect
|
||||
|
||||
@@ -9,36 +9,27 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// TomlToJson converts TOML data into its JSON representation.
|
||||
func TomlToJson(data []byte) ([]byte, error) {
|
||||
var val any
|
||||
if err := toml.NewDecoder(bytes.NewReader(data)).Decode(&val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
return encodeToJSON(val)
|
||||
}
|
||||
|
||||
// YamlToJson converts YAML data into its JSON representation.
|
||||
func YamlToJson(data []byte) ([]byte, error) {
|
||||
var val any
|
||||
if err := yaml.Unmarshal(data, &val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
val = toStringKeyMap(val)
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
return encodeToJSON(toStringKeyMap(val))
|
||||
}
|
||||
|
||||
// convertKeyToString ensures all keys of the map are of type string.
|
||||
func convertKeyToString(in map[any]any) map[string]any {
|
||||
res := make(map[string]any)
|
||||
for k, v := range in {
|
||||
@@ -47,10 +38,12 @@ func convertKeyToString(in map[any]any) map[string]any {
|
||||
return res
|
||||
}
|
||||
|
||||
// convertNumberToJsonNumber converts numbers into json.Number type for compatibility.
|
||||
func convertNumberToJsonNumber(in any) json.Number {
|
||||
return json.Number(lang.Repr(in))
|
||||
}
|
||||
|
||||
// convertSlice processes slice items to ensure key compatibility.
|
||||
func convertSlice(in []any) []any {
|
||||
res := make([]any, len(in))
|
||||
for i, v := range in {
|
||||
@@ -59,6 +52,17 @@ func convertSlice(in []any) []any {
|
||||
return res
|
||||
}
|
||||
|
||||
// encodeToJSON encodes the given value into its JSON representation.
|
||||
func encodeToJSON(val any) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// toStringKeyMap processes the data to ensure that all map keys are of type string.
|
||||
func toStringKeyMap(v any) any {
|
||||
switch v := v.(type) {
|
||||
case []any:
|
||||
|
||||
@@ -132,7 +132,7 @@ func (w *cryptionResponseWriter) flush(key []byte) {
|
||||
body := base64.StdEncoding.EncodeToString(content)
|
||||
if n, err := io.WriteString(w.ResponseWriter, body); err != nil {
|
||||
logx.Errorf("write response failed, error: %s", err)
|
||||
} else if n < len(content) {
|
||||
logx.Errorf("actual bytes: %d, written bytes: %d", len(content), n)
|
||||
} else if n < len(body) {
|
||||
logx.Errorf("actual bytes: %d, written bytes: %d", len(body), n)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,15 +2,18 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/codec"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,6 +40,19 @@ func TestCryptionHandlerGet(t *testing.T) {
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestCryptionHandlerGet_badKey(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/any", http.NoBody)
|
||||
handler := CryptionHandler(append(aesKey, aesKey...))(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte(respText))
|
||||
w.Header().Set("X-Test", "test")
|
||||
assert.Nil(t, err)
|
||||
}))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
}
|
||||
|
||||
func TestCryptionHandlerPost(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
|
||||
@@ -120,10 +136,110 @@ func TestCryptionHandler_ContentTooLong(t *testing.T) {
|
||||
defer svr.Close()
|
||||
|
||||
body := make([]byte, maxBytes+1)
|
||||
rand.Read(body)
|
||||
_, err := rand.Read(body)
|
||||
assert.NoError(t, err)
|
||||
req, err := http.NewRequest(http.MethodPost, svr.URL, bytes.NewReader(body))
|
||||
assert.Nil(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestCryptionHandler_BadBody(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "/foo", iotest.ErrReader(io.ErrUnexpectedEOF))
|
||||
assert.Nil(t, err)
|
||||
err = decryptBody(maxBytes, aesKey, req)
|
||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func TestCryptionHandler_BadKey(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
|
||||
assert.Nil(t, err)
|
||||
buf.WriteString(base64.StdEncoding.EncodeToString(enc))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/any", &buf)
|
||||
err = decryptBody(maxBytes, append(aesKey, aesKey...), req)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCryptionResponseWriter_Flush(t *testing.T) {
|
||||
body := []byte("hello, world!")
|
||||
|
||||
t.Run("half", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
f := flushableResponseWriter{
|
||||
writer: &halfWriter{recorder},
|
||||
}
|
||||
w := newCryptionResponseWriter(f)
|
||||
_, err := w.Write(body)
|
||||
assert.NoError(t, err)
|
||||
w.flush(aesKey)
|
||||
b, err := io.ReadAll(recorder.Body)
|
||||
assert.NoError(t, err)
|
||||
expected, err := codec.EcbEncrypt(aesKey, body)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(base64.StdEncoding.EncodeToString(expected), string(b)))
|
||||
assert.True(t, len(string(b)) < len(base64.StdEncoding.EncodeToString(expected)))
|
||||
})
|
||||
|
||||
t.Run("full", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
f := flushableResponseWriter{
|
||||
writer: recorder,
|
||||
}
|
||||
w := newCryptionResponseWriter(f)
|
||||
_, err := w.Write(body)
|
||||
assert.NoError(t, err)
|
||||
w.flush(aesKey)
|
||||
b, err := io.ReadAll(recorder.Body)
|
||||
assert.NoError(t, err)
|
||||
expected, err := codec.EcbEncrypt(aesKey, body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expected), string(b))
|
||||
})
|
||||
|
||||
t.Run("bad writer", func(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
f := flushableResponseWriter{
|
||||
writer: new(badWriter),
|
||||
}
|
||||
w := newCryptionResponseWriter(f)
|
||||
_, err := w.Write(body)
|
||||
assert.NoError(t, err)
|
||||
w.flush(aesKey)
|
||||
assert.True(t, strings.Contains(buf.Content(), io.ErrClosedPipe.Error()))
|
||||
})
|
||||
}
|
||||
|
||||
type flushableResponseWriter struct {
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func (m flushableResponseWriter) Header() http.Header {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m flushableResponseWriter) Write(p []byte) (int, error) {
|
||||
return m.writer.Write(p)
|
||||
}
|
||||
|
||||
func (m flushableResponseWriter) WriteHeader(statusCode int) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
type halfWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (t *halfWriter) Write(p []byte) (n int, err error) {
|
||||
n = len(p) >> 1
|
||||
return t.w.Write(p[0:n])
|
||||
}
|
||||
|
||||
type badWriter struct {
|
||||
}
|
||||
|
||||
func (b *badWriter) Write(p []byte) (n int, err error) {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/color"
|
||||
@@ -39,7 +38,7 @@ func LogHandler(next http.Handler) http.Handler {
|
||||
lrw := response.NewWithCodeResponseWriter(w)
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
r.Body, dup = iox.LimitDupReadCloser(r.Body, limitBodyBytes)
|
||||
next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
|
||||
r.Body = dup
|
||||
logBrief(r, lrw.Code, timer, logs)
|
||||
@@ -136,14 +135,7 @@ func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *intern
|
||||
|
||||
ok := isOkResponse(code)
|
||||
if !ok {
|
||||
fullReq := dumpRequest(r)
|
||||
limitReader := io.LimitReader(strings.NewReader(fullReq), limitBodyBytes)
|
||||
body, err := io.ReadAll(limitReader)
|
||||
if err != nil {
|
||||
buf.WriteString(fmt.Sprintf("\n%s", fullReq))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("\n%s", string(body)))
|
||||
}
|
||||
buf.WriteString(fmt.Sprintf("\n%s", dumpRequest(r)))
|
||||
}
|
||||
|
||||
body := logs.Flush()
|
||||
|
||||
@@ -16,8 +16,8 @@ require (
|
||||
github.com/zeromicro/antlr v0.0.1
|
||||
github.com/zeromicro/ddl-parser v1.0.5
|
||||
github.com/zeromicro/go-zero v1.5.5
|
||||
golang.org/x/text v0.12.0
|
||||
google.golang.org/grpc v1.57.0
|
||||
golang.org/x/text v0.13.0
|
||||
google.golang.org/grpc v1.58.2
|
||||
google.golang.org/protobuf v1.31.0
|
||||
)
|
||||
|
||||
@@ -92,14 +92,14 @@ require (
|
||||
go.uber.org/zap v1.24.0 // indirect
|
||||
golang.org/x/crypto v0.12.0 // indirect
|
||||
golang.org/x/net v0.14.0 // indirect
|
||||
golang.org/x/oauth2 v0.7.0 // indirect
|
||||
golang.org/x/oauth2 v0.10.0 // indirect
|
||||
golang.org/x/sys v0.11.0 // indirect
|
||||
golang.org/x/term v0.11.0 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
@@ -415,8 +415,8 @@ golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4Iltr
|
||||
golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g=
|
||||
golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4=
|
||||
golang.org/x/oauth2 v0.10.0 h1:zHCpF2Khkwy4mMB4bv0U37YtJdTGW8jI0glAApi0Kh8=
|
||||
golang.org/x/oauth2 v0.10.0/go.mod h1:kTpgurOux7LqtuxjuyZa4Gj2gdezIt/jQtGnNFfypQI=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -470,8 +470,8 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
@@ -580,12 +580,12 @@ google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6D
|
||||
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 h1:9NWlQfY2ePejTmfwUH1OWwmznFa+0kKcHGPDvcPza9M=
|
||||
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54/go.mod h1:zqTuNwFlFRsw5zIts5VnzLQxSRqh+CGOTVMlYbY0Eyk=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 h1:m8v1xLLLzMe1m5P+gCTF8nJB9epwZQUBERm20Oy1poQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA=
|
||||
google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 h1:Z0hjGZePRE0ZBWotvtrwxFNrNE9CUAGtplaDK5NNI/g=
|
||||
google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98/go.mod h1:S7mY02OqCJTD0E1OiQy1F72PWFB4bZJ87cAtLPYgDR0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 h1:FmF5cCW94Ij59cfpoLiwTgodWmm60eEV0CjlsVg2fuw=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98/go.mod h1:rsr7RhLuwsDKL7RmgDDCUc6yaGr1iqceVb5Wv6f6YvQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
@@ -602,8 +602,8 @@ google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTp
|
||||
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
||||
google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
|
||||
google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU=
|
||||
google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw=
|
||||
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
|
||||
google.golang.org/grpc v1.58.2 h1:SXUpjxeVF3FKrTYQI4f4KvbGD5u2xccdYdurwowix5I=
|
||||
google.golang.org/grpc v1.58.2/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
|
||||
@@ -72,7 +72,8 @@ from (
|
||||
t.typname AS type,
|
||||
a.atttypmod AS lengthvar,
|
||||
a.attnotnull AS not_null,
|
||||
b.description AS comment
|
||||
b.description AS comment,
|
||||
(c.relnamespace::regnamespace)::varchar AS schema_name
|
||||
FROM pg_class c,
|
||||
pg_attribute a
|
||||
LEFT OUTER JOIN pg_description b ON a.attrelid = b.objoid AND a.attnum = b.objsubid,
|
||||
@@ -81,10 +82,11 @@ from (
|
||||
and a.attnum > 0
|
||||
and a.attrelid = c.oid
|
||||
and a.atttypid = t.oid
|
||||
GROUP BY a.attnum, c.relname, a.attname, t.typname, a.atttypmod, a.attnotnull, b.description
|
||||
GROUP BY a.attnum, c.relname, a.attname, t.typname, a.atttypmod, a.attnotnull, b.description, c.relnamespace::regnamespace
|
||||
ORDER BY a.attnum) AS t
|
||||
left join information_schema.columns AS c on t.relname = c.table_name
|
||||
and t.field = c.column_name and c.table_schema = $2`
|
||||
left join information_schema.columns AS c on t.relname = c.table_name and t.schema_name = c.table_schema
|
||||
and t.field = c.column_name
|
||||
where c.table_schema = $2`
|
||||
|
||||
var reply []*PostgreColumn
|
||||
err := m.conn.QueryRowsPartial(&reply, querySql, table, schema)
|
||||
|
||||
@@ -147,3 +147,6 @@ func (m mockClientConn) UpdateAddresses(addresses []resolver.Address) {
|
||||
|
||||
func (m mockClientConn) Connect() {
|
||||
}
|
||||
|
||||
func (m mockClientConn) Shutdown() {
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user