Skip to content

Commit fbc74c8

Browse files
committed
Decorrelates openapi registration from server
1 parent caac676 commit fbc74c8

File tree

11 files changed

+115
-75
lines changed

11 files changed

+115
-75
lines changed

examples/petstore/lib/server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func TestPetstoreOpenAPIGeneration(t *testing.T) {
2121
)
2222

2323
server.OutputOpenAPISpec()
24-
err := server.OpenApiSpec.Validate(context.Background())
24+
err := server.OpenAPIzer.OpenAPIDescription().Validate(context.Background())
2525
require.NoError(t, err)
2626

2727
generatedSpec, err := os.ReadFile("testdata/doc/openapi.json")

examples/petstore/lib/testdata/doc/openapi.golden.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
{
2-
"openapi": "3.1.0",
32
"components": {
43
"schemas": {
54
"HTTPError": {
@@ -149,6 +148,7 @@
149148
"title": "OpenAPI",
150149
"version": "0.0.1"
151150
},
151+
"openapi": "3.1.0",
152152
"paths": {
153153
"/pets/": {
154154
"get": {

mux.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func Register[T, B any](s *Server, route Route[T, B], controller http.Handler, o
112112
route.Path = s.basePath + route.Path
113113

114114
var err error
115-
route.Operation, err = RegisterOpenAPIOperation(s, route)
115+
route.Operation, err = RegisterOpenAPIOperation(s.OpenAPIzer, route)
116116
if err != nil {
117117
slog.Warn("error documenting openapi operation", "error", err)
118118
}

mux_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
488488
Get(s, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil })
489489

490490
require.Equal(t, s.DisableOpenapi, true)
491-
require.True(t, s.OpenApiSpec.Paths.Find("/not-hidden") != nil)
492-
require.True(t, s.OpenApiSpec.Paths.Find("/test") == nil)
491+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil)
492+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test") == nil)
493493
})
494494

495495
t.Run("hide group", func(t *testing.T) {
@@ -500,8 +500,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
500500
Get(g, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil })
501501

502502
require.Equal(t, g.DisableOpenapi, true)
503-
require.True(t, s.OpenApiSpec.Paths.Find("/not-hidden") != nil)
504-
require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil)
503+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil)
504+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
505505
})
506506

507507
t.Run("hide group but not other group", func(t *testing.T) {
@@ -514,8 +514,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
514514

515515
require.Equal(t, true, g.DisableOpenapi)
516516
require.Equal(t, false, g2.DisableOpenapi)
517-
require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil)
518-
require.True(t, s.OpenApiSpec.Paths.Find("/group2/test") != nil)
517+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
518+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group2/test") != nil)
519519
})
520520

521521
t.Run("hide group but show sub group", func(t *testing.T) {
@@ -527,8 +527,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
527527
Get(g2, "/test", func(ctx *ContextNoBody) (string, error) { return "test", nil })
528528

529529
require.Equal(t, true, g.DisableOpenapi)
530-
require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil)
531-
require.True(t, s.OpenApiSpec.Paths.Find("/group/sub/test") != nil)
530+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
531+
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/sub/test") != nil)
532532
})
533533
}
534534

openapi.go

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,45 @@ import (
1616
"strings"
1717

1818
"github.com/getkin/kin-openapi/openapi3"
19+
"github.com/getkin/kin-openapi/openapi3gen"
1920
)
2021

22+
type OpenAPIzer interface {
23+
OpenAPIDescription() *openapi3.T
24+
Generator() *openapi3gen.Generator
25+
GlobalOpenAPIResponses() *[]openAPIError
26+
}
27+
28+
func NewSpec() *Spec {
29+
desc := NewOpenApiSpec()
30+
return &Spec{
31+
description: &desc,
32+
generator: openapi3gen.NewGenerator(),
33+
globalOpenAPIResponses: &[]openAPIError{},
34+
}
35+
}
36+
37+
// Holds the OpenAPI OpenAPIDescription (OAD) and OpenAPI capabilities.
38+
type Spec struct {
39+
description *openapi3.T
40+
generator *openapi3gen.Generator
41+
globalOpenAPIResponses *[]openAPIError
42+
}
43+
44+
func (d *Spec) OpenAPIDescription() *openapi3.T {
45+
return d.description
46+
}
47+
48+
func (d *Spec) Generator() *openapi3gen.Generator {
49+
return d.generator
50+
}
51+
52+
func (d *Spec) GlobalOpenAPIResponses() *[]openAPIError {
53+
return d.globalOpenAPIResponses
54+
}
55+
56+
var _ OpenAPIzer = &Spec{}
57+
2158
func NewOpenApiSpec() openapi3.T {
2259
info := &openapi3.Info{
2360
Title: "OpenAPI",
@@ -53,11 +90,11 @@ func (s *Server) Show() *Server {
5390
}
5491

5592
func declareAllTagsFromOperations(s *Server) {
56-
for _, pathItem := range s.OpenApiSpec.Paths.Map() {
93+
for _, pathItem := range s.OpenAPIzer.OpenAPIDescription().Paths.Map() {
5794
for _, op := range pathItem.Operations() {
5895
for _, tag := range op.Tags {
59-
if s.OpenApiSpec.Tags.Get(tag) == nil {
60-
s.OpenApiSpec.Tags = append(s.OpenApiSpec.Tags, &openapi3.Tag{
96+
if s.OpenAPIzer.OpenAPIDescription().Tags.Get(tag) == nil {
97+
s.OpenAPIzer.OpenAPIDescription().Tags = append(s.OpenAPIzer.OpenAPIDescription().Tags, &openapi3.Tag{
6198
Name: tag,
6299
})
63100
}
@@ -73,7 +110,7 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
73110
declareAllTagsFromOperations(s)
74111

75112
// Validate
76-
err := s.OpenApiSpec.Validate(context.Background())
113+
err := s.OpenAPIzer.OpenAPIDescription().Validate(context.Background())
77114
if err != nil {
78115
slog.Error("Error validating spec", "error", err)
79116
}
@@ -95,14 +132,14 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
95132
}
96133
}
97134

98-
return s.OpenApiSpec
135+
return *s.OpenAPIzer.OpenAPIDescription()
99136
}
100137

101138
func (s *Server) marshalSpec() ([]byte, error) {
102139
if s.OpenAPIConfig.PrettyFormatJson {
103-
return json.MarshalIndent(s.OpenApiSpec, "", " ")
140+
return json.MarshalIndent(s.OpenAPIzer.OpenAPIDescription(), "", " ")
104141
}
105-
return json.Marshal(s.OpenApiSpec)
142+
return json.Marshal(s.OpenAPIzer.OpenAPIDescription())
106143
}
107144

108145
func (s *Server) saveOpenAPIToFile(jsonSpecLocalPath string, jsonSpec []byte) error {
@@ -164,7 +201,7 @@ func validateSwaggerUrl(swaggerUrl string) bool {
164201
}
165202

166203
// RegisterOpenAPIOperation registers an OpenAPI operation.
167-
func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3.Operation, error) {
204+
func RegisterOpenAPIOperation[T, B any](s OpenAPIzer, route Route[T, B]) (*openapi3.Operation, error) {
168205
if route.Operation == nil {
169206
route.Operation = openapi3.NewOperation()
170207
}
@@ -184,7 +221,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3
184221
}
185222

186223
// Response - globals
187-
for _, openAPIGlobalResponse := range s.globalOpenAPIResponses {
224+
for _, openAPIGlobalResponse := range *s.GlobalOpenAPIResponses() {
188225
addResponseIfNotSet(s, route.Operation, openAPIGlobalResponse.Code, openAPIGlobalResponse.Description, openAPIGlobalResponse.ErrorType)
189226
}
190227

@@ -228,7 +265,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3
228265
}
229266
}
230267

231-
s.OpenApiSpec.AddOperation(route.Path, route.Method, route.Operation)
268+
s.OpenAPIDescription().AddOperation(route.Path, route.Method, route.Operation)
232269

233270
return route.Operation, nil
234271
}
@@ -247,10 +284,10 @@ type SchemaTag struct {
247284
Name string
248285
}
249286

250-
func SchemaTagFromType(s *Server, v any) SchemaTag {
287+
func SchemaTagFromType(s OpenAPIzer, v any) SchemaTag {
251288
if v == nil {
252289
// ensure we add unknown-interface to our schemas
253-
schema := s.getOrCreateSchema("unknown-interface", struct{}{})
290+
schema := getOrCreateSchema(s, "unknown-interface", struct{}{})
254291
return SchemaTag{
255292
Name: "unknown-interface",
256293
SchemaRef: openapi3.SchemaRef{
@@ -270,7 +307,7 @@ func SchemaTagFromType(s *Server, v any) SchemaTag {
270307
// If the type is a slice or array type it will dive into the type as well as
271308
// build and openapi3.Schema where Type is array and Ref is set to the proper
272309
// components Schema
273-
func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
310+
func dive(s OpenAPIzer, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
274311
if maxDepth == 0 {
275312
return SchemaTag{
276313
Name: "default",
@@ -297,26 +334,26 @@ func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
297334
return dive(s, t.Field(0).Type, tag, maxDepth-1)
298335
}
299336
tag.Ref = "#/components/schemas/" + tag.Name
300-
tag.Value = s.getOrCreateSchema(tag.Name, reflect.New(t).Interface())
337+
tag.Value = getOrCreateSchema(s, tag.Name, reflect.New(t).Interface())
301338

302339
return tag
303340
}
304341
}
305342

306343
// getOrCreateSchema is used to get a schema from the OpenAPI spec.
307344
// If the schema does not exist, it will create a new schema and add it to the OpenAPI spec.
308-
func (s *Server) getOrCreateSchema(key string, v any) *openapi3.Schema {
309-
schemaRef, ok := s.OpenApiSpec.Components.Schemas[key]
345+
func getOrCreateSchema(s OpenAPIzer, key string, v any) *openapi3.Schema {
346+
schemaRef, ok := s.OpenAPIDescription().Components.Schemas[key]
310347
if !ok {
311-
schemaRef = s.createSchema(key, v)
348+
schemaRef = createSchema(s, key, v)
312349
}
313350
return schemaRef.Value
314351
}
315352

316353
// createSchema is used to create a new schema and add it to the OpenAPI spec.
317354
// Relies on the openapi3gen package to generate the schema, and adds custom struct tags.
318-
func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
319-
schemaRef, err := s.openAPIGenerator.NewSchemaRefForValue(v, s.OpenApiSpec.Components.Schemas)
355+
func createSchema(s OpenAPIzer, key string, v any) *openapi3.SchemaRef {
356+
schemaRef, err := s.Generator().NewSchemaRefForValue(v, s.OpenAPIDescription().Components.Schemas)
320357
if err != nil {
321358
slog.Error("Error generating schema", "key", key, "error", err)
322359
}
@@ -327,9 +364,9 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
327364
schemaRef.Value.Description = descriptionable.Description()
328365
}
329366

330-
s.parseStructTags(reflect.TypeOf(v), schemaRef)
367+
parseStructTags(reflect.TypeOf(v), schemaRef)
331368

332-
s.OpenApiSpec.Components.Schemas[key] = schemaRef
369+
s.OpenAPIDescription().Components.Schemas[key] = schemaRef
333370

334371
return schemaRef
335372
}
@@ -346,7 +383,7 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
346383
// - min=1 => minLength=1 (for strings)
347384
// - max=100 => max=100 (for integers)
348385
// - max=100 => maxLength=100 (for strings)
349-
func (s *Server) parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef) {
386+
func parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef) {
350387
if t.Kind() == reflect.Ptr {
351388
t = t.Elem()
352389
}
@@ -360,7 +397,7 @@ func (s *Server) parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef)
360397

361398
if field.Anonymous {
362399
fieldType := field.Type
363-
s.parseStructTags(fieldType, schemaRef)
400+
parseStructTags(fieldType, schemaRef)
364401
continue
365402
}
366403

openapi_operations.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (r Route[ResponseBody, RequestBody]) Param(paramType ParamType, name, descr
6262
}
6363

6464
// Registers a response for the route, only if error for this code is not already set.
65-
func addResponseIfNotSet(s *Server, operation *openapi3.Operation, code int, description string, errorType ...any) {
65+
func addResponseIfNotSet(s OpenAPIzer, operation *openapi3.Operation, code int, description string, errorType ...any) {
6666
var responseSchema SchemaTag
6767

6868
if len(errorType) > 0 {

openapi_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ func TestDeclareCustom200Response(t *testing.T) {
545545
w.Write([]byte("PNG image"))
546546
}, optionReturnsPNG)
547547

548-
openAPIResponse := s.OpenApiSpec.Paths.Find("/image").Get.Responses.Value("200")
548+
openAPIResponse := s.OpenAPIzer.OpenAPIDescription().Paths.Find("/image").Get.Responses.Value("200")
549549
require.Nil(t, openAPIResponse.Value.Content.Get("application/json"))
550550
require.NotNil(t, openAPIResponse.Value.Content.Get("image/png"))
551551
require.Equal(t, "Generated image", *openAPIResponse.Value.Description)

option.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ func OptionAddError(code int, description string, errorType ...any) func(*BaseRo
302302
}
303303

304304
if len(errorType) > 0 {
305-
responseSchema = SchemaTagFromType(r.mainRouter, errorType[0])
305+
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, errorType[0])
306306
} else {
307-
responseSchema = SchemaTagFromType(r.mainRouter, HTTPError{})
307+
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, HTTPError{})
308308
}
309309
content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"})
310310

@@ -374,14 +374,14 @@ func OptionDefaultStatusCode(defaultStatusCode int) func(*BaseRoute) {
374374
// })
375375
func OptionSecurity(securityRequirements ...openapi3.SecurityRequirement) func(*BaseRoute) {
376376
return func(r *BaseRoute) {
377-
if r.mainRouter.OpenApiSpec.Components == nil {
377+
if r.mainRouter.OpenAPIzer.OpenAPIDescription().Components == nil {
378378
panic("zero security schemes have been registered with the server")
379379
}
380380

381381
// Validate the security scheme exists in components
382382
for _, req := range securityRequirements {
383383
for schemeName := range req {
384-
if _, exists := r.mainRouter.OpenApiSpec.Components.SecuritySchemes[schemeName]; !exists {
384+
if _, exists := r.mainRouter.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes[schemeName]; !exists {
385385
panic(fmt.Sprintf("security scheme '%s' not defined in components", schemeName))
386386
}
387387
}

option_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ func TestPath(t *testing.T) {
299299

300300
fuego.Get(s, "/test/{id}", helloWorld)
301301

302-
require.Equal(t, "id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
303-
require.Equal(t, "", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
302+
require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
303+
require.Equal(t, "", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
304304
})
305305

306306
t.Run("Declare explicitly an existing path parameter for the route", func(t *testing.T) {
@@ -310,9 +310,9 @@ func TestPath(t *testing.T) {
310310
fuego.OptionPath("id", "some id", param.Example("123", "123"), param.Nullable()),
311311
)
312312

313-
require.Equal(t, "id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
314-
require.Equal(t, "some id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
315-
require.Equal(t, true, s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required")
313+
require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
314+
require.Equal(t, "some id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
315+
require.Equal(t, true, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required")
316316
})
317317

318318
t.Run("Declare explicitly a non-existing path parameter for the route panics", func(t *testing.T) {
@@ -353,7 +353,7 @@ func TestRequestContentType(t *testing.T) {
353353
require.NotNil(t, content.Get("application/json"))
354354
require.Nil(t, content.Get("application/xml"))
355355
require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref)
356-
_, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"]
356+
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
357357
require.False(t, ok)
358358
})
359359

@@ -369,7 +369,7 @@ func TestRequestContentType(t *testing.T) {
369369
require.Nil(t, content.Get("application/xml"))
370370
require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref)
371371
require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref)
372-
_, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"]
372+
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
373373
require.False(t, ok)
374374
})
375375

@@ -385,7 +385,7 @@ func TestRequestContentType(t *testing.T) {
385385
require.Nil(t, content.Get("application/xml"))
386386
require.NotNil(t, content.Get("my/content-type"))
387387
require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref)
388-
_, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"]
388+
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
389389
require.False(t, ok)
390390
})
391391
}

0 commit comments

Comments
 (0)