add more tests for rest (#462)
This commit is contained in:
169
rest/internal/security/contentsecurity_test.go
Normal file
169
rest/internal/security/contentsecurity_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/codec"
|
||||
"github.com/tal-tech/go-zero/core/fs"
|
||||
"github.com/tal-tech/go-zero/rest/httpx"
|
||||
)
|
||||
|
||||
const (
|
||||
pubKey = `-----BEGIN PUBLIC KEY-----
|
||||
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCyeDYV2ieOtNDi6tuNtAbmUjN9
|
||||
pTHluAU5yiKEz8826QohcxqUKP3hybZBcm60p+rUxMAJFBJ8Dt+UJ6sEMzrf1rOF
|
||||
YOImVvORkXjpFU7sCJkhnLMs/kxtRzcZJG6ADUlG4GDCNcZpY/qELEvwgm2kCcHi
|
||||
tGC2mO8opFFFHTR0aQIDAQAB
|
||||
-----END PUBLIC KEY-----`
|
||||
priKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQCyeDYV2ieOtNDi6tuNtAbmUjN9pTHluAU5yiKEz8826QohcxqU
|
||||
KP3hybZBcm60p+rUxMAJFBJ8Dt+UJ6sEMzrf1rOFYOImVvORkXjpFU7sCJkhnLMs
|
||||
/kxtRzcZJG6ADUlG4GDCNcZpY/qELEvwgm2kCcHitGC2mO8opFFFHTR0aQIDAQAB
|
||||
AoGAcENv+jT9VyZkk6karLuG75DbtPiaN5+XIfAF4Ld76FWVOs9V88cJVON20xpx
|
||||
ixBphqexCMToj8MnXuHJEN5M9H15XXx/9IuiMm3FOw0i6o0+4V8XwHr47siT6T+r
|
||||
HuZEyXER/2qrm0nxyC17TXtd/+TtpfQWSbivl6xcAEo9RRECQQDj6OR6AbMQAIDn
|
||||
v+AhP/y7duDZimWJIuMwhigA1T2qDbtOoAEcjv3DB1dAswJ7clcnkxI9a6/0RDF9
|
||||
0IEHUcX9AkEAyHdcegWiayEnbatxWcNWm1/5jFnCN+GTRRFrOhBCyFr2ZdjFV4T+
|
||||
acGtG6omXWaZJy1GZz6pybOGy93NwLB93QJARKMJ0/iZDbOpHqI5hKn5mhd2Je25
|
||||
IHDCTQXKHF4cAQ+7njUvwIMLx2V5kIGYuMa5mrB/KMI6rmyvHv3hLewhnQJBAMMb
|
||||
cPUOENMllINnzk2oEd3tXiscnSvYL4aUeoErnGP2LERZ40/YD+mMZ9g6FVboaX04
|
||||
0oHf+k5mnXZD7WJyJD0CQQDJ2HyFbNaUUHK+lcifCibfzKTgmnNh9ZpePFumgJzI
|
||||
EfFE5H+nzsbbry2XgJbWzRNvuFTOLWn4zM+aFyy9WvbO
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
body = "hello world!"
|
||||
)
|
||||
|
||||
var key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
|
||||
func TestContentSecurity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
extraKey string
|
||||
extraSecret string
|
||||
extraTime string
|
||||
err error
|
||||
code int
|
||||
}{
|
||||
{
|
||||
name: "encrypted",
|
||||
mode: "1",
|
||||
},
|
||||
{
|
||||
name: "unencrypted",
|
||||
mode: "0",
|
||||
},
|
||||
{
|
||||
name: "bad content type",
|
||||
mode: "a",
|
||||
err: ErrInvalidContentType,
|
||||
},
|
||||
{
|
||||
name: "bad secret",
|
||||
mode: "1",
|
||||
extraSecret: "any",
|
||||
err: ErrInvalidSecret,
|
||||
},
|
||||
{
|
||||
name: "bad key",
|
||||
mode: "1",
|
||||
extraKey: "any",
|
||||
err: ErrInvalidKey,
|
||||
},
|
||||
{
|
||||
name: "bad time",
|
||||
mode: "1",
|
||||
extraTime: "any",
|
||||
code: httpx.CodeSignatureInvalidHeader,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r, err := http.NewRequest(http.MethodPost, "http://localhost:3333/a/b?c=first&d=second",
|
||||
strings.NewReader(body))
|
||||
assert.Nil(t, err)
|
||||
|
||||
timestamp := time.Now().Unix()
|
||||
sha := sha256.New()
|
||||
sha.Write([]byte(body))
|
||||
bodySign := fmt.Sprintf("%x", sha.Sum(nil))
|
||||
contentOfSign := strings.Join([]string{
|
||||
strconv.FormatInt(timestamp, 10),
|
||||
http.MethodPost,
|
||||
r.URL.Path,
|
||||
r.URL.RawQuery,
|
||||
bodySign,
|
||||
}, "\n")
|
||||
sign := hs256(key, contentOfSign)
|
||||
content := strings.Join([]string{
|
||||
"version=v1",
|
||||
"type=" + test.mode,
|
||||
fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)) + test.extraKey,
|
||||
"time=" + strconv.FormatInt(timestamp, 10) + test.extraTime,
|
||||
}, "; ")
|
||||
|
||||
encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
output, err := encrypter.Encrypt([]byte(content))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
encryptedContent := base64.StdEncoding.EncodeToString(output)
|
||||
r.Header.Set("X-Content-Security", strings.Join([]string{
|
||||
fmt.Sprintf("key=%s", fingerprint(pubKey)),
|
||||
"secret=" + encryptedContent + test.extraSecret,
|
||||
"signature=" + sign,
|
||||
}, "; "))
|
||||
|
||||
file, err := fs.TempFilenameWithText(priKey)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(file)
|
||||
|
||||
dec, err := codec.NewRsaDecrypter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
header, err := ParseContentSecurity(map[string]codec.RsaDecrypter{
|
||||
fingerprint(pubKey): dec,
|
||||
}, r)
|
||||
assert.Equal(t, test.err, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, test.code, VerifySignature(r, header, time.Minute))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func fingerprint(key string) string {
|
||||
h := md5.New()
|
||||
io.WriteString(h, key)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func hs256(key []byte, body string) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
io.WriteString(h, body)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
Reference in New Issue
Block a user