chore: refine rest validator (#2928)
* chore: refine rest validator * chore: add more tests * chore: reformat code * chore: add comments
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/mapping"
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/encoding"
|
"github.com/zeromicro/go-zero/rest/internal/encoding"
|
||||||
@@ -23,15 +24,13 @@ const (
|
|||||||
var (
|
var (
|
||||||
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
||||||
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
||||||
xValidator Validator
|
validator atomic.Value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Validator defines the interface for validating the request.
|
||||||
type Validator interface {
|
type Validator interface {
|
||||||
Validate(data interface{}, lang string) error
|
// Validate validates the request and parsed data.
|
||||||
}
|
Validate(r *http.Request, data any) error
|
||||||
|
|
||||||
func SetValidator(validator Validator) {
|
|
||||||
xValidator = validator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse parses the request.
|
// Parse parses the request.
|
||||||
@@ -52,9 +51,10 @@ func Parse(r *http.Request, v interface{}) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if xValidator != nil {
|
if val := validator.Load(); val != nil {
|
||||||
return xValidator.Validate(v, r.Header.Get("Accept-Language"))
|
return val.(Validator).Validate(r, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,6 +117,13 @@ func ParsePath(r *http.Request, v interface{}) error {
|
|||||||
return pathUnmarshaler.Unmarshal(m, v)
|
return pathUnmarshaler.Unmarshal(m, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetValidator sets the validator.
|
||||||
|
// The validator is used to validate the request, only called in Parse,
|
||||||
|
// not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath.
|
||||||
|
func SetValidator(val Validator) {
|
||||||
|
validator.Store(val)
|
||||||
|
}
|
||||||
|
|
||||||
func withJsonBody(r *http.Request) bool {
|
func withJsonBody(r *http.Request) bool {
|
||||||
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
|
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package httpx
|
package httpx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -207,9 +209,23 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.JsonContentType)
|
||||||
|
|
||||||
assert.Nil(t, Parse(r, &v))
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
assert.Equal(t, "kevin", v.Name)
|
assert.Equal(t, "kevin", v.Name)
|
||||||
assert.Equal(t, 18, v.Age)
|
assert.Equal(t, 18, v.Age)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad body", func(t *testing.T) {
|
||||||
|
var v struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
body := `{"name":"kevin", "ag": 18}`
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
|
r.Header.Set(ContentType, header.JsonContentType)
|
||||||
|
|
||||||
|
assert.Error(t, Parse(r, &v))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("hasn't body", func(t *testing.T) {
|
t.Run("hasn't body", func(t *testing.T) {
|
||||||
@@ -308,6 +324,36 @@ func TestParseHeaders_Error(t *testing.T) {
|
|||||||
assert.NotNil(t, Parse(r, &v))
|
assert.NotNil(t, Parse(r, &v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseWithValidator(t *testing.T) {
|
||||||
|
SetValidator(mockValidator{})
|
||||||
|
var v struct {
|
||||||
|
Name string `form:"name"`
|
||||||
|
Age int `form:"age"`
|
||||||
|
Percent float64 `form:"percent,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/a?name=hello&age=18&percent=3.4", http.NoBody)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
|
assert.Equal(t, "hello", v.Name)
|
||||||
|
assert.Equal(t, 18, v.Age)
|
||||||
|
assert.Equal(t, 3.4, v.Percent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithValidatorWithError(t *testing.T) {
|
||||||
|
SetValidator(mockValidator{})
|
||||||
|
var v struct {
|
||||||
|
Name string `form:"name"`
|
||||||
|
Age int `form:"age"`
|
||||||
|
Percent float64 `form:"percent,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/a?name=world&age=18&percent=3.4", http.NoBody)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Error(t, Parse(r, &v))
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkParseRaw(b *testing.B) {
|
func BenchmarkParseRaw(b *testing.B) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
|
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -351,3 +397,16 @@ func BenchmarkParseAuto(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockValidator struct{}
|
||||||
|
|
||||||
|
func (m mockValidator) Validate(r *http.Request, data any) error {
|
||||||
|
if r.URL.Path == "/a" {
|
||||||
|
val := reflect.ValueOf(data).Elem().FieldByName("Name").String()
|
||||||
|
if val != "hello" {
|
||||||
|
return errors.New("name is not hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user