@@ -16,8 +16,45 @@ import (
16
16
"strings"
17
17
18
18
"github.com/getkin/kin-openapi/openapi3"
19
+ "github.com/getkin/kin-openapi/openapi3gen"
19
20
)
20
21
22
+ type OpenAPIzer interface {
23
+ OpenAPIDescription () * openapi3.T
24
+ Generator () * openapi3gen.Generator
25
+ GlobalOpenAPIResponses () * []openAPIError
26
+ }
27
+
28
+ func NewSpec () * Spec {
29
+ desc := NewOpenApiSpec ()
30
+ return & Spec {
31
+ description : & desc ,
32
+ generator : openapi3gen .NewGenerator (),
33
+ globalOpenAPIResponses : & []openAPIError {},
34
+ }
35
+ }
36
+
37
+ // Holds the OpenAPI OpenAPIDescription (OAD) and OpenAPI capabilities.
38
+ type Spec struct {
39
+ description * openapi3.T
40
+ generator * openapi3gen.Generator
41
+ globalOpenAPIResponses * []openAPIError
42
+ }
43
+
44
+ func (d * Spec ) OpenAPIDescription () * openapi3.T {
45
+ return d .description
46
+ }
47
+
48
+ func (d * Spec ) Generator () * openapi3gen.Generator {
49
+ return d .generator
50
+ }
51
+
52
+ func (d * Spec ) GlobalOpenAPIResponses () * []openAPIError {
53
+ return d .globalOpenAPIResponses
54
+ }
55
+
56
+ var _ OpenAPIzer = & Spec {}
57
+
21
58
func NewOpenApiSpec () openapi3.T {
22
59
info := & openapi3.Info {
23
60
Title : "OpenAPI" ,
@@ -53,11 +90,11 @@ func (s *Server) Show() *Server {
53
90
}
54
91
55
92
func declareAllTagsFromOperations (s * Server ) {
56
- for _ , pathItem := range s .OpenApiSpec .Paths .Map () {
93
+ for _ , pathItem := range s .OpenAPIzer . OpenAPIDescription () .Paths .Map () {
57
94
for _ , op := range pathItem .Operations () {
58
95
for _ , tag := range op .Tags {
59
- if s .OpenApiSpec .Tags .Get (tag ) == nil {
60
- s .OpenApiSpec . Tags = append (s .OpenApiSpec .Tags , & openapi3.Tag {
96
+ if s .OpenAPIzer . OpenAPIDescription () .Tags .Get (tag ) == nil {
97
+ s .OpenAPIzer . OpenAPIDescription (). Tags = append (s .OpenAPIzer . OpenAPIDescription () .Tags , & openapi3.Tag {
61
98
Name : tag ,
62
99
})
63
100
}
@@ -73,7 +110,7 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
73
110
declareAllTagsFromOperations (s )
74
111
75
112
// Validate
76
- err := s .OpenApiSpec .Validate (context .Background ())
113
+ err := s .OpenAPIzer . OpenAPIDescription () .Validate (context .Background ())
77
114
if err != nil {
78
115
slog .Error ("Error validating spec" , "error" , err )
79
116
}
@@ -95,14 +132,14 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
95
132
}
96
133
}
97
134
98
- return s . OpenApiSpec
135
+ return * s . OpenAPIzer . OpenAPIDescription ()
99
136
}
100
137
101
138
func (s * Server ) marshalSpec () ([]byte , error ) {
102
139
if s .OpenAPIConfig .PrettyFormatJson {
103
- return json .MarshalIndent (s .OpenApiSpec , "" , " " )
140
+ return json .MarshalIndent (s .OpenAPIzer . OpenAPIDescription () , "" , " " )
104
141
}
105
- return json .Marshal (s .OpenApiSpec )
142
+ return json .Marshal (s .OpenAPIzer . OpenAPIDescription () )
106
143
}
107
144
108
145
func (s * Server ) saveOpenAPIToFile (jsonSpecLocalPath string , jsonSpec []byte ) error {
@@ -164,7 +201,7 @@ func validateSwaggerUrl(swaggerUrl string) bool {
164
201
}
165
202
166
203
// RegisterOpenAPIOperation registers an OpenAPI operation.
167
- func RegisterOpenAPIOperation [T , B any ](s * Server , route Route [T , B ]) (* openapi3.Operation , error ) {
204
+ func RegisterOpenAPIOperation [T , B any ](s OpenAPIzer , route Route [T , B ]) (* openapi3.Operation , error ) {
168
205
if route .Operation == nil {
169
206
route .Operation = openapi3 .NewOperation ()
170
207
}
@@ -184,7 +221,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3
184
221
}
185
222
186
223
// Response - globals
187
- for _ , openAPIGlobalResponse := range s . globalOpenAPIResponses {
224
+ for _ , openAPIGlobalResponse := range * s . GlobalOpenAPIResponses () {
188
225
addResponseIfNotSet (s , route .Operation , openAPIGlobalResponse .Code , openAPIGlobalResponse .Description , openAPIGlobalResponse .ErrorType )
189
226
}
190
227
@@ -228,7 +265,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3
228
265
}
229
266
}
230
267
231
- s .OpenApiSpec .AddOperation (route .Path , route .Method , route .Operation )
268
+ s .OpenAPIDescription () .AddOperation (route .Path , route .Method , route .Operation )
232
269
233
270
return route .Operation , nil
234
271
}
@@ -247,10 +284,10 @@ type SchemaTag struct {
247
284
Name string
248
285
}
249
286
250
- func SchemaTagFromType (s * Server , v any ) SchemaTag {
287
+ func SchemaTagFromType (s OpenAPIzer , v any ) SchemaTag {
251
288
if v == nil {
252
289
// ensure we add unknown-interface to our schemas
253
- schema := s . getOrCreateSchema ("unknown-interface" , struct {}{})
290
+ schema := getOrCreateSchema (s , "unknown-interface" , struct {}{})
254
291
return SchemaTag {
255
292
Name : "unknown-interface" ,
256
293
SchemaRef : openapi3.SchemaRef {
@@ -270,7 +307,7 @@ func SchemaTagFromType(s *Server, v any) SchemaTag {
270
307
// If the type is a slice or array type it will dive into the type as well as
271
308
// build and openapi3.Schema where Type is array and Ref is set to the proper
272
309
// components Schema
273
- func dive (s * Server , t reflect.Type , tag SchemaTag , maxDepth int ) SchemaTag {
310
+ func dive (s OpenAPIzer , t reflect.Type , tag SchemaTag , maxDepth int ) SchemaTag {
274
311
if maxDepth == 0 {
275
312
return SchemaTag {
276
313
Name : "default" ,
@@ -297,26 +334,26 @@ func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
297
334
return dive (s , t .Field (0 ).Type , tag , maxDepth - 1 )
298
335
}
299
336
tag .Ref = "#/components/schemas/" + tag .Name
300
- tag .Value = s . getOrCreateSchema (tag .Name , reflect .New (t ).Interface ())
337
+ tag .Value = getOrCreateSchema (s , tag .Name , reflect .New (t ).Interface ())
301
338
302
339
return tag
303
340
}
304
341
}
305
342
306
343
// getOrCreateSchema is used to get a schema from the OpenAPI spec.
307
344
// If the schema does not exist, it will create a new schema and add it to the OpenAPI spec.
308
- func (s * Server ) getOrCreateSchema ( key string , v any ) * openapi3.Schema {
309
- schemaRef , ok := s .OpenApiSpec .Components .Schemas [key ]
345
+ func getOrCreateSchema (s OpenAPIzer , key string , v any ) * openapi3.Schema {
346
+ schemaRef , ok := s .OpenAPIDescription () .Components .Schemas [key ]
310
347
if ! ok {
311
- schemaRef = s . createSchema (key , v )
348
+ schemaRef = createSchema (s , key , v )
312
349
}
313
350
return schemaRef .Value
314
351
}
315
352
316
353
// createSchema is used to create a new schema and add it to the OpenAPI spec.
317
354
// Relies on the openapi3gen package to generate the schema, and adds custom struct tags.
318
- func (s * Server ) createSchema ( key string , v any ) * openapi3.SchemaRef {
319
- schemaRef , err := s .openAPIGenerator .NewSchemaRefForValue (v , s .OpenApiSpec .Components .Schemas )
355
+ func createSchema (s OpenAPIzer , key string , v any ) * openapi3.SchemaRef {
356
+ schemaRef , err := s .Generator () .NewSchemaRefForValue (v , s .OpenAPIDescription () .Components .Schemas )
320
357
if err != nil {
321
358
slog .Error ("Error generating schema" , "key" , key , "error" , err )
322
359
}
@@ -327,9 +364,9 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
327
364
schemaRef .Value .Description = descriptionable .Description ()
328
365
}
329
366
330
- s . parseStructTags (reflect .TypeOf (v ), schemaRef )
367
+ parseStructTags (reflect .TypeOf (v ), schemaRef )
331
368
332
- s .OpenApiSpec .Components .Schemas [key ] = schemaRef
369
+ s .OpenAPIDescription () .Components .Schemas [key ] = schemaRef
333
370
334
371
return schemaRef
335
372
}
@@ -346,7 +383,7 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
346
383
// - min=1 => minLength=1 (for strings)
347
384
// - max=100 => max=100 (for integers)
348
385
// - max=100 => maxLength=100 (for strings)
349
- func ( s * Server ) parseStructTags (t reflect.Type , schemaRef * openapi3.SchemaRef ) {
386
+ func parseStructTags (t reflect.Type , schemaRef * openapi3.SchemaRef ) {
350
387
if t .Kind () == reflect .Ptr {
351
388
t = t .Elem ()
352
389
}
@@ -360,7 +397,7 @@ func (s *Server) parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef)
360
397
361
398
if field .Anonymous {
362
399
fieldType := field .Type
363
- s . parseStructTags (fieldType , schemaRef )
400
+ parseStructTags (fieldType , schemaRef )
364
401
continue
365
402
}
366
403
0 commit comments