Skip to content

Commit

Permalink
Merge pull request #1788 from statechannels/fix-cors
Browse files Browse the repository at this point in the history
Enable CORS header on error
  • Loading branch information
lalexgap authored Sep 21, 2023
2 parents cbf2030 + be7537a commit 2e8cde4
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions paymentproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,7 @@ func (p *PaymentProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// It will check the voucher amount against the cost (response size * cost per byte)
// If the voucher amount is less than the cost, it will return a 402 Payment Required error instead of serving the content
func (p *PaymentProxy) handleDestinationResponse(r *http.Response) error {
// Add CORS headers to allow all origins (*).
if r.Header.Get("Access-Control-Allow-Origin") == "" {
// We want to set this header exactly once (the destination server may have already set it)
r.Header.Set("Access-Control-Allow-Origin", "*")
}
r.Header.Set("Access-Control-Allow-Headers", "*")
r.Header.Set("Access-Control-Expose-Headers", "*")

enableCors(r.Header)
// Ignore OPTIONS requests as they are preflight requests
if r.Request.Method == "OPTIONS" {
return nil
Expand Down Expand Up @@ -164,6 +157,7 @@ func (p *PaymentProxy) handleDestinationResponse(r *http.Response) error {

// handleError is responsible for logging the error and returning the appropriate HTTP status code
func (p *PaymentProxy) handleError(w http.ResponseWriter, r *http.Request, err error) {
enableCors(w.Header())
if errors.Is(err, ErrPayment) {
http.Error(w, err.Error(), http.StatusPaymentRequired)
} else {
Expand Down Expand Up @@ -258,3 +252,16 @@ func readBodyLength(b io.ReadCloser) (uint64, error) {

return byteCount, nil
}

// enableCors sets the CORS headers if they are not already set
func enableCors(header http.Header) {
if header.Get("Access-Control-Allow-Origin") == "" {
header.Set("Access-Control-Allow-Origin", "*")
}
if header.Get("Access-Control-Allow-Headers") == "" {
header.Set("Access-Control-Allow-Headers", "*")
}
if header.Get("Access-Control-Expose-Headers") == "" {
header.Set("Access-Control-Expose-Headers", "*")
}
}

0 comments on commit 2e8cde4

Please sign in to comment.