Skip to content

Commit

Permalink
Update for newer go idioms
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-engledew committed Nov 26, 2023
1 parent 9685220 commit 32d8bf6
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 42 deletions.
31 changes: 25 additions & 6 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,43 @@ import (
// without being parsed.
type EditorFunc func(raw []byte, token *html.Token) (data []byte, done bool)

var headTag = []byte("head")
func Attrs(replacements map[string]string) EditorFunc {
return func(raw []byte, token *html.Token) ([]byte, bool) {
if token.Type == html.StartTagToken {
var found bool
for n := range token.Attr {
if val, ok := replacements[token.Attr[n].Key]; ok {
token.Attr[n].Val = val
found = true
}
}
if found {
return []byte(token.String()), false
}
}
return raw, false
}
}

// AfterHead returns an EditorFunc that will inject data after the first <head> tag.
func AfterHead(data string) EditorFunc {
func AfterTag(tag, data string, once bool) EditorFunc {
return func(raw []byte, token *html.Token) ([]byte, bool) {
if token.Type == html.StartTagToken {
if token.Data == "head" {
if token.Data == tag {
combined := make([]byte, 0, len(raw)+len(data))
combined = append(combined, raw...)
combined = append(combined, data...)
return combined, true
return combined, once
}
}
return raw, false
}
}

// AfterHead returns an EditorFunc that will inject data after the first <head> tag.
func AfterHead(data string) EditorFunc {
return AfterTag("head", data, true)
}

// Handle will rewriteFn any text/html documents that are served by next.
//
// On each request, Handle will call processRequest to provide an EditorFunc.
Expand All @@ -45,7 +65,6 @@ func Handle(next http.Handler, processRequest func(r *http.Request) (EditorFunc,
r.Header.Set("Accept-Encoding", "identity")

fn, err := processRequest(r)

if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
57 changes: 24 additions & 33 deletions response_editor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,54 @@ import (
)

type ResponseEditor struct {
rewriteFn EditorFunc
writeOnce sync.Once
writeHeaderOnce sync.Once
target http.ResponseWriter
body io.WriteCloser
statusCode int
http.ResponseWriter
rewriteFn EditorFunc
writeOnce sync.Once
closeOnce sync.Once
body io.WriteCloser
}

func (r *ResponseEditor) Unwrap() http.ResponseWriter {
return r.ResponseWriter
}

var _ io.WriteCloser = &ResponseEditor{}

// NewResponseEditor will return a ResponseEditor that inspects the http response
// and rewrites the HTML document before passing it to w.
func NewResponseEditor(w http.ResponseWriter, rewriteFn EditorFunc) *ResponseEditor {
return &ResponseEditor{
target: w,
rewriteFn: rewriteFn,
statusCode: http.StatusOK,
ResponseWriter: w,
rewriteFn: rewriteFn,
}
}

func (r *ResponseEditor) Header() http.Header {
return r.target.Header()
}

func (r *ResponseEditor) Write(p []byte) (int, error) {
r.writeOnce.Do(func() {
header := r.target.Header()
header := r.ResponseWriter.Header()

// TODO: handle content encoding

if strings.HasPrefix(header.Get("Content-Type"), "text/html") {
header.Set("Transfer-Encoding", "chunked")
header.Del("Content-Length")

r.body = NewTokenEditor(r.target, r.rewriteFn)
r.body = NewTokenEditor(r.ResponseWriter, r.rewriteFn)
}

r.writeHeaderOnce.Do(func() {
r.target.WriteHeader(r.statusCode)
})
})

if r.body != nil {
return r.body.Write(p)
}
return r.target.Write(p)
}

func (r *ResponseEditor) WriteHeader(statusCode int) {
r.statusCode = statusCode
return r.ResponseWriter.Write(p)
}

func (r *ResponseEditor) Close() error {
if r.body != nil {
return r.body.Close()
} else {
r.writeHeaderOnce.Do(func() {
r.target.WriteHeader(r.statusCode)
})
}
return nil
func (r *ResponseEditor) Close() (err error) {
r.closeOnce.Do(func() {
if r.body != nil {
err = r.body.Close()
}
})

return
}
3 changes: 2 additions & 1 deletion scanner.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rewritehtml

import (
"errors"
"golang.org/x/net/html"
"io"
)
Expand Down Expand Up @@ -74,7 +75,7 @@ func (s *Scanner) Next(atEOF bool) (raw []byte, token *html.Token, err error) {
if tt == html.ErrorToken {
nextErr := s.tokenizer.Err()

if nextErr == io.ErrNoProgress {
if errors.Is(nextErr, io.ErrNoProgress) {
s.Concat(nil)
if atEOF {
// recreate tokenizer
Expand Down
7 changes: 5 additions & 2 deletions token_editor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package rewritehtml

import "io"
import (
"errors"
"io"
)

type TokenEditor struct {
target io.Writer
Expand All @@ -22,7 +25,7 @@ func NewTokenEditor(w io.Writer, rewriteFn EditorFunc) *TokenEditor {
func (i *TokenEditor) doWrite(atEOF bool) error {
for !i.done {
raw, token, err := i.scanner.Next(atEOF)
if !atEOF && err == io.ErrNoProgress {
if !atEOF && errors.Is(err, io.ErrNoProgress) {
break
}
if err != nil {
Expand Down

0 comments on commit 32d8bf6

Please sign in to comment.