@@ -3,6 +3,7 @@ package v3
33import (
44 "fmt"
55 "regexp"
6+ "strings"
67
78 semver "github.com/Masterminds/semver/v3"
89 "github.com/spf13/cobra"
@@ -11,8 +12,9 @@ import (
1112)
1213
1314type ctxRepl struct {
14- pkg string
15- replFmt string
15+ pkg string
16+ replFmt string
17+ isDefault bool
1618}
1719
1820func parseMiddlewareImports (content string , reImport * regexp.Regexp ) map [string ]string {
@@ -29,29 +31,7 @@ func parseMiddlewareImports(content string, reImport *regexp.Regexp) map[string]
2931}
3032
3133func MigrateMiddlewareLocals (cmd * cobra.Command , cwd string , _ , _ * semver.Version ) error {
32- ctxMap := map [string ][]ctxRepl {
33- "requestid" : {
34- {pkg : "requestid" , replFmt : "requestid.FromContext(%s)" },
35- },
36- "csrf" : {
37- {pkg : "csrf" , replFmt : "csrf.TokenFromContext(%s)" },
38- },
39- "csrf_handler" : {
40- {pkg : "csrf" , replFmt : "csrf.HandlerFromContext(%s)" },
41- },
42- "session" : {
43- {pkg : "session" , replFmt : "session.FromContext(%s)" },
44- },
45- "username" : {
46- {pkg : "basicauth" , replFmt : "basicauth.UsernameFromContext(%s)" },
47- },
48- "password" : {
49- {pkg : "basicauth" , replFmt : "basicauth.PasswordFromContext(%s)" },
50- },
51- "token" : {
52- {pkg : "keyauth" , replFmt : "keyauth.TokenFromContext(%s)" },
53- },
54- }
34+ ctxMap := map [string ][]ctxRepl {}
5535
5636 extractors := []struct {
5737 pkg string
@@ -76,10 +56,16 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
7656 if e .pkg != pkg {
7757 continue
7858 }
79- re := regexp .MustCompile (alias + `\.Config{[^}]*` + e . field + `:\s*"([^"]+)" ` )
80- matches := re . FindAllStringSubmatch (content , - 1 )
59+ reCfg := regexp .MustCompile (regexp . QuoteMeta ( alias ) + `\.Config{` )
60+ matches := reCfg . FindAllStringIndex (content , - 1 )
8161 for _ , m := range matches {
82- ctxMap [m [1 ]] = append (ctxMap [m [1 ]], ctxRepl {pkg : e .pkg , replFmt : e .replFmt })
62+ start := m [0 ]
63+ end := extractBlock (content , m [1 ], '{' , '}' )
64+ cfg := content [start :end ]
65+ reField := regexp .MustCompile (e .field + `:\s*"([^"]+)"` )
66+ for _ , fm := range reField .FindAllStringSubmatch (cfg , - 1 ) {
67+ ctxMap [fm [1 ]] = append (ctxMap [fm [1 ]], ctxRepl {pkg : e .pkg , replFmt : e .replFmt })
68+ }
8369 }
8470 }
8571 }
@@ -90,6 +76,30 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
9076 return fmt .Errorf ("failed to gather middleware locals: %w" , err )
9177 }
9278
79+ defaults := map [string ][]ctxRepl {
80+ "requestid" : {{pkg : "requestid" , replFmt : "requestid.FromContext(%s)" , isDefault : true }},
81+ "csrf" : {{pkg : "csrf" , replFmt : "csrf.TokenFromContext(%s)" , isDefault : true }},
82+ "csrf_handler" : {{pkg : "csrf" , replFmt : "csrf.HandlerFromContext(%s)" , isDefault : true }},
83+ "session" : {{pkg : "session" , replFmt : "session.FromContext(%s)" , isDefault : true }},
84+ "username" : {{pkg : "basicauth" , replFmt : "basicauth.UsernameFromContext(%s)" , isDefault : true }},
85+ "password" : {{pkg : "basicauth" , replFmt : "basicauth.PasswordFromContext(%s)" , isDefault : true }},
86+ "token" : {{pkg : "keyauth" , replFmt : "keyauth.TokenFromContext(%s)" , isDefault : true }},
87+ }
88+ for key , repls := range defaults {
89+ for _ , r := range repls {
90+ exists := false
91+ for _ , existing := range ctxMap [key ] {
92+ if existing .pkg == r .pkg {
93+ exists = true
94+ break
95+ }
96+ }
97+ if ! exists {
98+ ctxMap [key ] = append (ctxMap [key ], r )
99+ }
100+ }
101+ }
102+
93103 // second pass: perform replacements and clean up
94104 changed , err := internal .ChangeFileContent (cwd , func (content string ) string {
95105 imports := parseMiddlewareImports (content , reImport )
@@ -99,17 +109,45 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
99109 sub := reLocals .FindStringSubmatch (s )
100110 ctx := sub [1 ]
101111 key := sub [2 ]
102- if repls , ok := ctxMap [key ]; ok {
103- if len (repls ) == 1 {
104- return fmt .Sprintf (repls [0 ].replFmt , ctx )
112+ repls , ok := ctxMap [key ]
113+ if ! ok {
114+ return s
115+ }
116+
117+ var custom , defs []ctxRepl
118+ for _ , r := range repls {
119+ if r .isDefault {
120+ defs = append (defs , r )
121+ } else {
122+ custom = append (custom , r )
105123 }
106- for _ , r := range repls {
124+ }
125+
126+ choose := func (r ctxRepl ) string { return fmt .Sprintf (r .replFmt , ctx ) }
127+
128+ if len (custom ) == 1 {
129+ return choose (custom [0 ])
130+ }
131+ if len (custom ) > 1 {
132+ for _ , r := range custom {
107133 for _ , pkg := range imports {
108134 if pkg == r .pkg {
109- return fmt . Sprintf ( r . replFmt , ctx )
135+ return choose ( r )
110136 }
111137 }
112138 }
139+ return s
140+ }
141+
142+ if len (defs ) == 1 {
143+ return choose (defs [0 ])
144+ }
145+ for _ , r := range defs {
146+ for _ , pkg := range imports {
147+ if pkg == r .pkg {
148+ return choose (r )
149+ }
150+ }
113151 }
114152 return s
115153 })
@@ -121,13 +159,32 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
121159 content = reComma .ReplaceAllString (content , "$1, $2 := $3, true" )
122160
123161 for alias := range imports {
124- reCfg := regexp .MustCompile (alias + `\.Config{[^}]*}` )
125- content = reCfg .ReplaceAllStringFunc (content , func (cfg string ) string {
162+ reCfg := regexp .MustCompile (regexp .QuoteMeta (alias ) + `\.Config{` )
163+ matches := reCfg .FindAllStringIndex (content , - 1 )
164+ if len (matches ) == 0 {
165+ continue
166+ }
167+ var b strings.Builder
168+ last := 0
169+ for _ , m := range matches {
170+ if _ , err := b .WriteString (content [last :m [0 ]]); err != nil {
171+ return content
172+ }
173+ start := m [0 ]
174+ end := extractBlock (content , m [1 ], '{' , '}' )
175+ cfg := content [start :end ]
126176 cfg = removeConfigField (cfg , "ContextKey" )
127177 cfg = removeConfigField (cfg , "ContextUsername" )
128178 cfg = removeConfigField (cfg , "ContextPassword" )
129- return cfg
130- })
179+ if _ , err := b .WriteString (cfg ); err != nil {
180+ return content
181+ }
182+ last = end
183+ }
184+ if _ , err := b .WriteString (content [last :]); err != nil {
185+ return content
186+ }
187+ content = b .String ()
131188 }
132189
133190 return content
0 commit comments