Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sec: fix s3 and gcs host checks #512

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions detect_gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ package getter

import (
"fmt"
"net/url"
"strings"

urlhelper "github.com/hashicorp/go-getter/helper/url"
)

// GCSDetector implements Detector to detect GCS URLs and turn
Expand All @@ -18,8 +19,17 @@ func (d *GCSDetector) Detect(src, _ string) (string, bool, error) {
return "", false, nil
}

if strings.Contains(src, "googleapis.com/") {
return d.detectHTTP(src)
if !strings.HasPrefix(src, "http://") && !strings.HasPrefix(src, "https://") {
src = "https://" + src
}

parsedURL, err := urlhelper.Parse(src)
if err != nil {
return "", false, fmt.Errorf("error parsing GCS URL")
}

if strings.HasSuffix(parsedURL.Host, ".googleapis.com") {
return d.detectHTTP(strings.ReplaceAll(src, "https://", ""))
}

return "", false, nil
Expand All @@ -36,7 +46,7 @@ func (d *GCSDetector) detectHTTP(src string) (string, bool, error) {
bucket := parts[3]
object := strings.Join(parts[4:], "/")

url, err := url.Parse(fmt.Sprintf("https://www.googleapis.com/storage/%s/%s/%s",
url, err := urlhelper.Parse(fmt.Sprintf("https://www.googleapis.com/storage/%s/%s/%s",
version, bucket, object))
if err != nil {
return "", false, fmt.Errorf("error parsing GCS URL: %s", err)
Expand Down
53 changes: 53 additions & 0 deletions detect_gcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func TestGCSDetector(t *testing.T) {
"www.googleapis.com/storage/v1/foo/bar.baz",
"gcs::https://www.googleapis.com/storage/v1/foo/bar.baz",
},
{
"www.googleapis.com/storage/v2/foo/bar/toor.baz",
"gcs::https://www.googleapis.com/storage/v2/foo/bar/toor.baz",
},
}

pwd := "/pwd"
Expand All @@ -42,3 +46,52 @@ func TestGCSDetector(t *testing.T) {
}
}
}

func TestGCSDetector_MalformedDetectHTTP(t *testing.T) {
cases := []struct {
Name string
Input string
Expected string
Output string
}{
{
"valid url",
"www.googleapis.com/storage/v1/my-bucket/foo/bar",
"",
"gcs::https://www.googleapis.com/storage/v1/my-bucket/foo/bar",
},
{
"empty url",
"",
"",
"",
},
{
"not valid url length",
"www.googleapis.com.invalid/storage/v1/",
"URL is not a valid GCS URL",
"",
},
{
"not valid url length",
"www.invalid.com/storage/v1",
"URL is not a valid GCS URL",
"",
},
}

pwd := "/pwd"
f := new(GCSDetector)
for _, tc := range cases {
output, _, err := f.Detect(tc.Input, pwd)
if err != nil {
if err.Error() != tc.Expected {
t.Fatalf("expected error %s, got %s for %s", tc.Expected, err.Error(), tc.Name)
}
}

if output != tc.Output {
t.Fatalf("expected %s, got %s for %s", tc.Output, output, tc.Name)
}
}
}
22 changes: 16 additions & 6 deletions detect_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ package getter

import (
"fmt"
"net/url"
"strings"

urlhelper "github.com/hashicorp/go-getter/helper/url"
)

// S3Detector implements Detector to detect S3 URLs and turn
Expand All @@ -18,8 +19,17 @@ func (d *S3Detector) Detect(src, _ string) (string, bool, error) {
return "", false, nil
}

if strings.Contains(src, ".amazonaws.com/") {
return d.detectHTTP(src)
if !strings.HasPrefix(src, "http://") && !strings.HasPrefix(src, "https://") {
src = "https://" + src
}

parsedURL, err := urlhelper.Parse(src)
if err != nil {
return "", false, fmt.Errorf("error parsing S3 URL")
}

if strings.HasSuffix(parsedURL.Host, ".amazonaws.com") {
return d.detectHTTP(strings.ReplaceAll(src, "https://", ""))
}

return "", false, nil
Expand Down Expand Up @@ -47,7 +57,7 @@ func (d *S3Detector) detectHTTP(src string) (string, bool, error) {

func (d *S3Detector) detectPathStyle(region string, parts []string) (string, bool, error) {
urlStr := fmt.Sprintf("https://%s.amazonaws.com/%s", region, strings.Join(parts, "/"))
url, err := url.Parse(urlStr)
url, err := urlhelper.Parse(urlStr)
if err != nil {
return "", false, fmt.Errorf("error parsing S3 URL: %s", err)
}
Expand All @@ -57,7 +67,7 @@ func (d *S3Detector) detectPathStyle(region string, parts []string) (string, boo

func (d *S3Detector) detectVhostStyle(region, bucket string, parts []string) (string, bool, error) {
urlStr := fmt.Sprintf("https://%s.amazonaws.com/%s/%s", region, bucket, strings.Join(parts, "/"))
url, err := url.Parse(urlStr)
url, err := urlhelper.Parse(urlStr)
if err != nil {
return "", false, fmt.Errorf("error parsing S3 URL: %s", err)
}
Expand All @@ -67,7 +77,7 @@ func (d *S3Detector) detectVhostStyle(region, bucket string, parts []string) (st

func (d *S3Detector) detectNewVhostStyle(region, bucket string, parts []string) (string, bool, error) {
urlStr := fmt.Sprintf("https://s3.%s.amazonaws.com/%s/%s", region, bucket, strings.Join(parts, "/"))
url, err := url.Parse(urlStr)
url, err := urlhelper.Parse(urlStr)
if err != nil {
return "", false, fmt.Errorf("error parsing S3 URL: %s", err)
}
Expand Down
4 changes: 3 additions & 1 deletion get_gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (g *GCSGetter) getObject(ctx context.Context, client *storage.Client, dst,
}

func (g *GCSGetter) parseURL(u *url.URL) (bucket, path, fragment string, err error) {
if strings.Contains(u.Host, "googleapis.com") {
if strings.HasSuffix(u.Host, ".googleapis.com") {
hostParts := strings.Split(u.Host, ".")
if len(hostParts) != 3 {
err = fmt.Errorf("URL is not a valid GCS URL")
Expand All @@ -208,6 +208,8 @@ func (g *GCSGetter) parseURL(u *url.URL) (bucket, path, fragment string, err err
bucket = pathParts[3]
path = pathParts[4]
fragment = u.Fragment
} else {
err = fmt.Errorf("URL is not a valid GCS URL")
}
return
}
Expand Down
56 changes: 56 additions & 0 deletions get_gcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,59 @@ func TestGCSGetter_GetFile_OAuthAccessToken(t *testing.T) {
}
assertContents(t, dst, "# Main\n")
}

func Test_GCSGetter_ParseUrl(t *testing.T) {
tests := []struct {
name string
url string
}{
{
name: "valid host",
url: "https://www.googleapis.com/storage/v1/hc-go-getter-test/go-getter/foobar",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := new(GCSGetter)
u, err := url.Parse(tt.url)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, _, _, err = g.parseURL(u)
if err != nil {
t.Fatalf("wasn't expecting error, got %s", err)
}
})
}
}
func Test_GCSGetter_ParseUrl_Malformed(t *testing.T) {
dduzgun-security marked this conversation as resolved.
Show resolved Hide resolved
tests := []struct {
name string
url string
}{
{
name: "invalid host suffix",
url: "https://www.googleapis.com.invalid",
},
{
name: "host suffix with a typo",
url: "https://www.googleapi.com.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := new(GCSGetter)
u, err := url.Parse(tt.url)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, _, _, err = g.parseURL(u)
if err == nil {
t.Fatalf("expected error, got none")
}
if err.Error() != "URL is not a valid GCS URL" {
t.Fatalf("expected error 'URL is not a valid GCS URL', got %s", err.Error())
}
})
}
}
6 changes: 3 additions & 3 deletions get_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
// This just check whether we are dealing with S3 or
// any other S3 compliant service. S3 has a predictable
// url as others do not
if strings.Contains(u.Host, "amazonaws.com") {
if strings.HasSuffix(u.Host, ".amazonaws.com") {
// Amazon S3 supports both virtual-hosted–style and path-style URLs to access a bucket, although path-style is deprecated
// In both cases few older regions supports dash-style region indication (s3-Region) even if AWS discourages their use.
// The same bucket could be reached with:
Expand Down Expand Up @@ -304,7 +304,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
path = pathParts[1]

}
if len(hostParts) < 3 && len(hostParts) > 5 {
if len(hostParts) < 3 || len(hostParts) > 5 {
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
Expand All @@ -313,7 +313,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
} else {
pathParts := strings.SplitN(u.Path, "/", 3)
if len(pathParts) != 3 {
err = fmt.Errorf("URL is not a valid S3 compliant URL")
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
bucket = pathParts[1]
Expand Down
8 changes: 8 additions & 0 deletions get_s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ func Test_S3Getter_ParseUrl_Malformed(t *testing.T) {
name: "vhost-style, dot region indication",
url: "https://bucket.s3.us-east-1.amazonaws.com",
},
{
name: "invalid host parts",
url: "https://invalid.host.parts.lenght.s3.us-east-1.amazonaws.com",
},
{
name: "invalid host suffix",
url: "https://bucket.s3.amazonaws.com.invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
Loading