diff --git a/tools/goctl/api/spec/fn.go b/tools/goctl/api/spec/fn.go index 6ca97d99..8821b3a7 100644 --- a/tools/goctl/api/spec/fn.go +++ b/tools/goctl/api/spec/fn.go @@ -13,10 +13,11 @@ const ( bodyTagKey = "json" formTagKey = "form" pathTagKey = "path" + headerTagKey = "header" defaultSummaryKey = "summary" ) -var definedKeys = []string{bodyTagKey, formTagKey, pathTagKey} +var definedKeys = []string{bodyTagKey, formTagKey, pathTagKey, headerTagKey} func (s Service) JoinPrefix() Service { var groups []Group @@ -138,6 +139,21 @@ func (m Member) IsFormMember() bool { return false } +// IsTagMember returns true if contains given tag +func (m Member) IsTagMember(tagKey string) bool { + if m.IsInline { + return true + } + + tags := m.Tags() + for _, tag := range tags { + if tag.Key == tagKey { + return true + } + } + return false +} + // GetBodyMembers returns all json fields func (t DefineStruct) GetBodyMembers() []Member { var result []Member @@ -171,6 +187,17 @@ func (t DefineStruct) GetNonBodyMembers() []Member { return result } +// GetTagMembers returns all given key fields +func (t DefineStruct) GetTagMembers(tagKey string) []Member { + var result []Member + for _, member := range t.Members { + if member.IsTagMember(tagKey) { + result = append(result, member) + } + } + return result +} + // JoinedDoc joins comments and summary value in AtDoc func (r Route) JoinedDoc() string { doc := r.AtDoc.Text diff --git a/tools/goctl/api/tsgen/genpacket.go b/tools/goctl/api/tsgen/genpacket.go index 99296ad4..73ec4d79 100644 --- a/tools/goctl/api/tsgen/genpacket.go +++ b/tools/goctl/api/tsgen/genpacket.go @@ -105,20 +105,43 @@ func paramsForRoute(route spec.Route) string { } hasParams := pathHasParams(route) hasBody := hasRequestBody(route) + hasHeader := hasRequestHeader(route) + hasPath := hasRequestPath(route) rt, err := goTypeToTs(route.RequestType, true) if err != nil { fmt.Println(err.Error()) return "" } - - if hasParams && hasBody { - return fmt.Sprintf("params: %s, req: %s", rt+"Params", rt) - } else if hasParams { - return fmt.Sprintf("params: %s", rt+"Params") - } else if hasBody { - return fmt.Sprintf("req: %s", rt) + var params []string + if hasParams { + params = append(params, fmt.Sprintf("params: %s", rt+"Params")) } - return "" + if hasBody { + params = append(params, fmt.Sprintf("req: %s", rt)) + } + if hasHeader { + params = append(params, fmt.Sprintf("headers: %s", rt+"Headers")) + } + if hasPath { + ds, ok := route.RequestType.(spec.DefineStruct) + if !ok { + fmt.Printf("invalid route.RequestType: {%v}\n", route.RequestType) + } + members := ds.GetTagMembers(pathTagKey) + for _, member := range members { + tags := member.Tags() + + if len(tags) > 0 && tags[0].Key == pathTagKey { + valueType, err := goTypeToTs(member.Type, false) + if err != nil { + fmt.Println(err.Error()) + return "" + } + params = append(params, fmt.Sprintf("%s: %s", tags[0].Name, valueType)) + } + } + } + return strings.Join(params, ", ") } func commentForRoute(route spec.Route) string { @@ -128,13 +151,15 @@ func commentForRoute(route spec.Route) string { builder.WriteString("\n * @description " + comment) hasParams := pathHasParams(route) hasBody := hasRequestBody(route) - if hasParams && hasBody { + hasHeader := hasRequestHeader(route) + if hasParams { builder.WriteString("\n * @param params") + } + if hasBody { builder.WriteString("\n * @param req") - } else if hasParams { - builder.WriteString("\n * @param params") - } else if hasBody { - builder.WriteString("\n * @param req") + } + if hasHeader { + builder.WriteString("\n * @param headers") } builder.WriteString("\n */") return builder.String() @@ -143,26 +168,42 @@ func commentForRoute(route spec.Route) string { func callParamsForRoute(route spec.Route, group spec.Group) string { hasParams := pathHasParams(route) hasBody := hasRequestBody(route) - if hasParams && hasBody { - return fmt.Sprintf("%s, %s, %s", pathForRoute(route, group), "params", "req") - } else if hasParams { - return fmt.Sprintf("%s, %s", pathForRoute(route, group), "params") - } else if hasBody { - return fmt.Sprintf("%s, %s", pathForRoute(route, group), "req") + hasHeader := hasRequestHeader(route) + + var params = []string{pathForRoute(route, group)} + if hasParams { + params = append(params, "params") + } + if hasBody { + params = append(params, "req") + } + if hasHeader { + params = append(params, "headers") } - return pathForRoute(route, group) + return strings.Join(params, ", ") } func pathForRoute(route spec.Route, group spec.Group) string { prefix := group.GetAnnotation(pathPrefix) + + routePath := route.Path + if strings.Contains(routePath, ":") { + pathSlice := strings.Split(routePath, "/") + for i, part := range pathSlice { + if strings.Contains(part, ":") { + pathSlice[i] = fmt.Sprintf("${%s}", part[1:]) + } + } + routePath = strings.Join(pathSlice, "/") + } if len(prefix) == 0 { - return "\"" + route.Path + "\"" + return "`" + routePath + "`" } prefix = strings.TrimPrefix(prefix, `"`) prefix = strings.TrimSuffix(prefix, `"`) - return fmt.Sprintf(`"%s/%s"`, prefix, strings.TrimPrefix(route.Path, "/")) + return fmt.Sprintf("`%s/%s`", prefix, strings.TrimPrefix(routePath, "/")) } func pathHasParams(route spec.Route) bool { @@ -182,3 +223,21 @@ func hasRequestBody(route spec.Route) bool { return len(route.RequestTypeName()) > 0 && len(ds.GetBodyMembers()) > 0 } + +func hasRequestPath(route spec.Route) bool { + ds, ok := route.RequestType.(spec.DefineStruct) + if !ok { + return false + } + + return len(route.RequestTypeName()) > 0 && len(ds.GetTagMembers(pathTagKey)) > 0 +} + +func hasRequestHeader(route spec.Route) bool { + ds, ok := route.RequestType.(spec.DefineStruct) + if !ok { + return false + } + + return len(route.RequestTypeName()) > 0 && len(ds.GetTagMembers(headerTagKey)) > 0 +} diff --git a/tools/goctl/api/tsgen/util.go b/tools/goctl/api/tsgen/util.go index 5c608b51..27c230d9 100644 --- a/tools/goctl/api/tsgen/util.go +++ b/tools/goctl/api/tsgen/util.go @@ -11,6 +11,12 @@ import ( "github.com/zeromicro/go-zero/tools/goctl/util" ) +const ( + formTagKey = "form" + pathTagKey = "path" + headerTagKey = "header" +) + func writeProperty(writer io.Writer, member spec.Member, indent int) error { writeIndent(writer, indent) ty, err := goTypeToTs(member.Type, false) @@ -129,13 +135,21 @@ func genParamsTypesIfNeed(writer io.Writer, tp spec.Type) error { if len(members) == 0 { return nil } - fmt.Fprintf(writer, "\n") + fmt.Fprintf(writer, "export interface %sParams {\n", util.Title(tp.Name())) - if err := writeMembers(writer, tp, true); err != nil { + if err := writeTagMembers(writer, tp, formTagKey); err != nil { return err } - fmt.Fprintf(writer, "}\n") + + if len(definedType.GetTagMembers(headerTagKey)) > 0 { + fmt.Fprintf(writer, "export interface %sHeaders {\n", util.Title(tp.Name())) + if err := writeTagMembers(writer, tp, headerTagKey); err != nil { + return err + } + fmt.Fprintf(writer, "}\n") + } + return nil } @@ -168,3 +182,30 @@ func writeMembers(writer io.Writer, tp spec.Type, isParam bool) error { } return nil } + +func writeTagMembers(writer io.Writer, tp spec.Type, tagKey string) error { + definedType, ok := tp.(spec.DefineStruct) + if !ok { + pointType, ok := tp.(spec.PointerType) + if ok { + return writeTagMembers(writer, pointType.Type, tagKey) + } + + return fmt.Errorf("type %s not supported", tp.Name()) + } + + members := definedType.GetTagMembers(tagKey) + for _, member := range members { + if member.IsInline { + if err := writeTagMembers(writer, member.Type, tagKey); err != nil { + return err + } + continue + } + + if err := writeProperty(writer, member, 1); err != nil { + return apiutil.WrapErr(err, " type "+tp.Name()) + } + } + return nil +}