diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml new file mode 100644 index 0000000..bedb968 --- /dev/null +++ b/.buildkite/pipeline.yaml @@ -0,0 +1,73 @@ +env: + APP_NAME: ${BUILDKITE_PIPELINE_SLUG} + IMAGE_REPO: ghcr.io/theopenlane/${APP_NAME} + SONAR_HOST: "https://sonarcloud.io" +steps: + - group: ":test_tube: tests" + key: "tests" + steps: + - label: ":golangci-lint: lint :lint-roller:" + cancel_on_build_failing: true + key: "lint" + plugins: + - docker#v5.11.0: + image: "registry.hub.docker.com/golangci/golangci-lint:v1.60.3" + command: ["golangci-lint", "run", "-v", "--timeout", "10m", "--config", ".golangci.yaml", "--concurrency", "0"] + environment: + - "GOTOOLCHAIN=auto" + - label: ":golang: go test" + key: "go_test" + retry: + automatic: + - exit_status: "*" + limit: 2 + cancel_on_build_failing: true + plugins: + - docker#v5.11.0: + image: golang:1.23.0 + command: ["go", "test", "-coverprofile=coverage.out", "./..."] + artifact_paths: ["coverage.out"] + - group: ":closed_lock_with_key: Security Checks" + depends_on: "go_test" + key: "security" + steps: + - label: ":closed_lock_with_key: gosec" + key: "gosec" + plugins: + - docker#v5.11.0: + image: "registry.hub.docker.com/securego/gosec:2.20.0" + command: ["-no-fail", "-exclude-generated", "-fmt sonarqube", "-out", "results.txt", "./..."] + environment: + - "GOTOOLCHAIN=auto" + artifact_paths: ["results.txt"] + - label: ":github: upload PR reports" + key: "scan-upload-pr" + if: build.pull_request.id != null + depends_on: ["gosec", "go_test"] + plugins: + - artifacts#v1.9.4: + download: "results.txt" + - artifacts#v1.9.4: + download: "coverage.out" + step: "go_test" + - docker#v5.11.0: + image: "sonarsource/sonar-scanner-cli:5" + environment: + - "SONAR_TOKEN" + - "SONAR_HOST_URL=$SONAR_HOST" + - "SONAR_SCANNER_OPTS=-Dsonar.pullrequest.branch=$BUILDKITE_BRANCH -Dsonar.pullrequest.base=$BUILDKITE_PULL_REQUEST_BASE_BRANCH -Dsonar.pullrequest.key=$BUILDKITE_PULL_REQUEST" + - label: ":github: upload reports" + key: "scan-upload" + if: build.branch == "main" + depends_on: ["gosec", "go_test"] + plugins: + - artifacts#v1.9.4: + download: results.txt + - artifacts#v1.9.4: + download: coverage.out + step: "go_test" + - docker#v5.11.0: + image: "sonarsource/sonar-scanner-cli:5" + environment: + - "SONAR_TOKEN" + - "SONAR_HOST_URL=$SONAR_HOST" diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..f906a12 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @theopenlane/blacksmiths \ No newline at end of file diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..48b8eca --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,24 @@ +# Contributing + +Given external users will not have write to the branches in this repository, you'll need to follow the forking process to open a PR - [here](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) is a guide from github on how to do so. + +Please also read our main [contributing guide](https://github.com/theopenlane/.github/blob/main/CONTRIBUTING.md) in addition to this one; the main guide mostly says that we'd like for you to open an issue first but it's not hard-required, and that we accept all forms of proposed changes given the state of this code base (in it's infancy, still!) + +## Pre-requisites to a PR + +This repository contains a number of code generating functions / utilities which take schema modifications and scaffold out resolvers, graphql API schemas, openAPI specifications, among other things. To ensure you've generated all the necessary dependencies run `task pr`; this will run the entirety of the commands required to safely generate a PR. If for some reason one of the commands fails / encounters an error, you will need to debug the individual steps. It should be decently easy to follow the `Taskfile` in the root of this repository. + +### Pre-Commit Hooks + +We have several `pre-commit` hooks that should be run before pushing a commit. Make sure this is installed: + +```bash +brew install pre-commit +pre-commit install +``` + +You can optionally run against all files: + +```bash +pre-commit run --all-files +``` diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..263ebe5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,16 @@ +--- +name: Bug report +about: Create a report to help us improve +title: "[Bug]" +labels: bug +assignees: '' + +--- + +**Describe the bug or issue you're encountering** + + +**What are the relevant steps to reproduce, including the version(s) of the relevant software?** + + +**What is the expected behavior?** diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..897f8f2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,14 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: "[Feature Request]" +labels: enhancement +assignees: matoszz + +--- + +**Describe how the feature might make your life easier or solve a problem** + +**Describe the solution you'd like to see with any relevant context** + +**Describe any alternatives you've considered or if there are short-tern vs. long-term options** diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 0000000..6e813dc --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,9 @@ +# Add 'bug' label to any PR where the head branch name starts with `bug` or has a `bug` section in the name +bug: + - head-branch: ["^bug", "bug"] +# Add 'enhancement' label to any PR where the head branch name starts with `enhancement` or has a `enhancement` section in the name +enhancement: + - head-branch: ["^enhancement", "enhancement", "^feature", "feature", "^enhance", "enhance", "^feat", "feat"] +# Add 'breaking-change' label to any PR where the head branch name starts with `breaking-change` or has a `breaking-change` section in the name +breaking-change: + - head-branch: ["^breaking-change", "breaking-change"] diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 0000000..37df9bc --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,24 @@ +changelog: + exclude: + labels: + - ignore-for-release + authors: [] + categories: + - title: Breaking Changes 🛠 + labels: + - Semver-Major + - breaking-change + - title: New Features 🎉 + labels: + - Semver-Minor + - enhancement + - feature + - title: Bug Fixes 🐛 + labels: + - bug + - title: 👒 Dependencies + labels: + - dependencies + - title: Other Changes + labels: + - "*" diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml new file mode 100644 index 0000000..fc43cb1 --- /dev/null +++ b/.github/workflows/labeler.yaml @@ -0,0 +1,13 @@ +name: "Pull Request Labeler" +on: + - pull_request_target +jobs: + triage: + permissions: + contents: read + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v5 + with: + sync-labels: true diff --git a/.github/workflows/releaser.yml b/.github/workflows/releaser.yml new file mode 100644 index 0000000..9381e0d --- /dev/null +++ b/.github/workflows/releaser.yml @@ -0,0 +1,127 @@ +name: Release +on: + workflow_dispatch: + release: + types: [created] +permissions: + contents: write +jobs: + ldflags_args: + runs-on: ubuntu-latest + outputs: + commit-date: ${{ steps.ldflags.outputs.commit-date }} + commit: ${{ steps.ldflags.outputs.commit }} + version: ${{ steps.ldflags.outputs.version }} + tree-state: ${{ steps.ldflags.outputs.tree-state }} + steps: + - id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - id: ldflags + run: | + echo "commit=$GITHUB_SHA" >> $GITHUB_OUTPUT + echo "commit-date=$(git log --date=iso8601-strict -1 --pretty=%ct)" >> $GITHUB_OUTPUT + echo "version=$(git describe --tags --always --dirty | cut -c2-)" >> $GITHUB_OUTPUT + echo "tree-state=$(if git diff --quiet; then echo "clean"; else echo "dirty"; fi)" >> $GITHUB_OUTPUT + release: + name: Build and release + needs: + - ldflags_args + outputs: + hashes: ${{ steps.hash.outputs.hashes }} + permissions: + contents: write # To add assets to a release. + id-token: write # To do keyless signing with cosign + runs-on: macos-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + cache: true + - name: Install Syft + uses: anchore/sbom-action/download-syft@ab9d16d4b419c9d1a02df5213fa0ebe965ca5a57 # v0.17.1 + - name: Install Cosign + uses: sigstore/cosign-installer@v3.6.0 + - name: Run GoReleaser + id: run-goreleaser + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser + version: latest + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} + VERSION: ${{ needs.ldflags_args.outputs.version }} + COMMIT: ${{ needs.ldflags_args.outputs.commit }} + COMMIT_DATE: ${{ needs.ldflags_args.outputs.commit-date }} + TREE_STATE: ${{ needs.ldflags_args.outputs.tree-state }} + - name: Generate subject + id: hash + env: + ARTIFACTS: "${{ steps.run-goreleaser.outputs.artifacts }}" + run: | + set -euo pipefail + hashes=$(echo $ARTIFACTS | jq --raw-output '.[] | {name, "digest": (.extra.Digest // .extra.Checksum)} | select(.digest) | {digest} + {name} | join(" ") | sub("^sha256:";"")' | base64) + if test "$hashes" = ""; then # goreleaser < v1.13.0 + checksum_file=$(echo "$ARTIFACTS" | jq -r '.[] | select (.type=="Checksum") | .path') + hashes=$(cat $checksum_file | base64) + fi + echo "hashes=$hashes" >> $GITHUB_OUTPUT + provenance: + name: Generate provenance (SLSA3) + needs: + - release + permissions: + actions: read # To read the workflow path. + id-token: write # To sign the provenance. + contents: write # To add assets to a release. + uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v2.0.0 + with: + base64-subjects: "${{ needs.release.outputs.hashes }}" + upload-assets: true # upload to a new release + verification: + name: Verify provenance of assets (SLSA3) + needs: + - release + - provenance + runs-on: ubuntu-latest + permissions: read-all + steps: + - name: Install the SLSA verifier + uses: slsa-framework/slsa-verifier/actions/installer@v2.6.0 + - name: Download assets + env: + GH_TOKEN: "${{ secrets.GITHUB_TOKEN }}" + CHECKSUMS: "${{ needs.release.outputs.hashes }}" + ATT_FILE_NAME: "${{ needs.provenance.outputs.provenance-name }}" + run: | + set -euo pipefail + checksums=$(echo "$CHECKSUMS" | base64 -d) + while read -r line; do + fn=$(echo $line | cut -d ' ' -f2) + echo "Downloading $fn" + gh -R "$GITHUB_REPOSITORY" release download "$GITHUB_REF_NAME" -p "$fn" + done <<<"$checksums" + gh -R "$GITHUB_REPOSITORY" release download "$GITHUB_REF_NAME" -p "$ATT_FILE_NAME" + - name: Verify assets + env: + CHECKSUMS: "${{ needs.release.outputs.hashes }}" + PROVENANCE: "${{ needs.provenance.outputs.provenance-name }}" + run: |- + set -euo pipefail + checksums=$(echo "$CHECKSUMS" | base64 -d) + while read -r line; do + fn=$(echo $line | cut -d ' ' -f2) + echo "Verifying SLSA provenance for $fn" + slsa-verifier verify-artifact --provenance-path "$PROVENANCE" \ + --source-uri "github.com/$GITHUB_REPOSITORY" \ + --source-tag "$GITHUB_REF_NAME" \ + "$fn" + done <<<"$checksums" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3d7b1eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,63 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Go workspace file +go.work + +# local dev created files +*.db +core +core-cli +server.crt +server.key +private_key.pem +public_key.pem +# Sendgrid test email folders +internal/httpserve/handlers/fixtures/emails/* +fixtures/emails/* +pkg/httpsling/testdata/* + +# Packages +*.7z +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar + +# Logs +*.log + +# Editor files +.vscode + +# OS Generated Files +.DS_Store* +.AppleDouble +.LSOverride +ehthumbs.db +Icon? +Thumbs.db + +.scannerwork/** +results.txt + +*.mime +*.mim + +# Configs +.task \ No newline at end of file diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..5401353 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,41 @@ +run: + timeout: 10m + allow-serial-runners: true +linters-settings: + goimports: + local-prefixes: github.com/theopenlane/httpsling + gofumpt: + extra-rules: true + gosec: + exclude-generated: true + revive: + ignore-generated-header: true +linters: + enable: + - bodyclose + - errcheck + - gocritic + - gocyclo + - err113 + - gofmt + - goimports + - mnd + - gosimple + - govet + - gosec + - ineffassign + - misspell + - noctx + - revive + - staticcheck + - stylecheck + - typecheck + - unused + - whitespace + - wsl +issues: + fix: true + exclude-use-default: true + exclude-dirs: + - totp/testing/* + exclude-files: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e3e69b7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +default_stages: [pre-commit] +fail_fast: true +default_language_version: + golang: system + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: detect-private-key + - repo: https://github.com/google/yamlfmt + rev: v0.13.0 + hooks: + - id: yamlfmt + - repo: https://github.com/crate-ci/typos + rev: v1.24.1 + hooks: + - id: typos diff --git a/.trivyignore b/.trivyignore new file mode 100644 index 0000000..e69de29 diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 0000000..ef3bcd6 --- /dev/null +++ b/.typos.toml @@ -0,0 +1,20 @@ +[files] +extend-exclude = ["go.mod","go.sum"] +ignore-hidden = true +ignore-files = true +ignore-dot = true +ignore-vcs = true +ignore-global = true +ignore-parent = true + +[default] +binary = false +check-filename = true +check-file = true +unicode = true +ignore-hex = true +identifier-leading-digits = false +locale = "en" +extend-ignore-identifiers-re = [] +extend-ignore-words-re = ["(?i)requestor","(?i)encrypter","(?i)seeked"] +extend-ignore-re = ["#\\s*spellchecker:off\\s*\\n.*\\n\\s*#\\s*spellchecker:on"] \ No newline at end of file diff --git a/.yamlfmt b/.yamlfmt new file mode 100644 index 0000000..f6cfc8b --- /dev/null +++ b/.yamlfmt @@ -0,0 +1,4 @@ +exclude: + - config/ +formatter: + retain_line_breaks: true \ No newline at end of file diff --git a/LICENSE b/LICENSE index 261eeb9..2e08388 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2024, The Open Lane, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..1d25051 --- /dev/null +++ b/README.md @@ -0,0 +1,869 @@ +# Slinging HTTP + +The `httpsling` library simplifies the way you make HTTP httpsling. It's intended to provide an easy-to-use interface for sending requests and handling responses, reducing the boilerplate code typically associated with the `net/http` package. + +## Overview + +Creating a new HTTP client and making a request should be straightforward: + +```go +package main + +import ( + "github.com/theopenlane/httpsling" + "log" +) + +func main() { + // Create a client using a base URL + client := httpsling.URL("http://mattisthebest.com") + + // Alternatively, create a client with custom configuration + client = httpsling.Create(&httpsling.Config{ + BaseURL: "http://mattisthebest.com", + Timeout: 30 * time.Second, + }) + + // Perform a GET request + resp, err := client.Get("/resource") + if err != nil { + log.Fatal(err) + } + + defer resp.Close() + + log.Println(resp.String()) +} +``` + +## Client + +The `Client` struct is your gateway to making HTTP requests. You can configure it to your needs, setting default headers, cookies, timeout durations, etc. + +```go +client := httpsling.URL("http://mattisthebest.com") + +// Or, with full configuration +client = httpsling.Create(&httpsling.Config{ + BaseURL: "http://mattisthebest.com", + Timeout: 5 * time.Second, + Headers: &http.Header{ + HeaderContentType: []string{ContentTypeJSON}, + }, +}) +``` + +### Initializing the Client + +You can start by creating a `Client` with specific configurations using the `Create` method: + +```go +client := httpsling.Create(&httpsling.Config{ + BaseURL: "https://the.cats.meow.com", + Timeout: 30 * time.Second, + Headers: &http.Header{ + HeaderAuthorization: []string{"Bearer YOUR_ACCESS_TOKEN"}, + HeaderContentType: []string{ContentTypeJSON}, + }, + Cookies: map[string]string{ + "session_token": "YOUR_SESSION_TOKEN", + }, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + MaxRetries: 3, + RetryStrategy: httpsling.ExponentialBackoffStrategy(1*time.Second, 2, 30*time.Second), + RetryIf: httpsling.DefaultRetryIf, +}) +``` + +This setup creates a `Client` tailored for your API communication, including base URL, request timeout, default headers, and cookies + +### Configuring with Set Methods + +Alternatively, you can use `Set` methods for a more dynamic configuration approach: + +```go +client := httpsling.URL("https://the.cats.meow.com"). + SetDefaultHeader(HeaderAuthorization, "Bearer YOUR_ACCESS_TOKEN"). + SetDefaultHeader(HeaderContentType, ContentTypeJSON). + SetDefaultCookie("session_token", "YOUR_SESSION_TOKEN"). + SetTLSConfig(&tls.Config{InsecureSkipVerify: true}). + SetMaxRetries(3). + SetRetryStrategy(httpsling.ExponentialBackoffStrategy(1*time.Second, 2, 30*time.Second)). + SetRetryIf(httpsling.DefaultRetryIf). + SetProxy("http://localhost:8080") +``` + +### Configuring BaseURL + +Set the base URL for all requests: + +```go +client.SetBaseURL("https://the.cats.meow.com") +``` + +### Setting Headers + +Set default headers for all requests: + +```go +client.SetDefaultHeader(HeaderAuthorization, "Bearer YOUR_ACCESS_TOKEN") +client.SetDefaultHeader(HeaderContentType, ContentTypeJSON) +``` + +Bulk set default headers: + +```go +headers := &http.Header{ + HeaderAuthorization: []string{"Bearer YOUR_ACCESS_TOKEN"}, + HeaderContentType: []string{ContentTypeJSON}, +} +client.SetDefaultHeaders(headers) +``` + +Add or remove a header: + +```go +client.AddDefaultHeader("X-Custom-Header", "Value1") +client.DelDefaultHeader("X-Unneeded-Header") +``` + +### Managing Cookies + +Set default cookies for all requests: + +```go +client.SetDefaultCookie("session_id", "123456") +``` + +Bulk set default cookies: + +```go +cookies := map[string]string{ + "session_id": "123456", + "preferences": "dark_mode=true", +} +client.SetDefaultCookies(cookies) +``` + +Remove a default cookie: + +```go +client.DelDefaultCookie("session_id") +``` + +This approach simplifies managing base URLs, headers, and cookies across all requests made with the client, ensuring consistency. + +### Configuring Timeouts + +Define a global timeout for all requests to prevent indefinitely hanging operations: + +```go +client := httpsling.Create(&httpsling.Config{ + Timeout: 15 * time.Second, +}) +``` + +### TLS Configuration + +Custom TLS configurations can be applied for enhanced security measures, such as loading custom certificates: + +```go +tlsConfig := &tls.Config{InsecureSkipVerify: true} +client.SetTLSConfig(tlsConfig) +``` + +## Requests + +The library provides a `RequestBuilder` to construct and dispatch HTTP httpsling. Here are examples of performing various types of requests, including adding query parameters, setting headers, and attaching a body to your httpsling. + +#### GET Request + +```go +resp, err := client.Get("/path"). + Query("search", "query"). + Header(HeaderAccept, ContentTypeJSON). + Send(context.Background()) +``` + +#### POST Request + +```go +resp, err := client.Post("/path"). + Header(HeaderContentType, ContentTypeJSON). + JSONBody(map[string]interface{}{"key": "value"}). + Send(context.Background()) +``` + +#### PUT Request + +```go +resp, err := client.Put("/stff/{stuff_id}"). + PathParam("stuff_id", "123456"). + JSONBody(map[string]interface{}{"updatedKey": "newValue"}). + Send(context.Background()) +``` + +#### DELETE Request + +```go +resp, err := client.Delete("/stffs/{stuff_id}"). + PathParam("stuff_id", "123456meowmeow"). + Send(context.Background()) +``` + +### Retry Mechanism + +Automatically retry requests on failure with customizable strategies: + +```go +client.SetMaxRetries(3) +client.SetRetryStrategy(httpsling.ExponentialBackoffStrategy(1*time.Second, 2, 30*time.Second)) +client.SetRetryIf(func(req *http.Request, resp *http.Response, err error) bool { + // Only retry for 500 Internal Server Error + return resp.StatusCode == http.StatusInternalServerError +}) +``` + +### Configuring Retry Strategies + +#### Applying a Default Backoff Strategy + +For consistent delay intervals between retries: + +```go +client.SetRetryStrategy(httpsling.DefaultBackoffStrategy(5 * time.Second)) +``` + +#### Utilizing a Linear Backoff Strategy + +To increase delay intervals linearly with each retry attempt: + +```go +client.SetRetryStrategy(httpsling.LinearBackoffStrategy(1 * time.Second)) +``` + +#### Employing an Exponential Backoff Strategy + +For exponential delay increases between attempts, with an option to cap the delay: + +```go +client.SetRetryStrategy(httpsling.ExponentialBackoffStrategy(1*time.Second, 2, 30*time.Second)) +``` + +### Customizing Retry Conditions + +Define when retries should be attempted based on response status codes or errors: + +```go +client.SetRetryIf(func(req *http.Request, resp *http.Response, err error) bool { + return resp.StatusCode == http.StatusInternalServerError || err != nil +}) +``` + +### Setting Maximum Retry Attempts + +To limit the number of retries, use the `SetMaxRetries` method: + +```go +client.SetMaxRetries(3) +``` + +### Proxy Configuration + +Route requests through a proxy server: + +```go +client.SetProxy("http://localhost:8080") +``` + +### Authentication + +Supports various authentication methods: + +- **Basic Auth**: + +```go +client.SetAuth(httpsling.BasicAuth{ + Username: "user", + Password: "pass", +}) +``` + +- **Bearer Token**: + +```go +client.SetAuth(httpsling.BearerAuth{ + Token: "YOUR_ACCESS_TOKEN", +}) +``` + +### Query Parameters + +Add query parameters to your request using `Query`, `Queries`, `QueriesStruct`, or remove them with `DelQuery` + +```go +// Add a single query parameter +request.Query("search", "query") + +// Add multiple query parameters +request.Queries(url.Values{"sort": []string{"date"}, "limit": []string{"10"}}) + +// Add query parameters from a struct +type queryParams struct { + Sort string `url:"sort"` + Limit int `url:"limit"` +} +request.QueriesStruct(queryParams{Sort: "date", Limit: 10}) + +// Remove one or more query parameters +request.DelQuery("sort", "limit") +``` + +### Headers + +Set request headers using `Header`, `Headers`, or related methods + +```go +request.Header(HeaderAuthorization, "Bearer YOUR_ACCESS_TOKEN") +request.Headers(http.Header{HeaderContentType: []string{ContentTypeJSON}}) + +// Convenient methods for common headers +request.ContentType(ContentTypeJSON) +request.Accept(ContentTypeJSON) +request.UserAgent("MyCustomClient/1.0") +request.Referer("https://example.com") +``` + +### Cookies + +Add cookies to your request using `Cookie`, `Cookies`, or remove them with `DelCookie`. + +```go +// Add a single cookie +request.Cookie("session_token", "YOUR_SESSION_TOKEN") + +// Add multiple cookies at once +request.Cookies(map[string]string{ + "session_token": "YOUR_SESSION_TOKEN", + "user_id": "12345", +}) + +// Remove one or more cookies +request.DelCookie("session_token", "user_id") + +``` + +### Body Content + +Specify the request body directly with `Body` or use format-specific methods like `JSONBody`, `XMLBody`, `YAMLBody`, `TextBody`, or `RawBody` for appropriate content types. + +```go +// Setting JSON body +request.JSONBody(map[string]interface{}{"key": "value"}) + +// Setting XML body +request.XMLBody(myXmlStruct) + +// Setting YAML body +request.YAMLBody(myYamlStruct) + +// Setting text body +request.TextBody("plain text content") + +// Setting raw body +request.RawBody([]byte("raw data")) +``` + +### Timeout and Retries + +Configure request-specific timeout and retry strategies: + +```go +request.Timeout(10 * time.Second).MaxRetries(3) +``` + +### Sending Requests + +The `Send(ctx)` method executes the HTTP request built with the Request builder. It requires a `context.Context` argument, allowing you to control request cancellation and timeouts. + +```go +resp, err := request.Send(context.Background()) +if err != nil { + log.Fatalf("Request failed: %v", err) +} +// Process response... +``` +### Advanced Features + +#### Handling Cancellation + +To cancel a request, simply use the context's cancel function. This is particularly useful for long-running requests that you may want to abort if they take too long or if certain conditions are met. + +```go +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() // Ensures resources are freed up after the operation completes or times out + +// Cancel the request if it hasn't completed within the timeout +resp, err := request.Send(ctx) +if errors.Is(err, context.Canceled) { + log.Println("Request was canceled") +} +``` + +#### HTTP Client Customization + +Directly customize the underlying `http.Client`: + +```go +customHTTPClient := &http.Client{Timeout: 20 * time.Second} +client.SetHTTPClient(customHTTPClient) +``` + +#### Path Parameters + +To insert or modify path parameters in your URL, use `PathParam` for individual parameters or `PathParams` for multiple. For removal, use `DelPathParam`. + +```go +// Setting a single path parameter +request.PathParam("userId", "123") + +// Setting multiple path parameters at once +request.PathParams(map[string]string{"userId": "123", "postId": "456"}) + +// Removing path parameters +request.DelPathParam("userId", "postId") +``` + +When using `client.Get("/users/{userId}/posts/{postId}")`, replace `{userId}` and `{postId}` with actual values by using `PathParams` or `PathParam`. + +#### Form Data + +For `application/x-www-form-urlencoded` content, utilize `FormField` for individual fields or `FormFields` for multiple. + +```go +// Adding individual form field +request.FormField("name", "John Snow") + +// Setting multiple form fields at once +fields := map[string]interface{}{"name": "John", "age": "30"} +request.FormFields(fields) +``` + +#### File Uploads + +To include files in a `multipart/form-data` request, specify each file's form field name, file name, and content using `File` or add multiple files with `Files`. + +```go +// Adding a single file +file, _ := os.Open("path/to/file") +request.File("profile_picture", "filename.jpg", file) + +// Adding multiple files +request.Files(file1, file2) +``` + +### Authentication + +Apply authentication methods directly to the request: + +```go +request.Auth(httpsling.BasicAuth{ + Username: "user", + Password: "pass", +}) +``` + +## Middleware + +Add custom middleware to process the request or response: + +```go +request.AddMiddleware(func(next httpsling.MiddlewareHandlerFunc) httpsling.MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + // Custom logic before request + resp, err := next(req) + // Custom logic after response + return resp, err + } +}) +``` + +### Understanding Middleware + +Middleware functions wrap around HTTP requests, allowing pre- and post-processing of requests and responses. They can modify requests before they are sent, examine responses, and decide whether to modify them, retry the request, or take other actions. + +### Client-Level Middleware + +Client-level middleware is applied to all requests made by a client. It's ideal for cross-cutting concerns like logging, error handling, and metrics collection. + +**Adding Middleware to a Client:** + +```go +client := httpsling.Create(&httpsling.Config{BaseURL: "https://the.cats.meow.com"}) +client.AddMiddleware(func(next httpsling.MiddlewareHandlerFunc) httpsling.MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + // Pre-request manipulation + fmt.Println("Request URL:", req.URL) + + // Proceed with the request + resp, err := next(req) + + // Post-response manipulation + if err == nil { + fmt.Println("Response status:", resp.Status) + } + + return resp, err + } +}) +``` + +### Request-Level Middleware + +Request-level middleware applies only to individual httpsling. This is useful for request-specific concerns, such as request tracing or modifying the request based on dynamic context. + +**Adding Middleware to a Request:** + +```go +request := client.NewRequestBuilder(MethodGet, "/path").AddMiddleware(func(next httpsling.MiddlewareHandlerFunc) httpsling.MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + // Modify the request here + req.Header.Add("X-Request-ID", "12345") + + // Proceed with the modified request + return next(req) + } +}) +``` + +### Implementing Custom Middleware + +Custom middleware can perform a variety of tasks, such as authentication, logging, and metrics. Here's a simple logging middleware example: + +```go +func loggingMiddleware(next httpsling.MiddlewareHandlerFunc) httpsling.MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + log.Printf("Requesting %s %s", req.Method, req.URL) + resp, err := next(req) + if err != nil { + log.Printf("Request to %s failed: %v", req.URL, err) + } else { + log.Printf("Received %d response from %s", resp.StatusCode, req.URL) + } + return resp, err + } +} +``` + +### Integrating OpenTelemetry Middleware + +OpenTelemetry middleware can be used to collect tracing and metrics for your requests if you're into that sort of thing. Below is an example of how to set up a basic trace for an HTTP request: + +**Implementing OpenTelemetry Middleware:** + +```go +func openTelemetryMiddleware(next httpsling.MiddlewareHandlerFunc) httpsling.MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + ctx, span := otel.Tracer("requests").Start(req.Context(), req.URL.Path) + defer span.End() + + // Add trace ID to request headers if needed + traceID := span.SpanContext().TraceID().String() + req.Header.Set("X-Trace-ID", traceID) + + resp, err := next(req) + + // Set span attributes based on response + if err == nil { + span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + } else { + span.RecordError(err) + } + + return resp, err + } +} +``` + + +## Responses + +Handling responses is necessary in determining the outcome of your HTTP requests - the library has some built-in response code validators and other tasty things. + +```go +type APIResponse struct { + Data string `json:"data"` +} + +var apiResp APIResponse +if err := resp.ScanJSON(&apiResp); err != nil { + log.Fatal(err) +} + +log.Printf("Status Code: %d\n", resp.StatusCode()) +log.Printf("Response Data: %s\n", apiResp.Data) +``` + +### Parsing Response Body + +By leveraging the `Scan`, `ScanJSON`, `ScanXML`, and `ScanYAML` methods, you can decode responses based on the `Content-Type` + +#### JSON Responses + +Given a JSON response, you can unmarshal it directly into a Go struct using either the specific `ScanJSON` method or the generic `Scan` method, which automatically detects the content type: + +```go +var jsonData struct { + Name string `json:"name"` + Age int `json:"age"` +} + +// Unmarshal using ScanJSON +if err := response.ScanJSON(&jsonData); err != nil { + log.Fatalf("Error unmarshalling JSON: %v", err) +} + +// Alternatively, unmarshal using Scan +if err := response.Scan(&jsonData); err != nil { + log.Fatalf("Error unmarshalling response: %v", err) +} +``` + +#### XML Responses + +For XML responses, use `ScanXML` or `Scan` to decode into a Go struct. Here's an example assuming the response contains XML data: + +```go +var xmlData struct { + Name string `xml:"name"` + Age int `xml:"age"` +} + +// Unmarshal using ScanXML +if err := response.ScanXML(&xmlData); err != nil { + log.Fatalf("Error unmarshalling XML: %v", err) +} + +// Alternatively, unmarshal using Scan +if err := response.Scan(&xmlData); err != nil { + log.Fatalf("Error unmarshalling response: %v", err) +} +``` + +#### YAML Responses + +YAML content is similarly straightforward to handle. The `ScanYAML` or `Scan` method decodes the YAML response into the specified Go struct: + +```go +var yamlData struct { + Name string `yaml:"name"` + Age int `yaml:"age"` +} + +// Unmarshal using ScanYAML +if err := response.ScanYAML(&yamlData); err != nil { + log.Fatalf("Error unmarshalling YAML: %v", err) +} + +// Alternatively, unmarshal using Scan +if err := response.Scan(&yamlData); err != nil { + log.Fatalf("Error unmarshalling response: %v", err) +} +``` + +### Storing Response Content + +For saving the response body to a file or streaming it to an `io.Writer`: + +- **Save**: Write the response body to a designated location + + ```go + // Save response to a file + if err := response.Save("downloaded_file.txt"); err != nil { + log.Fatalf("Failed to save file: %v", err) + } + ``` + +### Evaluating Response Success + +To assess whether the HTTP request was successful: + +- **IsSuccess**: Check if the status code signifies a successful response + + ```go + if response.IsSuccess() { + fmt.Println("The request succeeded hot diggity dog") + } + ``` + + +## Enabling Logging + +To turn on logging, you must explicitly initialize and set a `Logger` in the client configuration. Here's how to create and use the `DefaultLogger`, which logs to `os.Stderr` by default, and is configured to log errors only: + +```go +logger := httpsling.NewDefaultLogger(os.Stderr, slog.LevelError) +client := httpsling.Create(&httpsling.Config{ + Logger: logger, +}) +``` + +Or, for an already instantiated client: + +```go +client.SetLogger(httpsling.NewDefaultLogger(os.Stderr, slog.LevelError)) +``` + +### Adjusting Log Levels + +Adjusting the log level is straightforward. After defining your logger, simply set the desired level. This allows you to control the verbosity of the logs based on your requirements. + +```go +logger := httpsling.NewDefaultLogger(os.Stderr, httpsling.LevelError) +logger.SetLevel(httpsling.LevelInfo) // Set to Info level to capture more detailed logs + +client := httpsling.Create(&httpsling.Config{ + Logger: logger, +}) +``` + +The available log levels are: + +- `LevelDebug` +- `LevelInfo` +- `LevelWarn` +- `LevelError` + +### Implementing a Custom Logger + +For more advanced scenarios where you might want to integrate with an existing logging system or format logs differently, implement the `Logger` interface. This requires methods for each level of logging (`Debugf`, `Infof`, `Warnf`, `Errorf`) and a method to set the log level (`SetLevel`). + +Here is a simplified example: + +```go +type MyLogger struct { + // Include your custom logging mechanism here +} + +func (l *MyLogger) Debugf(format string, v ...any) { + // Custom debug logging implementation +} + +func (l *MyLogger) Infof(format string, v ...any) { + // Custom info logging implementation +} + +func (l *MyLogger) Warnf(format string, v ...any) { + // Custom warn logging implementation +} + +func (l *MyLogger) Errorf(format string, v ...any) { + // Custom error logging implementation +} + +func (l *MyLogger) SetLevel(level httpsling.Level) { + // Implement setting the log level in your logger +} + +// Usage +myLogger := &MyLogger{} +myLogger.SetLevel(httpsling.LevelDebug) // Example setting to Debug level + +client := httpsling.Create(&httpsling.Config{ + Logger: myLogger, +}) +``` + +## Stream Callbacks + +Stream callbacks are functions that you define to handle chunks of data as they are received from the server. The Requests library supports three types of stream callbacks: + +- **StreamCallback**: Invoked for each chunk of data received +- **StreamErrCallback**: Invoked when an error occurs during streaming +- **StreamDoneCallback**: Invoked once streaming is completed, regardless of whether it ended due to an error or successfully + +### Configuring Stream Callbacks + +To configure streaming for a request, use the `Stream` method on a `RequestBuilder` instance. This method accepts a `StreamCallback` function, which will be called with each chunk of data received from the server. + +```go +streamCallback := func(data []byte) error { + fmt.Println("Received stream data:", string(data)) + return nil // Return an error if needed to stop streaming +} + +request := client.Get("/stream-endpoint").Stream(streamCallback) +``` + +### Handling Stream Errors + +To handle errors that occur during streaming, set a `StreamErrCallback` using the `StreamErr` method on the `Response` object. + +```go +streamErrCallback := func(err error) { + fmt.Printf("Stream error: %v\n", err) +} + +response, _ := request.Send(context.Background()) +response.StreamErr(streamErrCallback) +``` + +### Completing Stream Processing + +Once streaming is complete, you can use the `StreamDone` method on the `Response` object to set a `StreamDoneCallback`. This callback is invoked after the stream is fully processed, either successfully or due to an error. + +```go +streamDoneCallback := func() { + fmt.Println("Stream processing completed") +} + +response.StreamDone(streamDoneCallback) +``` + +### Example: Consuming an SSE Stream + +The following example demonstrates how to consume a Server-Sent Events (SSE) stream, processing each event as it arrives, handling errors, and performing cleanup once the stream ends. + +```go +// Configure the stream callback to handle data chunks +streamCallback := func(data []byte) error { + fmt.Println("Received stream event:", string(data)) + return nil +} + +// Configure error and done callbacks +streamErrCallback := func(err error) { + fmt.Printf("Error during streaming: %v\n", err) +} + +streamDoneCallback := func() { + fmt.Println("Stream ended") +} + +// Create the streaming request +client := httpsling.Create(&httpsling.Config{BaseURL: "https://example.com"}) +request := client.Get("/events").Stream(streamCallback) + +// Send the request and configure callbacks +response, err := request.Send(context.Background()) +if err != nil { + fmt.Printf("Failed to start streaming: %v\n", err) + return +} + +response.StreamErr(streamErrCallback).StreamDone(streamDoneCallback) +``` + + +## Inspirations + +This library was inspired by and built upon the work of several other HTTP client libraries: + +- [Dghubble/sling](https://github.com/dghubble/sling) +- [Monaco-io/request](https://github.com/monaco-io/request) +- [Go-resty/resty](https://github.com/go-resty/resty) +- [Fiber Client](https://github.com/gofiber/fiber) + +Props to dghubble for a great name with `sling`, which was totally ripped off to make `httpsling` <3. I chose not to use any of these directly because I wanted to have layers of control we may need within our services echosystem. \ No newline at end of file diff --git a/Taskfile.yaml b/Taskfile.yaml new file mode 100644 index 0000000..4a76940 --- /dev/null +++ b/Taskfile.yaml @@ -0,0 +1,48 @@ +version: "3" + +tasks: + default: + silent: true + cmds: + - task --list + + ## Go tasks + go:lint: + desc: runs golangci-lint, the most annoying opinionated linter ever + cmds: + - golangci-lint run --config=.golangci.yaml --verbose --fast --fix + + go:test: + desc: runs and outputs results of created go tests + aliases: ['go:test:psql', 'test:psql'] + env: + TEST_DB_URL: "docker://postgres:16-alpine" + cmds: + - go test -v ./... + + go:fmt: + desc: format all go code + cmds: + - go fmt ./... + + go:tidy: + desc: Runs go mod tidy on the backend + aliases: [tidy] + cmds: + - go mod tidy + + go:all: + aliases: [go] + desc: Runs all go test and lint related tasks + cmds: + - task: go:tidy + - task: go:fmt + - task: go:lint + - task: go:test + + precommit-full: + desc: Lint the project against all files + cmds: + - pre-commit install && pre-commit install-hooks + - pre-commit autoupdate + - pre-commit run --show-diff-on-failure --color=always --all-files diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..0a38480 --- /dev/null +++ b/auth.go @@ -0,0 +1,71 @@ +package httpsling + +import "net/http" + +// AuthType represents the type of authentication method +type AuthType string + +var ( + // BearerAuthType is the Bearer token authentication type using the Authorization header + BearerAuthType AuthType = "Bearer" + // BasicAuthType is the Basic authentication type using a username and password + BasicAuthType AuthType = "Basic" +) + +// AuthMethod defines the interface for applying authentication strategies to httpsling +type AuthMethod interface { + // Apply adds the authentication method to the request + Apply(req *http.Request) + // Valid checks if the authentication method is valid + Valid() bool +} + +// BasicAuth represents HTTP Basic Authentication credentials +type BasicAuth struct { + Username string + Password string +} + +// CustomAuth allows for custom Authorization header values +type CustomAuth struct { + Header string +} + +// BearerAuth represents an OAuth 2.0 Bearer token +type BearerAuth struct { + Token string +} + +// Apply adds the Basic Auth credentials to the request +func (b BasicAuth) Apply(req *http.Request) { + req.SetBasicAuth(b.Username, b.Password) +} + +// Valid checks if the Basic Auth credentials are present +func (b BasicAuth) Valid() bool { + return b.Username != "" && b.Password != "" +} + +// Apply adds the Bearer token to the request's Authorization header +func (b BearerAuth) Apply(req *http.Request) { + if b.Valid() { + req.Header.Set(HeaderAuthorization, "Bearer "+b.Token) + } +} + +// Valid checks if the Bearer token is present +func (b BearerAuth) Valid() bool { + return b.Token != "" +} + +// Apply sets a custom Authorization header value +func (c CustomAuth) Apply(req *http.Request) { + if c.Valid() { + req.Header.Set(HeaderAuthorization, c.Header) + } +} + +// Valid checks if the custom Authorization header value is present +func (c CustomAuth) Valid() bool { + return c.Header != "" +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..b01b775 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,68 @@ +package httpsling + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBasicAuthApply(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://hotdogs.com", nil) + + auth := BasicAuth{ + Username: "user", + Password: "pass", + } + auth.Apply(req) + + assert.Equal(t, "Basic dXNlcjpwYXNz", req.Header.Get("Authorization")) +} + +func TestBasicAuthValid(t *testing.T) { + auth := BasicAuth{ + Username: "user", + Password: "pass", + } + + assert.True(t, auth.Valid()) +} + +func TestBearerAuthApply(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://sarahisgreat.com", nil) + + auth := BearerAuth{ + Token: "token", + } + auth.Apply(req) + + assert.Equal(t, "Bearer token", req.Header.Get("Authorization")) +} + +func TestBearerAuthValid(t *testing.T) { + auth := BearerAuth{ + Token: "token", + } + + assert.True(t, auth.Valid()) +} + +func TestCustomAuthApply(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://meow.com", nil) + + auth := CustomAuth{ + Header: "CustomValue", + } + auth.Apply(req) + + assert.Equal(t, "CustomValue", req.Header.Get("Authorization")) +} + +func TestCustomAuthValid(t *testing.T) { + auth := CustomAuth{ + Header: "CustomValue", + } + + assert.True(t, auth.Valid()) +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..0cf832e --- /dev/null +++ b/client.go @@ -0,0 +1,524 @@ +package httpsling + +import ( + "crypto/tls" + "net/http" + "net/http/cookiejar" + "sync" + "time" +) + +// Client represents an HTTP client and is the main control mechanism for making HTTP requests +type Client struct { + // mu is a mutex to protect the client's configuration + mu sync.RWMutex + // BaseURL is the base URL for all httpsling made by this client + BaseURL string + // Headers are the default headers to be sent with each request + Headers *http.Header + // Cookies are the default cookies to be sent with each request + Cookies []*http.Cookie + // Middlewares are the request/response manipulation middlewares + Middlewares []Middleware + // TLSConfig is the TLS configuration for the client + TLSConfig *tls.Config + // MaxRetries is the maximum number of retry attempts + MaxRetries int + // RetryStrategy is the backoff strategy for retries + RetryStrategy BackoffStrategy + // RetryIf is the custom retry condition function + RetryIf RetryIfFunc + // HTTPClient is the underlying HTTP client + HTTPClient *http.Client + // JSONEncoder is the JSON encoder for the client + JSONEncoder Encoder + // JSONDecoder is the JSON decoder for the client + JSONDecoder Decoder + // XMLEncoder is the XML encoder for the client + XMLEncoder Encoder + // XMLDecoder is the XML decoder for the client + XMLDecoder Decoder + // YAMLEncoder is the YAML encoder for the client + YAMLEncoder Encoder + // YAMLDecoder is the YAML decoder for the client + YAMLDecoder Decoder + // Logger is the logger instance for the client + Logger Logger + // auth is the authentication method for the client + Auth AuthMethod +} + +// Config sets up the initial configuration for the HTTP client - you need to initialize multiple if you want the behaviors to be different +type Config struct { + // The base URL for all httpsling made by this client + BaseURL string + // Default headers to be sent with each request + Headers *http.Header + // Default Cookies to be sent with each request + Cookies map[string]string + // Timeout for httpsling + Timeout time.Duration + // Cookie jar for the client + CookieJar *cookiejar.Jar + // Middlewares for request/response manipulation + Middlewares []Middleware + // TLS configuration for the client + TLSConfig *tls.Config + // Custom transport for the client + Transport http.RoundTripper + // Maximum number of retry attempts + MaxRetries int + // RetryStrategy defines the backoff strategy for retries + RetryStrategy BackoffStrategy + // RetryIf defines the custom retry condition function + RetryIf RetryIfFunc + // Logger instance for the client + Logger Logger +} + +// URL creates a new HTTP client with the given base URL +func URL(baseURL string) *Client { + return Create(&Config{BaseURL: baseURL}) +} + +// Create initializes a new HTTP client with the given configuration +func Create(config *Config) *Client { + cfg, httpClient := setInitialClientDetails(config) + + // Return a new Client instance + client := &Client{ + BaseURL: cfg.BaseURL, + HTTPClient: httpClient, + JSONEncoder: DefaultJSONEncoder, + JSONDecoder: DefaultJSONDecoder, + XMLEncoder: DefaultXMLEncoder, + XMLDecoder: DefaultXMLDecoder, + YAMLEncoder: DefaultYAMLEncoder, + YAMLDecoder: DefaultYAMLDecoder, + TLSConfig: cfg.TLSConfig, + } + + if config != nil { + client.Headers = config.Headers + } + + return finalizeClientChecks(client, cfg, httpClient) +} + +// finalizeClientChecks is a helper function to finalize the client configuration +func finalizeClientChecks(client *Client, config *Config, httpClient *http.Client) *Client { + if client.TLSConfig != nil && httpClient.Transport != nil { + httpTransport := httpClient.Transport.(*http.Transport) + httpTransport.TLSClientConfig = client.TLSConfig + } else if client.TLSConfig != nil { + httpClient.Transport = &http.Transport{ + TLSClientConfig: client.TLSConfig, + } + } + + if config.Middlewares != nil { + client.Middlewares = config.Middlewares + } else { + client.Middlewares = make([]Middleware, 0) + } + + if config.Cookies != nil { + client.SetDefaultCookies(config.Cookies) + } + + if config.MaxRetries != 0 { + client.MaxRetries = config.MaxRetries + } + + if config.RetryStrategy != nil { + client.RetryStrategy = config.RetryStrategy + } else { + client.RetryStrategy = DefaultBackoffStrategy(1 * time.Second) + } + + if config.RetryIf != nil { + client.RetryIf = config.RetryIf + } else { + client.RetryIf = DefaultRetryIf + } + + if config.Logger != nil { + client.Logger = config.Logger + } + + return client +} + +// setInitialClientDetails is a helper function that sets the initial configuration for the client and mostly breaks up how large of a function check the Create function is +func setInitialClientDetails(config *Config) (*Config, *http.Client) { + if config == nil { + config = &Config{} + } + + httpClient := &http.Client{} + + if config.Transport != nil { + httpClient.Transport = config.Transport + } + + if config.Timeout != 0 { + httpClient.Timeout = config.Timeout + } + + if config.CookieJar != nil { + httpClient.Jar = config.CookieJar + } + + return config, httpClient +} + +// SetBaseURL sets the base URL for the client +func (c *Client) SetBaseURL(baseURL string) { + c.mu.Lock() + defer c.mu.Unlock() + + c.BaseURL = baseURL +} + +// AddMiddleware adds a middleware to the client +func (c *Client) AddMiddleware(middlewares ...Middleware) { + c.mu.Lock() + defer c.mu.Unlock() + + c.Middlewares = append(c.Middlewares, middlewares...) +} + +// SetTLSConfig sets the TLS configuration for the client +func (c *Client) SetTLSConfig(config *tls.Config) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.TLSConfig = config + + if c.HTTPClient == nil { + c.HTTPClient = &http.Client{} + } + + // Apply the TLS configuration to the existing transport if possible + // If the current transport is not an *http.Transport, replace it + if transport, ok := c.HTTPClient.Transport.(*http.Transport); ok { + transport.TLSClientConfig = config + } else { + c.HTTPClient.Transport = &http.Transport{ + TLSClientConfig: config, + } + } + + return c +} + +// InsecureSkipVerify sets the TLS configuration to skip certificate verification +func (c *Client) InsecureSkipVerify() *Client { + c.mu.Lock() + defer c.mu.Unlock() + + if c.TLSConfig == nil { + c.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } + + c.TLSConfig.InsecureSkipVerify = true + + if c.HTTPClient == nil { + c.HTTPClient = &http.Client{} + } + + if transport, ok := c.HTTPClient.Transport.(*http.Transport); ok { + transport.TLSClientConfig = c.TLSConfig + } else { + c.HTTPClient.Transport = &http.Transport{ + TLSClientConfig: c.TLSConfig, + } + } + + return c +} + +// SetHTTPClient sets the HTTP client for the client +func (c *Client) SetHTTPClient(httpClient *http.Client) { + c.mu.Lock() + defer c.mu.Unlock() + + c.HTTPClient = httpClient +} + +// SetDefaultHeaders sets the default headers for the client +func (c *Client) SetDefaultHeaders(headers *http.Header) { + c.mu.Lock() + defer c.mu.Unlock() + + c.Headers = headers +} + +// SetDefaultHeader adds or updates a default header +func (c *Client) SetDefaultHeader(key, value string) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.Headers == nil { + c.Headers = &http.Header{} + } + + c.Headers.Set(key, value) +} + +// AddDefaultHeader adds a default header +func (c *Client) AddDefaultHeader(key, value string) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.Headers == nil { + c.Headers = &http.Header{} + } + + c.Headers.Add(key, value) +} + +// DelDefaultHeader removes a default header +func (c *Client) DelDefaultHeader(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.Headers != nil { // only attempt to delete if initialized + c.Headers.Del(key) + } +} + +// SetDefaultContentType sets the default content type for the client +func (c *Client) SetDefaultContentType(contentType string) { + c.SetDefaultHeader(HeaderContentType, contentType) +} + +// SetDefaultAccept sets the default accept header for the client +func (c *Client) SetDefaultAccept(accept string) { + c.SetDefaultHeader(HeaderAccept, accept) +} + +// SetDefaultUserAgent sets the default user agent for the client +func (c *Client) SetDefaultUserAgent(userAgent string) { + c.SetDefaultHeader(HeaderUserAgent, userAgent) +} + +// SetDefaultReferer sets the default referer for the client +func (c *Client) SetDefaultReferer(referer string) { + c.SetDefaultHeader(HeaderReferer, referer) +} + +// SetDefaultTimeout sets the default timeout for the client +func (c *Client) SetDefaultTimeout(timeout time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + c.HTTPClient.Timeout = timeout +} + +// SetDefaultTransport sets the default transport for the client +func (c *Client) SetDefaultTransport(transport http.RoundTripper) { + c.mu.Lock() + defer c.mu.Unlock() + + c.HTTPClient.Transport = transport +} + +// SetDefaultCookieJar sets the default cookie jar for the client +func (c *Client) SetCookieJar(jar *cookiejar.Jar) { + c.mu.Lock() + defer c.mu.Unlock() + + c.HTTPClient.Jar = jar +} + +// SetDefaultCookieJar sets the creates a new cookie jar and sets it for the client +func (c *Client) SetDefaultCookieJar() error { + c.mu.Lock() + defer c.mu.Unlock() + + // Create a new cookie jar + jar, err := cookiejar.New(nil) + if err != nil { + return err + } + + c.HTTPClient.Jar = jar + + return nil +} + +// SetDefaultCookies sets the default cookies for the client +func (c *Client) SetDefaultCookies(cookies map[string]string) { + for name, value := range cookies { + c.SetDefaultCookie(name, value) + } +} + +// SetDefaultCookie sets a default cookie for the client +func (c *Client) SetDefaultCookie(name, value string) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.Cookies == nil { + c.Cookies = make([]*http.Cookie, 0) + } + + c.Cookies = append(c.Cookies, &http.Cookie{Name: name, Value: value}) +} + +// DelDefaultCookie removes a default cookie from the client +func (c *Client) DelDefaultCookie(name string) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.Cookies != nil { // Only attempt to delete if Cookies is initialized + for i, cookie := range c.Cookies { + if cookie.Name == name { + c.Cookies = append(c.Cookies[:i], c.Cookies[i+1:]...) + break + } + } + } +} + +// SetJSONMarshal sets the JSON marshal function for the client's JSONEncoder +func (c *Client) SetJSONMarshal(marshalFunc func(v any) ([]byte, error)) { + c.JSONEncoder = &JSONEncoder{ + MarshalFunc: marshalFunc, + } +} + +// SetJSONUnmarshal sets the JSON unmarshal function for the client's JSONDecoder +func (c *Client) SetJSONUnmarshal(unmarshalFunc func(data []byte, v any) error) { + c.JSONDecoder = &JSONDecoder{ + UnmarshalFunc: unmarshalFunc, + } +} + +// SetXMLMarshal sets the XML marshal function for the client's XMLEncoder +func (c *Client) SetXMLMarshal(marshalFunc func(v any) ([]byte, error)) { + c.XMLEncoder = &XMLEncoder{ + MarshalFunc: marshalFunc, + } +} + +// SetXMLUnmarshal sets the XML unmarshal function for the client's XMLDecoder +func (c *Client) SetXMLUnmarshal(unmarshalFunc func(data []byte, v any) error) { + c.XMLDecoder = &XMLDecoder{ + UnmarshalFunc: unmarshalFunc, + } +} + +// SetYAMLMarshal sets the YAML marshal function for the client's YAMLEncoder +func (c *Client) SetYAMLMarshal(marshalFunc func(v any) ([]byte, error)) { + c.YAMLEncoder = &YAMLEncoder{ + MarshalFunc: marshalFunc, + } +} + +// SetYAMLUnmarshal sets the YAML unmarshal function for the client's YAMLDecoder +func (c *Client) SetYAMLUnmarshal(unmarshalFunc func(data []byte, v any) error) { + c.YAMLDecoder = &YAMLDecoder{ + UnmarshalFunc: unmarshalFunc, + } +} + +// SetMaxRetries sets the maximum number of retry attempts +func (c *Client) SetMaxRetries(maxRetries int) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.MaxRetries = maxRetries + + return c +} + +// SetRetryStrategy sets the backoff strategy for retries +func (c *Client) SetRetryStrategy(strategy BackoffStrategy) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.RetryStrategy = strategy + + return c +} + +// SetRetryIf sets the custom retry condition function +func (c *Client) SetRetryIf(retryIf RetryIfFunc) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.RetryIf = retryIf + + return c +} + +// SetAuth configures an authentication method for the client +func (c *Client) SetAuth(auth AuthMethod) { + if auth.Valid() { + c.Auth = auth + } +} + +// SetLogger sets logger instance in client +func (c *Client) SetLogger(logger Logger) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.Logger = logger + + return c +} + +// Get initiates a GET request +func (c *Client) Get(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodGet, path) +} + +// Post initiates a POST request +func (c *Client) Post(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodPost, path) +} + +// Delete initiates a DELETE request +func (c *Client) Delete(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodDelete, path) +} + +// Put initiates a PUT request +func (c *Client) Put(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodPut, path) +} + +// Patch initiates a PATCH request +func (c *Client) Patch(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodPatch, path) +} + +// Options initiates an OPTIONS request +func (c *Client) Options(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodOptions, path) +} + +// Head initiates a HEAD request +func (c *Client) Head(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodHead, path) +} + +// Connect initiates a CONNECT request +func (c *Client) Connect(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodConnect, path) +} + +// Trace initiates a TRACE request +func (c *Client) Trace(path string) *RequestBuilder { + return c.NewRequestBuilder(http.MethodTrace, path) +} + +// Custom initiates a custom request +func (c *Client) Custom(path, method string) *RequestBuilder { + return c.NewRequestBuilder(method, path) +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..a9f7a2f --- /dev/null +++ b/client_test.go @@ -0,0 +1,783 @@ +package httpsling + +import ( + "context" + "encoding/base64" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "testing" + "time" + + "github.com/bytedance/sonic" + yaml "github.com/goccy/go-yaml" + "github.com/stretchr/testify/assert" +) + +// startTestHTTPServer starts a test HTTP server that responds to various endpoints for testing purposes +func startTestHTTPServer() *httptest.Server { + handler := http.NewServeMux() + handler.HandleFunc("/test-get", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "GET response") + }) + + handler.HandleFunc("/test-post", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "POST response") + }) + + handler.HandleFunc("/test-put", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "PUT response") + }) + + handler.HandleFunc("/test-delete", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "DELETE response") + }) + + handler.HandleFunc("/test-patch", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "PATCH response") + }) + + handler.HandleFunc("/test-status-code", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) // 201 + fmt.Fprintln(w, `Created`) + }) + + handler.HandleFunc("/test-headers", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Header", "TestValue") + fmt.Fprintln(w, `Headers test`) + }) + + handler.HandleFunc("/test-cookies", func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{Name: "test-cookie", Value: "cookie-value"}) + fmt.Fprintln(w, `Cookies test`) + }) + + handler.HandleFunc("/test-body", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "This is the response body.") + }) + + handler.HandleFunc("/test-empty", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler.HandleFunc("/test-json", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeJSON) + fmt.Fprintln(w, `{"message": "This is a JSON response", "status": true}`) + }) + + handler.HandleFunc("/test-xml", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeXML) + fmt.Fprintln(w, `This is an XML responsetrue`) + }) + + handler.HandleFunc("/test-text", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeText) + fmt.Fprintln(w, `This is a text response`) + }) + + handler.HandleFunc("/test-pdf", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, "application/pdf") + fmt.Fprintln(w, `This is a PDF response`) + }) + + handler.HandleFunc("/test-redirect", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/test-redirected", http.StatusFound) + }) + + handler.HandleFunc("/test-redirected", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Redirected") + }) + + handler.HandleFunc("/test-failure", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + }) + + return httptest.NewServer(handler) +} + +// testRoundTripperFunc type is an adapter to allow the use of ordinary functions as http.RoundTrippers +type testRoundTripperFunc func(*http.Request) (*http.Response, error) + +// RoundTrip executes a single HTTP transaction +func (f testRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// TestSetHTTPClient verifies that SetHTTPClient correctly sets a custom http.Client and uses it for subsequent httpsling, specifically checking for cookie modifications +func TestSetHTTPClient(t *testing.T) { + // Create a mock server that inspects incoming httpsling for a specific cookie + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for the presence of a specific cookie + cookie, err := r.Cookie("X-Custom-Test-Cookie") + if err != nil || cookie.Value != "true" { + // If the cookie is missing or not as expected, respond with a 400 Bad Request + w.WriteHeader(http.StatusBadRequest) + return + } + // If the cookie is present and correct, respond with a 200 + w.WriteHeader(http.StatusOK) + })) + + defer mockServer.Close() + + // Create a new instance of your Client + client := Create(&Config{ + BaseURL: mockServer.URL, + }) + + // Define a custom transport that adds a custom cookie to all outgoing httpsling + customTransport := testRoundTripperFunc(func(req *http.Request) (*http.Response, error) { + // Add the custom cookie to the request + req.AddCookie(&http.Cookie{Name: "X-Custom-Test-Cookie", Value: "true"}) + // Proceed with the default transport after adding the cookie + return http.DefaultTransport.RoundTrip(req) + }) + + // Set the custom http.Client with the custom transport to your Client + client.SetHTTPClient(&http.Client{ + Transport: customTransport, + }) + + // Send a request using the custom http.Client + resp, err := client.Get("/test").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + defer resp.Close() //nolint: errcheck + + // Verify that the server responded with a 200 OK, indicating the custom cookie was successfully added. + if resp.StatusCode() != http.StatusOK { + t.Errorf("Expected status code 200, got %d. Indicates custom cookie was not recognized by the server.", resp.StatusCode()) + } +} + +func TestClientURL(t *testing.T) { + client := URL("http://localhost:8080") + assert.NotNil(t, client) + assert.Equal(t, "http://localhost:8080", client.BaseURL) +} + +func TestClientGetRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-get").Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, "GET response\n", resp.String()) +} + +func TestClientPostRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Post("/test-post").Body(map[string]interface{}{"key": "value"}).Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, "POST response\n", resp.String()) +} + +func TestClientPutRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Put("/test-put").Body(map[string]interface{}{"key": "value"}).Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, "PUT response\n", resp.String()) +} + +func TestClientDeleteRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Delete("/test-delete").Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, "DELETE response\n", resp.String()) +} + +func TestClientPatchRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Patch("/test-patch").Body(map[string]interface{}{"key": "value"}).Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, "PATCH response\n", resp.String()) +} + +func TestClientOptionsRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Options("/test-get").Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.RawResponse.StatusCode) +} + +func TestClientHeadRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Head("/test-get").Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.RawResponse.StatusCode) +} + +func TestClientTraceRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Trace("/test-get").Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.RawResponse.StatusCode) +} + +func TestClientCustomMethodRequest(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Custom("/test-get", "OPTIONS").Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.RawResponse.StatusCode) +} + +// testSchema represents the JSON structure for testing +type testSchema struct { + Name string `json:"name"` + Age int `json:"age"` +} + +// TestSetJSONMarshal tests custom JSON marshal functionality +func TestSetJSONMarshal(t *testing.T) { + // Start a mock HTTP server. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read body from the request + var received testSchema + err := json.NewDecoder(r.Body).Decode(&received) + assert.NoError(t, err) + assert.Equal(t, "John Snow", received.Name) + assert.Equal(t, 30, received.Age) + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + client.SetJSONMarshal(sonic.Marshal) + + data := testSchema{ + Name: "John Snow", + Age: 30, + } + + // Send a request with the custom marshaled body + resp, err := client.Post("/").JSONBody(&data).Send(context.Background()) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode()) +} + +// TestSetJSONUnmarshal tests custom JSON unmarshal functionality +func TestSetJSONUnmarshal(t *testing.T) { + // Mock response data + mockResponse := `{"name":"Jane Doe","age":25}` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeJSON) + fmt.Fprintln(w, mockResponse) + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Set the custom JSON unmarshal function + client.SetJSONUnmarshal(sonic.Unmarshal) + + // Fetch and unmarshal the response + resp, err := client.Get("/").Send(context.Background()) + assert.NoError(t, err) + + var result testSchema + err = resp.Scan(&result) + assert.NoError(t, err) + assert.Equal(t, "Jane Doe", result.Name) + assert.Equal(t, 25, result.Age) +} + +type xmlTestSchema struct { + XMLName xml.Name `xml:"Test"` + Message string `xml:"Message"` + Status bool `xml:"Status"` +} + +func TestSetXMLMarshal(t *testing.T) { + // Mock server to check the received XML + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var received xmlTestSchema + err := xml.NewDecoder(r.Body).Decode(&received) + assert.NoError(t, err) + assert.Equal(t, "Test message", received.Message) + assert.True(t, received.Status) + })) + + defer server.Close() + + // Create your client and set the XML marshal function to use Go's default + client := Create(&Config{BaseURL: server.URL}) + client.SetXMLMarshal(xml.Marshal) + + // Data to marshal and send + data := xmlTestSchema{ + Message: "Test message", + Status: true, + } + + // Marshal and send the data + resp, err := client.Post("/").XMLBody(&data).Send(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestSetXMLUnmarshal(t *testing.T) { + // Mock server to send XML data + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeXML) + fmt.Fprintln(w, `Response messagetrue`) + })) + defer server.Close() + + // Create your client and set the XML unmarshal function to use go's default + client := Create(&Config{BaseURL: server.URL}) + client.SetXMLUnmarshal(xml.Unmarshal) + + // Fetch and attempt to unmarshal the data + resp, err := client.Get("/").Send(context.Background()) + assert.NoError(t, err) + + var result xmlTestSchema + err = resp.Scan(&result) + assert.NoError(t, err) + assert.Equal(t, "Response message", result.Message) + assert.True(t, result.Status) +} + +func TestSetYAMLMarshal(t *testing.T) { + type yamlTestSchema struct { + Message string `yaml:"message"` + Status bool `yaml:"status"` + } + + // Mock server to check the received YAML + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var received yamlTestSchema + err := yaml.NewDecoder(r.Body).Decode(&received) + assert.NoError(t, err) + assert.Equal(t, "Test message", received.Message) + assert.True(t, received.Status) + })) + + defer server.Close() + + // Create your client and set the YAML marshal function to use goccy/go-yaml's Marshal + client := Create(&Config{BaseURL: server.URL}) + client.SetYAMLMarshal(yaml.Marshal) + + // Data to marshal and send + data := yamlTestSchema{ + Message: "Test message", + Status: true, + } + + // Marshal and send the data + resp, err := client.Post("/").YAMLBody(&data).Send(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestSetYAMLUnmarshal(t *testing.T) { + // Define a test schema + type yamlTestSchema struct { + Message string `yaml:"message"` + Status bool `yaml:"status"` + } + + // Mock server to send YAML data + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeYAML) + fmt.Fprintln(w, "message: Response message\nstatus: true") + })) + defer server.Close() + + // Create your client and set the YAML unmarshal function to use goccy/go-yaml's Unmarshal + client := Create(&Config{BaseURL: server.URL}) + client.SetYAMLUnmarshal(yaml.Unmarshal) + + // Fetch and attempt to unmarshal the data + resp, err := client.Get("/").Send(context.Background()) + assert.NoError(t, err) + + var result yamlTestSchema + err = resp.Scan(&result) + assert.NoError(t, err) + assert.Equal(t, "Response message", result.Message) + assert.True(t, result.Status) +} + +// TestSetAuth verifies that SetAuth correctly sets the Authorization header for basic authentication +func TestSetAuth(t *testing.T) { + // Expected username and password + expectedUsername := "testuser" + expectedPassword := "testpass" + + // Expected Authorization header value + expectedAuthValue := "Basic " + base64.StdEncoding.EncodeToString([]byte(expectedUsername+":"+expectedPassword)) + + // Create a mock server that checks the Authorization header + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Retrieve the Authorization header from the request. + authHeader := r.Header.Get(HeaderAuthorization) + + // Check if the Authorization header matches the expected value + if authHeader != expectedAuthValue { + // If not, respond with 401 + w.WriteHeader(http.StatusUnauthorized) + return + } + + // If the header is correct, respond with 200 OK + w.WriteHeader(http.StatusOK) + })) + + defer mockServer.Close() + + client := Create(&Config{ + BaseURL: mockServer.URL, + }) + + // Set basic authentication using the SetBasicAuth method + client.SetAuth(BasicAuth{ + Username: expectedUsername, + Password: expectedPassword, + }) + + // Send the request through the client + resp, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Close() //nolint: errcheck + + // Check the response status code + if resp.StatusCode() != http.StatusOK { + t.Errorf("Expected status code 200, got %d. Indicates Authorization header was not set correctly.", resp.StatusCode()) + } +} + +func TestSetDefaultHeaders(t *testing.T) { + // Create a mock server to check headers + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Custom-Header") != "HeaderValue" { + t.Error("Default header 'X-Custom-Header' not found or value incorrect") + } + })) + + defer mockServer.Close() + + // Initialize the client and set a default header + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultHeader("X-Custom-Header", "HeaderValue") + + // Make a request to trigger the header check + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } +} + +func TestDelDefaultHeader(t *testing.T) { + // Mock server to check for the absence of a specific header + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Deleted-Header") != "" { + t.Error("Deleted default header 'X-Deleted-Header' was found in the request") + } + })) + + defer mockServer.Close() + + // Initialize the client, set, and then delete a default header + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultHeader("X-Deleted-Header", "ShouldBeDeleted") + client.DelDefaultHeader("X-Deleted-Header") + + // Make a request to check for the absence of the deleted header + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } +} + +func TestSetDefaultContentType(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the Content-Type header + if r.Header.Get(HeaderContentType) != ContentTypeJSON { + t.Error("Default Content-Type header not set correctly") + } + })) + defer mockServer.Close() + + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultContentType(ContentTypeJSON) + + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } +} + +func TestSetDefaultAccept(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the Accept header + if r.Header.Get(HeaderAccept) != ContentTypeXML { + t.Error("Default Accept header not set correctly") + } + })) + defer mockServer.Close() + + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultAccept(ContentTypeXML) + + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } +} + +func TestSetDefaultUserAgent(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the User-Agent header + if r.Header.Get(HeaderUserAgent) != "MyCustomAgent/1.0" { + t.Error("Default User-Agent header not set correctly") + } + })) + defer mockServer.Close() + + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultUserAgent("MyCustomAgent/1.0") + + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } +} + +func TestSetDefaultTimeout(t *testing.T) { + // Create a server that delays its response + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) // Delay longer than client's timeout + })) + defer mockServer.Close() + + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultTimeout(1 * time.Second) // Set timeout to 1 second + + _, err := client.Get("/").Send(context.Background()) + if err == nil { + t.Fatal("Expected a timeout error, got nil") + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + // Check if the error is a timeout error + } else { + t.Fatalf("Expected a timeout error, got %v", err) + } +} + +func TestSetDefaultCookieJar(t *testing.T) { + jar, _ := cookiejar.New(nil) + + // Initialize the client and set the default cookie jar nom nom nom + client := Create(&Config{}) + client.SetCookieJar(jar) + + // Start a test HTTP server that sets a cookie + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/set-cookie" { + http.SetCookie(w, &http.Cookie{Name: "test", Value: "cookie"}) + return + } + + // Check for the cookie on a different endpoint + cookie, err := r.Cookie("test") + if err != nil { + t.Fatal("Cookie 'test' not found in request, cookie jar not working") + } + + if cookie.Value != "cookie" { + t.Fatalf("Expected cookie 'test' to have value 'cookie', got '%s'", cookie.Value) + } + })) + + defer server.Close() + + // First request to set the cookie + _, err := client.Get(server.URL + "/set-cookie").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + // Second request to check if the cookie is sent back + _, err = client.Get(server.URL + "/check-cookie").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send second request: %v", err) + } +} + +func TestSetDefaultCookies(t *testing.T) { + // Create a mock server to check cookies + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for the presence of specific cookies + sessionCookie, err := r.Cookie("session_id") + if err != nil || sessionCookie.Value != "abcd1234" { + t.Error("Default cookie 'session_id' not found or value incorrect") + } + + authCookie, err := r.Cookie("auth_token") + if err != nil || authCookie.Value != "token1234" { + t.Error("Default cookie 'auth_token' not found or value incorrect") + } + })) + + defer mockServer.Close() + + // Initialize the client and set default cookies + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultCookies(map[string]string{ + "session_id": "abcd1234", + "auth_token": "token1234", + }) + + // Make a request to trigger the cookie check + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } +} + +func TestDelDefaultCookie(t *testing.T) { + // Mock server to check for absence of a specific cookie + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := r.Cookie("session_id") + if err == nil { + t.Error("Deleted default cookie 'session_id' was found in the request") + } + })) + + defer mockServer.Close() + + // Initialize the client, set, and then delete a default cookie + client := Create(&Config{BaseURL: mockServer.URL}) + client.SetDefaultCookie("session_id", "abcd1234") + client.DelDefaultCookie("session_id") + + // Make a request to check for the absence of the deleted cookie + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } +} + +func createTestRetryServer(t *testing.T) *httptest.Server { + var requestCount int + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + switch requestCount { + case 1: + w.WriteHeader(http.StatusInternalServerError) // Simulate server error on first attempt + case 2: + w.WriteHeader(http.StatusOK) // Successful on second attempt + default: + t.Fatalf("Unexpected number of httpsling: %d", requestCount) + } + })) + + return server +} + +func TestSetMaxRetriesAndRetryStrategy(t *testing.T) { + server := createTestRetryServer(t) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + retryCalled := false + + client.SetMaxRetries(1).SetRetryStrategy(func(attempt int) time.Duration { + retryCalled = true + return 10 * time.Millisecond // Short delay for testing + }) + + // Make a request to the test server + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + if !retryCalled { + t.Error("Expected retry strategy to be called, but it wasn't") + } +} + +func TestSetRetryIf(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) // Always return server error + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + client.SetMaxRetries(2).SetRetryIf(func(req *http.Request, resp *http.Response, err error) bool { + return resp.StatusCode == http.StatusInternalServerError + }) + + retryCount := 0 + + client.SetRetryStrategy(func(int) time.Duration { + retryCount++ + return 10 * time.Millisecond // Short delay for testing + }) + + // Make a request to the test server + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + if retryCount != 2 { + t.Errorf("Expected 2 retries, got %d", retryCount) + } +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..52a2b9b --- /dev/null +++ b/doc.go @@ -0,0 +1,2 @@ +// Package httpsling is a wrapper for creating and sending http httpsling (e.g. for webhooks, external 3d party integrations) +package httpsling diff --git a/encoders.go b/encoders.go new file mode 100644 index 0000000..0ecee1b --- /dev/null +++ b/encoders.go @@ -0,0 +1,264 @@ +package httpsling + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "io" + + "github.com/goccy/go-yaml" + "github.com/valyala/bytebufferpool" +) + +// Encoder is the interface that wraps the Encode method +type Encoder interface { + // Encode encodes the provided value into a reader + Encode(v any) (io.Reader, error) + // ContentType returns the content type of the encoded data + ContentType() string +} + +// Decoder is the interface that wraps the Decode method +type Decoder interface { + // Decode decodes the data from the reader into the provided value + Decode(r io.Reader, v any) error +} + +// StreamCallback is a callback function that is called when data is received +type StreamCallback func([]byte) error + +// StreamErrCallback is a callback function that is called when an error occurs +type StreamErrCallback func(error) + +// StreamDoneCallback is a callback function that is called when the stream is done +type StreamDoneCallback func() + +// JSONEncoder handles encoding of JSON data +type JSONEncoder struct { + // MarshalFunc is the custom marshal function to use + MarshalFunc func(v any) ([]byte, error) +} + +// JSONDecoder handles decoding of JSON data +type JSONDecoder struct { + UnmarshalFunc func(data []byte, v any) error +} + +// DefaultJSONEncoder instance using the standard json.Marshal function +var DefaultJSONEncoder = &JSONEncoder{ + MarshalFunc: json.Marshal, +} + +// DefaultJSONDecoder instance using the standard json.Unmarshal function +var DefaultJSONDecoder = &JSONDecoder{ + UnmarshalFunc: json.Unmarshal, +} + +var bufferPool bytebufferpool.Pool + +// poolReader wraps bytes.Reader to return the buffer to the pool when closed +type poolReader struct { + // *bytes.Reader is an io.Reader + *bytes.Reader + // poolBuf is a bytebufferpool.ByteBuffer + poolBuf *bytebufferpool.ByteBuffer +} + +// ContentType returns the content type for JSON data +func (e *JSONEncoder) ContentType() string { + return ContentTypeJSONUTF8 +} + +// Encode marshals the provided value into JSON format +func (e *JSONEncoder) Encode(v any) (io.Reader, error) { + var err error + + var data []byte + + if e.MarshalFunc == nil { + data, err = json.Marshal(v) // Fallback to standard JSON marshal if no custom function is provided + } else { + data, err = e.MarshalFunc(v) + } + + if err != nil { + return nil, err + } + + buf := GetBuffer() + + _, err = buf.Write(data) + + if err != nil { + PutBuffer(buf) // Ensure the buffer is returned to the pool in case of an error + return nil, err + } + + // we need to ensure the buffer will be returned to the pool after being read + reader := &poolReader{Reader: bytes.NewReader(buf.B), poolBuf: buf} + + return reader, nil +} + +// Decode reads the data from the reader and unmarshals it into the provided value +func (d *JSONDecoder) Decode(r io.Reader, v any) error { + data, err := io.ReadAll(r) + if err != nil { + return err + } + + if d.UnmarshalFunc != nil { + return d.UnmarshalFunc(data, v) + } + + return json.Unmarshal(data, v) +} + +// GetBuffer retrieves a buffer from the pool +func GetBuffer() *bytebufferpool.ByteBuffer { + return bufferPool.Get() +} + +// PutBuffer returns a buffer to the pool +func PutBuffer(b *bytebufferpool.ByteBuffer) { + bufferPool.Put(b) +} + +func (r *poolReader) Close() error { + PutBuffer(r.poolBuf) + + return nil +} + +// XMLEncoder handles encoding of XML data +type XMLEncoder struct { + MarshalFunc func(v any) ([]byte, error) +} + +// DefaultXMLEncoder instance using the standard xml.Marshal function +var DefaultXMLEncoder = &XMLEncoder{ + MarshalFunc: xml.Marshal, +} + +// XMLDecoder handles decoding of XML data +type XMLDecoder struct { + UnmarshalFunc func(data []byte, v any) error +} + +// DefaultXMLDecoder instance using the standard xml.Unmarshal function +var DefaultXMLDecoder = &XMLDecoder{ + UnmarshalFunc: xml.Unmarshal, +} + +// Encode marshals the provided value into XML format +func (e *XMLEncoder) Encode(v any) (io.Reader, error) { + var err error + + var data []byte + + if e.MarshalFunc != nil { + data, err = e.MarshalFunc(v) + } else { + data, err = xml.Marshal(v) + } + + if err != nil { + return nil, err + } + + buf := GetBuffer() + _, err = buf.Write(data) + + if err != nil { + PutBuffer(buf) + return nil, err + } + + return &poolReader{Reader: bytes.NewReader(buf.B), poolBuf: buf}, nil +} + +// ContentType returns the content type for XML data +func (e *XMLEncoder) ContentType() string { + return ContentTypeXMLUTF8 +} + +// Decode unmarshals the XML data from the reader into the provided value +func (d *XMLDecoder) Decode(r io.Reader, v any) error { + data, err := io.ReadAll(r) + if err != nil { + return err + } + + if d.UnmarshalFunc != nil { + return d.UnmarshalFunc(data, v) + } + + return xml.Unmarshal(data, v) +} + +// YAMLEncoder handles encoding of YAML data +type YAMLEncoder struct { + MarshalFunc func(v any) ([]byte, error) +} + +// DefaultYAMLEncoder instance using the goccy/go-yaml Marshal function +var DefaultYAMLEncoder = &YAMLEncoder{ + MarshalFunc: yaml.Marshal, +} + +// YAMLDecoder handles decoding of YAML data +type YAMLDecoder struct { + UnmarshalFunc func(data []byte, v any) error +} + +// DefaultYAMLDecoder instance using the goccy/go-yaml Unmarshal function +var DefaultYAMLDecoder = &YAMLDecoder{ + UnmarshalFunc: yaml.Unmarshal, +} + +// Encode marshals the provided value into YAML format +func (e *YAMLEncoder) Encode(v any) (io.Reader, error) { + var err error + + var data []byte + + if e.MarshalFunc != nil { + data, err = e.MarshalFunc(v) + } else { + data, err = yaml.Marshal(v) + } + + if err != nil { + return nil, err + } + + buf := GetBuffer() + _, err = buf.Write(data) + + if err != nil { + PutBuffer(buf) + return nil, err + } + + return &poolReader{Reader: bytes.NewReader(buf.B), poolBuf: buf}, nil +} + +// ContentType returns the content type for YAML data +func (e *YAMLEncoder) ContentType() string { + return ContentTypeYAMLUTF8 +} + +// Decode reads the data from the reader and unmarshals it into the provided value +func (d *YAMLDecoder) Decode(r io.Reader, v any) error { + data, err := io.ReadAll(r) + if err != nil { + return err + } + + if d.UnmarshalFunc != nil { + return d.UnmarshalFunc(data, v) + } + + // Fallback to standard YAML unmarshal using goccy/go-yaml + return yaml.Unmarshal(data, v) +} diff --git a/encoders_test.go b/encoders_test.go new file mode 100644 index 0000000..7774a13 --- /dev/null +++ b/encoders_test.go @@ -0,0 +1,127 @@ +package httpsling_test + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/theopenlane/httpsling" +) + +func TestJSONEncoderEncode(t *testing.T) { + encoder := &httpsling.JSONEncoder{} + + // Test encoding a struct + data := struct { + Name string `json:"name"` + Age int `json:"age"` + }{ + Name: "John Doe", + Age: 30, + } + + reader, err := encoder.Encode(data) + require.NoError(t, err) + + encodedData, err := io.ReadAll(reader) + require.NoError(t, err) + + expectedData := `{"name":"John Doe","age":30}` + require.Equal(t, expectedData, string(encodedData)) +} + +func TestJSONDecoderDecode(t *testing.T) { + decoder := &httpsling.JSONDecoder{} + + // Test decoding JSON data into a struct + jsonData := `{"name":"John Snow","age":30}` + + var data struct { + Name string `json:"name"` + Age int `json:"age"` + } + + err := decoder.Decode(bytes.NewReader([]byte(jsonData)), &data) + require.NoError(t, err) + + expectedData := struct { + Name string `json:"name"` + Age int `json:"age"` + }{ + Name: "John Snow", + Age: 30, + } + require.Equal(t, expectedData, data) +} + +func TestXMLDecoderDecode(t *testing.T) { + decoder := &httpsling.XMLDecoder{} + + // Test decoding XML data into a struct + xmlData := `John Meow30` + + var data struct { + Name string `xml:"name"` + Age int `xml:"age"` + } + + err := decoder.Decode(bytes.NewReader([]byte(xmlData)), &data) + require.NoError(t, err) + + expectedData := struct { + Name string `xml:"name"` + Age int `xml:"age"` + }{ + Name: "John Meow", + Age: 30, + } + require.Equal(t, expectedData, data) +} + +func TestYAMLEncoderEncode(t *testing.T) { + encoder := &httpsling.YAMLEncoder{} + + // Test encoding a struct + data := struct { + Name string `yaml:"name"` + Age int `yaml:"age"` + }{ + Name: "John Flow", + Age: 30, + } + + reader, err := encoder.Encode(data) + require.NoError(t, err) + + encodedData, err := io.ReadAll(reader) + require.NoError(t, err) + + expectedData := "name: John Flow\nage: 30\n" + require.Equal(t, expectedData, string(encodedData)) +} + +func TestYAMLDecoderDecode(t *testing.T) { + decoder := &httpsling.YAMLDecoder{} + + // Test decoding YAML data into a struct + yamlData := "name: John Show\nage: 30\n" + + var data struct { + Name string `yaml:"name"` + Age int `yaml:"age"` + } + + err := decoder.Decode(bytes.NewReader([]byte(yamlData)), &data) + require.NoError(t, err) + + expectedData := struct { + Name string `yaml:"name"` + Age int `yaml:"age"` + }{ + Name: "John Show", + Age: 30, + } + require.Equal(t, expectedData, data) +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..d4902e9 --- /dev/null +++ b/errors.go @@ -0,0 +1,32 @@ +package httpsling + +import ( + "errors" +) + +var ( + // ErrUnsupportedContentType is returned when the content type is unsupported + ErrUnsupportedContentType = errors.New("unsupported content type") + // ErrUnsupportedDataType is returned when the data type is unsupported + ErrUnsupportedDataType = errors.New("unsupported data type") + // ErrEncodingFailed is returned when the encoding fails + ErrEncodingFailed = errors.New("encoding failed") + // ErrRequestCreationFailed is returned when the request cannot be created + ErrRequestCreationFailed = errors.New("failed to create request") + // ErrResponseReadFailed is returned when the response cannot be read + ErrResponseReadFailed = errors.New("failed to read response") + // ErrUnsupportedScheme is returned when the proxy scheme is unsupported + ErrUnsupportedScheme = errors.New("unsupported proxy scheme") + // ErrUnsupportedFormFieldsType is returned when the form fields type is unsupported + ErrUnsupportedFormFieldsType = errors.New("unsupported form fields type") + // ErrNotSupportSaveMethod is returned when the provided type for saving is not supported + ErrNotSupportSaveMethod = errors.New("the provided type for saving is not supported") + // ErrInvalidTransportType is returned when the transport type is invalid + ErrInvalidTransportType = errors.New("invalid transport type") + // ErrResponseNil is returned when the response is nil + ErrResponseNil = errors.New("response is nil") + // ErrFailedToCloseResponseBody is returned when the response body cannot be closed + ErrFailedToCloseResponseBody = errors.New("failed to close response body") + // ErrMapper + ErrMapper = "%w: %v" +) diff --git a/examples/main.go b/examples/main.go new file mode 100644 index 0000000..052b6ae --- /dev/null +++ b/examples/main.go @@ -0,0 +1,40 @@ +package main + +import ( + "context" + "log" + "time" + + "github.com/theopenlane/httpsling" +) + +// Post represents a simple structure to map JSON Placeholder posts +type Post struct { + UserID int `json:"userId"` + ID int `json:"id"` + Title string `json:"title"` + Body string `json:"body"` +} + +func main() { + // Initialize the client with a base URL and a default timeout + client := httpsling.Create(&httpsling.Config{ + BaseURL: "https://jsonplaceholder.typicode.com", + Timeout: 30 * time.Second, + }) + + // Perform a GET request to the /posts endpoint + resp, err := client.Get("/posts/{post_id}").PathParam("post_id", "1").Send(context.Background()) + if err != nil { + log.Fatalf("Failed to make request: %v", err) + } + + // Decode the JSON response into our Post struct + var post Post + if err := resp.ScanJSON(&post); err != nil { + log.Fatalf("Failed to parse response: %v", err) + } + + // Output the result + log.Printf("Post Received: %+v\n", post) +} diff --git a/form.go b/form.go new file mode 100644 index 0000000..4ecb41d --- /dev/null +++ b/form.go @@ -0,0 +1,147 @@ +package httpsling + +import ( + "fmt" + "io" + "net/url" + "strings" + + "github.com/google/go-querystring/query" +) + +// File represents a form file +type File struct { + // Name is the form field name + Name string + // FileName is the file name + FileName string + // Content is the file content + Content io.ReadCloser + // FileMime is the file mime type + FileMime string +} + +// FormEncoder handles encoding of form data +type FormEncoder struct{} + +// DefaultFormEncoder instance +var DefaultFormEncoder = &FormEncoder{} + +// SetContent sets the content of the file +func (f *File) SetContent(content io.ReadCloser) { + f.Content = content +} + +// SetFileName sets the file name +func (f *File) SetFileName(fileName string) { + f.FileName = fileName +} + +// SetName sets the form field name +func (f *File) SetName(name string) { + f.Name = name +} + +// parseFormFields parses the given form fields into url.Values +func parseFormFields(fields any) (url.Values, error) { + switch data := fields.(type) { + case url.Values: + return data, nil + case map[string][]string: + return url.Values(data), nil + case map[string]string: + values := make(url.Values) + + for key, value := range data { + values.Set(key, value) + } + + return values, nil + default: + if values, err := query.Values(fields); err == nil { + return values, nil + } else { + return nil, fmt.Errorf("%w: %v", ErrUnsupportedFormFieldsType, err) + } + } +} + +// parseForm parses the given form data into url.Values and []*File +func parseForm(v any) (url.Values, []*File, error) { + switch data := v.(type) { + case url.Values: + // Directly return url.Values data + return data, nil, nil + case map[string][]string: + // Convert and return map[string][]string data as url.Values + return url.Values(data), nil, nil + case map[string]string: + // Convert and return map[string]string data as url.Values + values := make(url.Values) + for key, value := range data { + values.Set(key, value) + } + + return values, nil, nil + case map[string]any: + // Convert and return map[string]any data as url.Values and []*File + values := make(url.Values) + + files := make([]*File, 0) + + for key, value := range data { + switch v := value.(type) { + case string: + values.Set(key, v) + case []string: + for _, v := range v { + values.Add(key, v) + } + case *File: + v.SetName(key) + files = append(files, v) + default: + return nil, nil, fmt.Errorf("%w: %T", ErrUnsupportedDataType, value) + } + } + + return values, files, nil + default: + // Attempt to use query.Values for encoding struct types + if values, err := query.Values(v); err == nil { + return values, nil, nil + } else { + return nil, nil, fmt.Errorf("%w: %v", ErrUnsupportedFormFieldsType, err) + } + } +} + +// Encode encodes the given value into URL-encoded form data +func (e *FormEncoder) Encode(v any) (io.Reader, error) { + switch data := v.(type) { + case url.Values: + // Directly encode url.Values data. + return strings.NewReader(data.Encode()), nil + case map[string][]string: + // Convert and encode map[string][]string data as url.Values + values := url.Values(data) + return strings.NewReader(values.Encode()), nil + case map[string]string: + // Convert and encode map[string]string data as url.Values + values := make(url.Values) + + for key, value := range data { + values.Set(key, value) + } + + return strings.NewReader(values.Encode()), nil + default: + // Attempt to use query.Values for encoding struct types + if values, err := query.Values(v); err == nil { + return strings.NewReader(values.Encode()), nil + } else { + // Return an error if encoding fails or type is unsupported + return nil, fmt.Errorf("%w: %v", ErrEncodingFailed, err) + } + } +} diff --git a/form_test.go b/form_test.go new file mode 100644 index 0000000..3710703 --- /dev/null +++ b/form_test.go @@ -0,0 +1,152 @@ +package httpsling + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// startFileUploadServer starts a mock server to test file uploads +func startFileUploadServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(10 << 20) // Limit: 10MB + if err != nil { + http.Error(w, "Failed to parse multipart form", http.StatusBadRequest) + + return + } + + // Collect file upload details + uploads := make(map[string][]string) + + for key, files := range r.MultipartForm.File { + for _, fileHeader := range files { + file, err := fileHeader.Open() + if err != nil { + http.Error(w, "Failed to open file", http.StatusInternalServerError) + return + } + + defer file.Close() //nolint: errcheck + + // Read file content (for demonstration; in real tests, might hash or skip) + content, err := io.ReadAll(file) + if err != nil { + http.Error(w, "Failed to read file content", http.StatusInternalServerError) + + return + } + + // Store file details (e.g., filename and a snippet of content for verification) + contentSnippet := string(content) + if len(contentSnippet) > 10 { + contentSnippet = contentSnippet[:10] + "..." + } + + uploads[key] = append(uploads[key], fmt.Sprintf("%s: %s", fileHeader.Filename, contentSnippet)) + } + } + + // Respond with details of the uploaded files in JSON format + w.Header().Set(HeaderContentType, ContentTypeJSON) + + if encoder := json.NewEncoder(w); encoder != nil { + if err = encoder.Encode(uploads); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } + } else { + http.Error(w, "Failed to create JSON encoder", http.StatusInternalServerError) + } + })) +} + +func TestFiles(t *testing.T) { + server := startFileUploadServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + fileContent1 := strings.NewReader("File content 1") + fileContent2 := strings.NewReader("File content 2") + + resp, err := client.Post("/"). + Files( + &File{Name: "file1", FileName: "test1.txt", Content: io.NopCloser(fileContent1)}, + &File{Name: "file2", FileName: "test2.txt", Content: io.NopCloser(fileContent2)}, + ). + Send(context.Background()) + + assert.NoError(t, err, "No error expected on sending request") + + var uploads map[string][]string + err = resp.ScanJSON(&uploads) + assert.NoError(t, err, "Expect no error on parsing response") + + // Validate the file uploads + assert.Contains(t, uploads, "file1", "file1 should be present in the uploads") + assert.Contains(t, uploads, "file2", "file2 should be present in the uploads") +} +func TestFile(t *testing.T) { + server := startFileUploadServer() // Start the mock file upload server + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Simulate a file's content + fileContent := strings.NewReader("This is the file content") + + // Send a request with a single file + resp, err := client.Post("/"). + File("file", "single.txt", io.NopCloser(fileContent)). + Send(context.Background()) + assert.NoError(t, err, "No error expected on sending request") + + // Parse the server's JSON response + var uploads map[string][]string + err = resp.ScanJSON(&uploads) + assert.NoError(t, err, "Expect no error on parsing response") + + // Check if the server received the file correctly + assert.Contains(t, uploads, "file", "The file should be present in the uploads") + assert.Contains(t, uploads["file"][0], "single.txt", "The file name should be correctly received") +} + +func TestDelFile(t *testing.T) { + server := startFileUploadServer() // Start the mock file upload server + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Simulate file contents + fileContent1 := strings.NewReader("File content 1") + fileContent2 := strings.NewReader("File content 2") + + // Prepare the request with two files, then delete one before sending + resp, err := client.Post("/"). + Files( + &File{Name: "file1", FileName: "file1.txt", Content: io.NopCloser(fileContent1)}, + &File{Name: "file2", FileName: "file2.txt", Content: io.NopCloser(fileContent2)}, + ). + DelFile("file1"). // Remove the first file + Send(context.Background()) + assert.NoError(t, err, "No error expected on sending request") + + // Parse the server's JSON response + var uploads map[string][]string + + err = resp.ScanJSON(&uploads) + + assert.NoError(t, err, "Expect no error on parsing response") + + // Validate that only the second file was uploaded + assert.NotContains(t, uploads, "file1", "file1 should have been removed from the uploads") + assert.Contains(t, uploads, "file2", "file2 should be present in the uploads") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..4e21b53 --- /dev/null +++ b/go.mod @@ -0,0 +1,34 @@ +module github.com/theopenlane/httpsling + +go 1.23.0 + +require ( + github.com/bytedance/sonic v1.12.2 + github.com/goccy/go-yaml v1.12.0 + github.com/google/go-querystring v1.1.0 + github.com/stretchr/testify v1.9.0 + github.com/theopenlane/utils v0.1.1 + github.com/valyala/bytebufferpool v1.0.0 +) + +require ( + github.com/bytedance/sonic/loader v0.2.0 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fatih/color v1.17.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/validator/v10 v10.15.5 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.8 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/theopenlane/echox v0.1.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/sys v0.23.0 // indirect + golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6bd3eaf --- /dev/null +++ b/go.sum @@ -0,0 +1,83 @@ +github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg= +github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= +github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= +github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.15.5 h1:LEBecTWb/1j5TNY1YYG2RcOUN3R7NLylN+x8TTueE24= +github.com/go-playground/validator/v10 v10.15.5/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/goccy/go-yaml v1.12.0 h1:/1WHjnMsI1dlIBQutrvSMGZRQufVO3asrHfTwfACoPM= +github.com/goccy/go-yaml v1.12.0/go.mod h1:wKnAMd44+9JAAnGQpWVEgBzGt3YuTaQ4uXoHvE4m7WU= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= +github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/theopenlane/echox v0.1.0 h1:y4Z2shaODCLwXHsHBrY/EkH/2sIuo49xdIfxx7h+Zvg= +github.com/theopenlane/echox v0.1.0/go.mod h1:RaynhPvY9qbLOVlcO7Js1NqZ66+CP9hVBa0c7ehNYA4= +github.com/theopenlane/utils v0.1.1 h1:GoPrIE8tmmC1VGlp+QmVTvrgBlHwe8e8FqLw2IPdgmY= +github.com/theopenlane/utils v0.1.1/go.mod h1:37sJeeuIsmMbMFE2nKglmEQUJenTccxh5WxkJtyuZUw= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/headers.go b/headers.go new file mode 100644 index 0000000..ab138df --- /dev/null +++ b/headers.go @@ -0,0 +1,176 @@ +package httpsling + +const ( + // Authentication + HeaderAuthorization = "Authorization" + HeaderProxyAuthenticate = "Proxy-Authenticate" + HeaderProxyAuthorization = "Proxy-Authorization" + HeaderWWWAuthenticate = "WWW-Authenticate" + + // Caching + HeaderAge = "Age" + HeaderCacheControl = "Cache-Control" + HeaderClearSiteData = "Clear-Site-Data" + HeaderExpires = "Expires" + HeaderPragma = "Pragma" + HeaderWarning = "Warning" + + // Client hints + HeaderAcceptCH = "Accept-CH" + HeaderAcceptCHLifetime = "Accept-CH-Lifetime" + HeaderContentDPR = "Content-DPR" + HeaderDPR = "DPR" + HeaderEarlyData = "Early-Data" + HeaderSaveData = "Save-Data" + HeaderViewportWidth = "Viewport-Width" + HeaderWidth = "Width" + + // Conditionals + HeaderETag = "ETag" + HeaderIfMatch = "If-Match" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderIfNoneMatch = "If-None-Match" + HeaderIfUnmodifiedSince = "If-Unmodified-Since" + HeaderLastModified = "Last-Modified" + HeaderVary = "Vary" + + // Connection management + HeaderConnection = "Connection" + HeaderKeepAlive = "Keep-Alive" + HeaderProxyConnection = "Proxy-Connection" + + // Content negotiation + HeaderAccept = "Accept" + HeaderAcceptCharset = "Accept-Charset" + HeaderAcceptEncoding = "Accept-Encoding" + HeaderAcceptLanguage = "Accept-Language" + + // Controls + HeaderCookie = "Cookie" + HeaderExpect = "Expect" + HeaderMaxForwards = "Max-Forwards" + HeaderSetCookie = "Set-Cookie" + + // CORS + HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" + HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" + HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" + HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" + HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" + HeaderAccessControlMaxAge = "Access-Control-Max-Age" + HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" + HeaderAccessControlRequestMethod = "Access-Control-Request-Method" + HeaderOrigin = "Origin" + HeaderTimingAllowOrigin = "Timing-Allow-Origin" + HeaderXPermittedCrossDomainPolicies = "X-Permitted-Cross-Domain-Policies" + + // Do Not Track + HeaderDNT = "DNT" + HeaderTk = "Tk" + + // Downloads + HeaderContentDisposition = "Content-Disposition" + + // Message body information + HeaderContentEncoding = "Content-Encoding" + HeaderContentLanguage = "Content-Language" + HeaderContentLength = "Content-Length" + HeaderContentLocation = "Content-Location" + HeaderContentType = "Content-Type" + + // Content Types + ContentTypeForm = "application/x-www-form-urlencoded" // https://datatracker.ietf.org/doc/html/rfc1866 + ContentTypeMultipart = "multipart/form-data" // https://datatracker.ietf.org/doc/html/rfc2388 + ContentTypeJSON = "application/json" // https://datatracker.ietf.org/doc/html/rfc4627 + ContentTypeJSONUTF8 = "application/json;charset=utf-8" // https://datatracker.ietf.org/doc/html/rfc4627 + ContentTypeXML = "application/xml" // https://datatracker.ietf.org/doc/html/rfc3023 + ContentTypeXMLUTF8 = "application/xml;charset=utf-8" + ContentTypeYAML = "application/yaml" // https://www.rfc-editor.org/rfc/rfc9512.html + ContentTypeYAMLUTF8 = "application/yaml;charset=utf-8" + ContentTypeText = "text/plain" + ContentTypeApplicationOctetStream = "application/octet-stream" + + // Proxies + HeaderForwarded = "Forwarded" + HeaderVia = "Via" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedHost = "X-Forwarded-Host" + HeaderXForwardedProto = "X-Forwarded-Proto" + + // Redirects + HeaderLocation = "Location" + + // Request context + HeaderFrom = "From" + HeaderHost = "Host" + HeaderReferer = "Referer" + HeaderReferrerPolicy = "Referrer-Policy" + HeaderUserAgent = "User-Agent" + + // Response context + HeaderAllow = "Allow" + HeaderServer = "Server" + + // Range requests. + HeaderAcceptRanges = "Accept-Ranges" + HeaderContentRange = "Content-Range" + HeaderIfRange = "If-Range" + HeaderRange = "Range" + + // Security + HeaderContentSecurityPolicy = "Content-Security-Policy" + HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" + HeaderCrossOriginResourcePolicy = "Cross-Origin-Resource-Policy" + HeaderExpectCT = "Expect-CT" + HeaderFeaturePolicy = "Feature-Policy" + HeaderPublicKeyPins = "Public-Key-Pins" + HeaderPublicKeyPinsReportOnly = "Public-Key-Pins-Report-Only" + HeaderStrictTransportSecurity = "Strict-Transport-Security" + HeaderUpgradeInsecureRequests = "Upgrade-Insecure-Requests" + HeaderXContentTypeOptions = "X-Content-Type-Options" + HeaderXDownloadOptions = "X-Download-Options" + HeaderXFrameOptions = "X-Frame-Options" + HeaderXPoweredBy = "X-Powered-By" + HeaderXXSSProtection = "X-XSS-Protection" + + // Server-sent event + HeaderLastEventID = "Last-Event-ID" + HeaderNEL = "NEL" + HeaderPingFrom = "Ping-From" + HeaderPingTo = "Ping-To" + HeaderReportTo = "Report-To" + + // Transfer coding + HeaderTE = "TE" + HeaderTrailer = "Trailer" + HeaderTransferEncoding = "Transfer-Encoding" + + // WebSockets + HeaderSecWebSocketAccept = "Sec-WebSocket-Accept" + HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions" /* #nosec G101 */ + HeaderSecWebSocketKey = "Sec-WebSocket-Key" + HeaderSecWebSocketProtocol = "Sec-WebSocket-Protocol" + HeaderSecWebSocketVersion = "Sec-WebSocket-Version" + + // Other + HeaderAcceptPatch = "Accept-Patch" + HeaderAcceptPushPolicy = "Accept-Push-Policy" + HeaderAcceptSignature = "Accept-Signature" + HeaderAltSvc = "Alt-Svc" + HeaderDate = "Date" + HeaderIndex = "Index" + HeaderLargeAllocation = "Large-Allocation" + HeaderLink = "Link" + HeaderPushPolicy = "Push-Policy" + HeaderRetryAfter = "Retry-After" + HeaderServerTiming = "Server-Timing" + HeaderSignature = "Signature" + HeaderSignedHeaders = "Signed-Headers" + HeaderSourceMap = "SourceMap" + HeaderUpgrade = "Upgrade" + HeaderXDNSPrefetchControl = "X-DNS-Prefetch-Control" + HeaderXPingback = "X-Pingback" + HeaderXRequestedWith = "X-Requested-With" + HeaderXRobotsTag = "X-Robots-Tag" + HeaderXUACompatible = "X-UA-Compatible" +) diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..f3c9713 --- /dev/null +++ b/logger.go @@ -0,0 +1,85 @@ +package httpsling + +import ( + "fmt" + "io" + "log/slog" +) + +// Level is a type that represents the log level +type Level int + +// The levels of logs +const ( + LevelDebug Level = iota + LevelInfo + LevelWarn + LevelError +) + +// Logger is a logger interface that output logs with a format +type Logger interface { + Debugf(format string, v ...any) + Infof(format string, v ...any) + Warnf(format string, v ...any) + Errorf(format string, v ...any) + SetLevel(level Level) +} + +// DefaultLogger is a default logger that uses `slog` as the underlying logger +type DefaultLogger struct { + logger *slog.Logger + level *slog.LevelVar +} + +// Debugf logs a message at the Debug level +func (l *DefaultLogger) Debugf(format string, v ...any) { + l.logger.Debug(fmt.Sprintf(format, v...)) +} + +// Infof logs a message at the Info level +func (l *DefaultLogger) Infof(format string, v ...any) { + l.logger.Info(fmt.Sprintf(format, v...)) +} + +// Warnf logs a message at the Warn level +func (l *DefaultLogger) Warnf(format string, v ...any) { + l.logger.Warn(fmt.Sprintf(format, v...)) +} + +// Errorf logs a message at the Error level +func (l *DefaultLogger) Errorf(format string, v ...any) { + l.logger.Error(fmt.Sprintf(format, v...)) +} + +// SetLevel sets the log level of the logger +func (l *DefaultLogger) SetLevel(level Level) { + switch level { + case LevelDebug: + l.level.Set(slog.LevelDebug) + case LevelInfo: + l.level.Set(slog.LevelInfo) + case LevelWarn: + l.level.Set(slog.LevelWarn) + case LevelError: + l.level.Set(slog.LevelError) + } +} + +// NewDefaultLogger creates a new `DefaultLogger` with the given output and log level +func NewDefaultLogger(output io.Writer, level Level) Logger { + levelVar := &slog.LevelVar{} + + textHandler := slog.NewTextHandler(output, &slog.HandlerOptions{ + Level: levelVar, + }) + + logger := &DefaultLogger{ + logger: slog.New(textHandler), + level: levelVar, + } + + logger.SetLevel(level) + + return logger +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..5c8e66b --- /dev/null +++ b/logger_test.go @@ -0,0 +1,102 @@ +package httpsling + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type mockLoggerRecorder struct { + Records []string +} + +// Write writes the given bytes to the recorder +func (m *mockLoggerRecorder) Write(p []byte) (n int, err error) { + m.Records = append(m.Records, string(p)) + + return len(p), nil +} + +func TestDefaultLoggerLevels(t *testing.T) { + rec := &mockLoggerRecorder{} + logger := NewDefaultLogger(rec, LevelDebug) + + logger.Debugf("debug %s", "message") + logger.Infof("info %s", "message") + logger.Warnf("warn %s", "message") + logger.Errorf("error %s", "message") + + assert.Len(t, rec.Records, 4, "Should log 4 messages") + assert.Contains(t, rec.Records[0], "debug message", "Debug log message should match") + assert.Contains(t, rec.Records[1], "info message", "Info log message should match") + assert.Contains(t, rec.Records[2], "warn message", "Warn log message should match") + assert.Contains(t, rec.Records[3], "error message", "Error log message should match") +} + +type mockLogger struct { + Infos []string + Errors []string +} + +func (m *mockLogger) Debugf(format string, v ...any) { + m.Infos = append(m.Infos, fmt.Sprintf(format, v...)) +} +func (m *mockLogger) Infof(format string, v ...any) { + m.Infos = append(m.Infos, fmt.Sprintf(format, v...)) +} +func (m *mockLogger) Warnf(format string, v ...any) { + m.Infos = append(m.Infos, fmt.Sprintf(format, v...)) +} +func (m *mockLogger) Errorf(format string, v ...any) { + m.Errors = append(m.Errors, fmt.Sprintf(format, v...)) +} + +func (m *mockLogger) SetLevel(level Level) {} +func TestRetryLogMessage(t *testing.T) { + // Initialize attempt counter + var attempts int + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Fail initially to trigger a retry + w.WriteHeader(http.StatusServiceUnavailable) + } else { + // Succeed in the next attempt + w.WriteHeader(http.StatusOK) + } + })) + + defer server.Close() + + mockLogger := &mockLogger{} + client := Create(&Config{ + BaseURL: server.URL, + Logger: mockLogger, + }).SetMaxRetries(1).SetRetryStrategy(func(attempt int) time.Duration { + return 0 // No delay for testing + }) + + // Making a request that should trigger a retry + _, err := client.Get("/test").Send(context.Background()) + assert.Nil(t, err, "Did not expect an error after retry") + + // Check if the retry log message was recorded + expectedLogMessage := "Retrying request (attempt 1) after backoff" + found := false + + for _, logMsg := range mockLogger.Infos { + if strings.Contains(logMsg, expectedLogMessage) { + found = true + break + } + } + + assert.True(t, found, "Expected retry log message was not recorded") +} diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..6dd5323 --- /dev/null +++ b/middleware.go @@ -0,0 +1,9 @@ +package httpsling + +import "net/http" + +// MiddlewareHandlerFunc defines a function that takes an http.Request and returns an http.Response +type MiddlewareHandlerFunc func(req *http.Request) (*http.Response, error) + +// Middleware takes MiddlewareHandlerFunc and wraps around a next function call, which can be another middleware or the final transport layer call +type Middleware func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..d3f8339 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,224 @@ +package httpsling + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "testing" +) + +// TestMiddleware ensures that the Middleware correctly applies middleware to outgoing httpsling +func TestMiddleware(t *testing.T) { + // Set up a mock server to inspect incoming httpsling + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for the custom header added by our middleware + if r.Header.Get("X-Custom-Header") != "true" { + t.Errorf("Expected custom header 'X-Custom-Header' to be 'true', got '%s'", r.Header.Get("X-Custom-Header")) + w.WriteHeader(http.StatusBadRequest) // Indicate a bad request if header is missing + + return + } + + w.WriteHeader(http.StatusOK) // All good if the header is present + })) + + defer mockServer.Close() + + // Define the middleware that adds a custom header + customHeaderMiddleware := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + // Add the custom header + req.Header.Set("X-Custom-Header", "true") + // Proceed with the next middleware or the actual request + return next(req) + } + } + + // Initialize the client with our custom middleware + client := Create(&Config{ + BaseURL: mockServer.URL, // Use our mock server as the base URL + Transport: http.DefaultTransport, // Use the default transport + Middlewares: []Middleware{customHeaderMiddleware}, // Apply our custom header middleware + }) + + // Create an HTTP request object + resp, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Close() //nolint: errcheck + + // Check if the server responded with a 200 OK, indicating the middleware applied the header successfully + if resp.StatusCode() != http.StatusOK { + t.Errorf("Expected status code 200, got %d", resp.StatusCode()) + } +} + +func TestNestedMiddleware(t *testing.T) { + var buf bytes.Buffer + + mid0 := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + buf.WriteString("0>>") + + resp, err := next(req) + + buf.WriteString(">>0") + + return resp, err + } + } + + mid1 := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + buf.WriteString("1>>") + + resp, err := next(req) + + buf.WriteString(">>1") + + return resp, err + } + } + + mid2 := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + buf.WriteString("2>>") + + resp, err := next(req) + + buf.WriteString(">>2") + + return resp, err + } + } + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + buf.WriteString("(served)") + w.WriteHeader(http.StatusOK) + })) + + defer mockServer.Close() + + client := Create(&Config{ + BaseURL: mockServer.URL, + Middlewares: []Middleware{mid0, mid1, mid2}, + }) + + // Create an HTTP request object + resp, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + defer resp.Close() //nolint: errcheck + + expected := "0>>1>>2>>(served)>>2>>1>>0" + if buf.String() != expected { + t.Errorf("Expected sequence %s, got %s", expected, buf.String()) + } +} + +// TestDynamicMiddlewareAddition tests the dynamic addition of middleware to the client +func TestDynamicMiddlewareAddition(t *testing.T) { + // Buffer to track middleware execution order + var executionOrder bytes.Buffer + + // Define middleware functions + loggingMiddleware := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + executionOrder.WriteString("Logging>") + return next(req) + } + } + + authenticationMiddleware := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + executionOrder.WriteString("Auth>") + return next(req) + } + } + + // Set up a mock server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executionOrder.WriteString("Handler") + w.WriteHeader(http.StatusOK) + })) + + defer mockServer.Close() + + // Create a new client + client := Create(&Config{ + BaseURL: mockServer.URL, + }) + + // Dynamically add middleware + client.AddMiddleware(loggingMiddleware) + client.AddMiddleware(authenticationMiddleware) + + // Make a request to the mock server + _, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + // Check the order of middleware execution + expectedOrder := "Logging>Auth>Handler" + if executionOrder.String() != expectedOrder { + t.Errorf("Middleware executed in incorrect order. Expected %s, got %s", expectedOrder, executionOrder.String()) + } +} + +// TestRequestMiddlewareAddition tests the addition of middleware at the request level, +// and ensures that both client and request level middlewares are executed in the correct order +func TestRequestMiddlewareAddition(t *testing.T) { + // Buffer to track middleware execution order + var executionOrder bytes.Buffer + + // Define client-level middleware + clientLoggingMiddleware := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + executionOrder.WriteString("ClientLogging>") + return next(req) + } + } + + // Define request-level middleware + requestAuthMiddleware := func(next MiddlewareHandlerFunc) MiddlewareHandlerFunc { + return func(req *http.Request) (*http.Response, error) { + executionOrder.WriteString("RequestAuth>") + return next(req) + } + } + + // Set up a mock server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executionOrder.WriteString("Handler") + w.WriteHeader(http.StatusOK) + })) + + defer mockServer.Close() + + // Create a new client with client-level middleware + client := Create(&Config{ + BaseURL: mockServer.URL, + Middlewares: []Middleware{clientLoggingMiddleware}, // Apply client-level middleware + }) + + // Create a request and dynamically add request-level middleware + reqBuilder := client.Get("/") + reqBuilder.AddMiddleware(requestAuthMiddleware) // Apply request-level middleware + + // Make a request to the mock server + _, err := reqBuilder.Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + // Check the order of middleware execution + expectedOrder := "ClientLogging>RequestAuth>Handler" + if executionOrder.String() != expectedOrder { + t.Errorf("Middleware executed in incorrect order. Expected %s, got %s", expectedOrder, executionOrder.String()) + } +} diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..fbce0bd --- /dev/null +++ b/proxy.go @@ -0,0 +1,69 @@ +package httpsling + +import ( + "net/http" + "net/url" + + "github.com/theopenlane/utils/rout" +) + +// verifyProxy validates the given proxy URL, supporting http, https, and socks5 schemes +func verifyProxy(proxyURL string) (*url.URL, error) { + parsedURL, err := url.Parse(proxyURL) + if err != nil { + return nil, err + } + + // Check if the scheme is supported + switch parsedURL.Scheme { + case "http", "https", "socks5": + return parsedURL, nil + default: + return nil, ErrUnsupportedScheme + } +} + +// SetProxy configures the client to use a proxy. Supports http, https, and socks5 proxies +func (c *Client) SetProxy(proxyURL string) error { + c.mu.Lock() + defer c.mu.Unlock() + + // Validate and parse the proxy URL + validatedProxyURL, err := verifyProxy(proxyURL) + if err != nil { + return err + } + + // Ensure the HTTPClient's Transport is properly initialized + if c.HTTPClient.Transport == nil { + c.HTTPClient.Transport = &http.Transport{} + } + + // Assert the Transport to *http.Transport to access the Proxy field + transport, ok := c.HTTPClient.Transport.(*http.Transport) + if !ok { + return rout.HTTPErrorResponse(err) + } + + transport.Proxy = http.ProxyURL(validatedProxyURL) + + return nil +} + +// RemoveProxy clears any configured proxy, allowing direct connections +func (c *Client) RemoveProxy() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.HTTPClient.Transport == nil { + return + } + + transport, ok := c.HTTPClient.Transport.(*http.Transport) + + if !ok { + return // If it's not *http.Transport, it doesn't have a proxy to remove + } + + transport.Proxy = nil +} diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..e35f344 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,83 @@ +package httpsling + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +// createTestServerForProxy creates a simple HTTP server for testing purposes +func createTestServerForProxy() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) +} + +// TestSetProxyValidProxy tests setting a valid proxy and making a request through it +func TestSetProxyValidProxy(t *testing.T) { + server := createTestServerForProxy() + + defer server.Close() + + proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Indicate the request passed through the proxy + w.Header().Set("X-Test-Proxy", "true") + w.WriteHeader(http.StatusOK) + })) + + defer proxyServer.Close() + + client := URL(server.URL) + + err := client.SetProxy(proxyServer.URL) + assert.Nil(t, err, "Setting a valid proxy should not result in an error") + + resp, err := client.Get("/").Send(context.Background()) + assert.Nil(t, err, "Request through a valid proxy should succeed") + assert.NotNil(t, resp, "Response should not be nil") + assert.Equal(t, "true", resp.Header().Get("X-Test-Proxy"), "Request should have passed through the proxy") +} + +// TestSetProxyInvalidProxy tests handling of invalid proxy URLs +func TestSetProxyInvalidProxy(t *testing.T) { + server := createTestServerForProxy() + + defer server.Close() + + client := URL(server.URL) + + invalidProxyURL := "://invalid_url" + err := client.SetProxy(invalidProxyURL) + assert.NotNil(t, err, "Setting an invalid proxy URL should result in an error") +} + +// TestSetProxyRemoveProxy tests removing proxy settings +func TestSetProxyRemoveProxy(t *testing.T) { + server := createTestServerForProxy() + + defer server.Close() + + proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Proxy server response + w.WriteHeader(http.StatusOK) + })) + + defer proxyServer.Close() + + client := URL(server.URL) + + // Set then remove the proxy + err := client.SetProxy(proxyServer.URL) + assert.Nil(t, err, "Setting a proxy should not result in an error") + + client.RemoveProxy() + + // Make a request and check it doesn't go through the proxy + resp, err := client.Get("/").Send(context.Background()) + assert.Nil(t, err, "Request after removing proxy should succeed") + assert.NotNil(t, resp, "Response should not be ni.") + assert.NotEqual(t, "true", resp.Header().Get("X-Test-Proxy"), "Request should not have passed through the proxy") +} diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..ec944db --- /dev/null +++ b/renovate.json @@ -0,0 +1,11 @@ +{ + "extends": [ + "config:base" + ], + "postUpdateOptions": [ + "gomodTidy" + ], + "labels": [ + "dependencies" + ] +} \ No newline at end of file diff --git a/request.go b/request.go new file mode 100644 index 0000000..5ddf519 --- /dev/null +++ b/request.go @@ -0,0 +1,841 @@ +package httpsling + +import ( + "bytes" + "context" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/google/go-querystring/query" + + "github.com/theopenlane/utils/rout" +) + +// RequestBuilder facilitates building and executing HTTP requests +type RequestBuilder struct { + // client is the HTTP client instance + client *Client + // method is the HTTP method for the request + method string + // path is the URL path for the request + path string + // headers contains the request headers + headers *http.Header + // cookies contains the request cookies + cookies []*http.Cookie + // queries contains the request query parameters + queries url.Values + // pathParams contains the request path parameters + pathParams map[string]string + // formFields contains the request form fields + formFields url.Values + // formFiles contains the request form files + formFiles []*File + // boundary is the custom boundary for multipart requests + boundary string + // bodyData is the request body + bodyData interface{} + // timeout is the request timeout + timeout time.Duration + // middlewares contains the request middlewares + middlewares []Middleware + // maxRetries is the maximum number of retry attempts + maxRetries int + // retryStrategy is the backoff strategy for retries + retryStrategy BackoffStrategy + // retryIf is the custom retry condition function + retryIf RetryIfFunc + // auth is the authentication method for the request + auth AuthMethod + // stream is the stream callback for the request + stream StreamCallback + // streamErr is the error callback for the request + streamErr StreamErrCallback + // streamDone is the done callback for the request + streamDone StreamDoneCallback + // BeforeRequest is a hook that can be used to modify the request object + // before the request has been fired. This is useful for adding authentication + // and other functionality not provided in this library + BeforeRequest func(req *http.Request) error +} + +// NewRequestBuilder creates a new RequestBuilder with default settings +func (c *Client) NewRequestBuilder(method, path string) *RequestBuilder { + rb := &RequestBuilder{ + client: c, + method: method, + path: path, + queries: url.Values{}, + headers: &http.Header{}, + } + + if c.Headers != nil { + rb.headers = c.Headers + } + + return rb +} + +// AddMiddleware adds a middleware to the request +func (b *RequestBuilder) AddMiddleware(middlewares ...Middleware) { + if b.middlewares == nil { + b.middlewares = []Middleware{} + } + + b.middlewares = append(b.middlewares, middlewares...) +} + +// Method sets the HTTP method for the request +func (b *RequestBuilder) Method(method string) *RequestBuilder { + b.method = method + + return b +} + +// Path sets the URL path for the request +func (b *RequestBuilder) Path(path string) *RequestBuilder { + b.path = path + + return b +} + +// PathParams sets multiple path params fields and their values at one go in the RequestBuilder instance +func (b *RequestBuilder) PathParams(params map[string]string) *RequestBuilder { + if b.pathParams == nil { + b.pathParams = map[string]string{} + } + + for key, value := range params { + b.pathParams[key] = value + } + + return b +} + +// PathParam sets a single path param field and its value in the RequestBuilder instance +func (b *RequestBuilder) PathParam(key, value string) *RequestBuilder { + if b.pathParams == nil { + b.pathParams = map[string]string{} + } + + b.pathParams[key] = value + + return b +} + +// DelPathParam removes one or more path params fields from the RequestBuilder instance +func (b *RequestBuilder) DelPathParam(key ...string) *RequestBuilder { + if b.pathParams != nil { + for _, k := range key { + delete(b.pathParams, k) + } + } + + return b +} + +// preparePath replaces path parameters in the URL path +func (b *RequestBuilder) preparePath() string { + if b.pathParams == nil { + return b.path + } + + preparedPath := b.path + + for key, value := range b.pathParams { + placeholder := "{" + key + "}" + preparedPath = strings.ReplaceAll(preparedPath, placeholder, url.PathEscape(value)) + } + + return preparedPath +} + +// Queries adds query parameters to the request +func (b *RequestBuilder) Queries(params url.Values) *RequestBuilder { + for key, values := range params { + for _, value := range values { + b.queries.Add(key, value) + } + } + + return b +} + +// Query adds a single query parameter to the request +func (b *RequestBuilder) Query(key, value string) *RequestBuilder { + b.queries.Add(key, value) + + return b +} + +// DelQuery removes one or more query parameters from the request +func (b *RequestBuilder) DelQuery(key ...string) *RequestBuilder { + for _, k := range key { + b.queries.Del(k) + } + + return b +} + +// QueriesStruct adds query parameters to the request based on a struct tagged with url tags +func (b *RequestBuilder) QueriesStruct(queryStruct interface{}) *RequestBuilder { + values, _ := query.Values(queryStruct) // safely ignore error for simplicity + + for key, value := range values { + for _, v := range value { + b.queries.Add(key, v) + } + } + + return b +} + +// Headers set headers to the request +func (b *RequestBuilder) Headers(headers http.Header) *RequestBuilder { + for key, values := range headers { + for _, value := range values { + b.headers.Set(key, value) + } + } + + return b +} + +// Header sets (or replaces) a header in the request +func (b *RequestBuilder) Header(key, value string) *RequestBuilder { + b.headers.Set(key, value) + + return b +} + +// AddHeader adds a header to the request +func (b *RequestBuilder) AddHeader(key, value string) *RequestBuilder { + b.headers.Add(key, value) + + return b +} + +// DelHeader removes one or more headers from the request +func (b *RequestBuilder) DelHeader(key ...string) *RequestBuilder { + for _, k := range key { + b.headers.Del(k) + } + + return b +} + +// Cookies method for map +func (b *RequestBuilder) Cookies(cookies map[string]string) *RequestBuilder { + for key, value := range cookies { + b.Cookie(key, value) + } + + return b +} + +// Cookie adds a cookie to the request +func (b *RequestBuilder) Cookie(key, value string) *RequestBuilder { + if b.cookies == nil { + b.cookies = []*http.Cookie{} + } + + b.cookies = append(b.cookies, &http.Cookie{Name: key, Value: value}) + + return b +} + +// DelCookie removes one or more cookies from the request +func (b *RequestBuilder) DelCookie(key ...string) *RequestBuilder { + if b.cookies != nil { + for i, cookie := range b.cookies { + if slices.Contains(key, cookie.Name) { + b.cookies = append(b.cookies[:i], b.cookies[i+1:]...) + } + } + } + + return b +} + +// ContentType sets the Content-Type header for the request +func (b *RequestBuilder) ContentType(contentType string) *RequestBuilder { + b.headers.Set(HeaderContentType, contentType) + + return b +} + +// Accept sets the Accept header for the request +func (b *RequestBuilder) Accept(accept string) *RequestBuilder { + b.headers.Set(HeaderAccept, accept) + + return b +} + +// UserAgent sets the User-Agent header for the request +func (b *RequestBuilder) UserAgent(userAgent string) *RequestBuilder { + b.headers.Set(HeaderUserAgent, userAgent) + + return b +} + +// Referer sets the Referer header for the request +func (b *RequestBuilder) Referer(referer string) *RequestBuilder { + b.headers.Set(HeaderReferer, referer) + + return b +} + +// Auth applies an authentication method to the request +func (b *RequestBuilder) Auth(auth AuthMethod) *RequestBuilder { + if auth.Valid() { + b.auth = auth + } + + return b +} + +// Form sets form fields and files for the request +func (b *RequestBuilder) Form(v any) *RequestBuilder { + formFields, formFiles, err := parseForm(v) + + if err != nil { + if b.client.Logger != nil { + b.client.Logger.Errorf("Error parsing form: %v", err) + } + + return b + } + + if formFields != nil { + b.formFields = formFields + } + + if formFiles != nil { + b.formFiles = formFiles + } + + return b +} + +// FormFields sets multiple form fields at once +func (b *RequestBuilder) FormFields(fields any) *RequestBuilder { + if b.formFields == nil { + b.formFields = url.Values{} + } + + values, err := parseFormFields(fields) + if err != nil { + if b.client.Logger != nil { + b.client.Logger.Errorf("Error parsing form fields: %v", err) + } + + return b + } + + for key, value := range values { + for _, v := range value { + b.formFields.Add(key, v) + } + } + + return b +} + +// FormField adds or updates a form field +func (b *RequestBuilder) FormField(key, val string) *RequestBuilder { + if b.formFields == nil { + b.formFields = url.Values{} + } + + b.formFields.Add(key, val) + + return b +} + +// DelFormField removes one or more form fields +func (b *RequestBuilder) DelFormField(key ...string) *RequestBuilder { + if b.formFields != nil { + for _, k := range key { + b.formFields.Del(k) + } + } + + return b +} + +// Files sets multiple files at once +func (b *RequestBuilder) Files(files ...*File) *RequestBuilder { + if b.formFiles == nil { + b.formFiles = []*File{} + } + + b.formFiles = append(b.formFiles, files...) + + return b +} + +// File adds a file to the request +func (b *RequestBuilder) File(key, filename string, content io.ReadCloser) *RequestBuilder { + if b.formFiles == nil { + b.formFiles = []*File{} + } + + b.formFiles = append(b.formFiles, &File{ + Name: key, + FileName: filename, + Content: content, + }) + + return b +} + +// DelFile removes one or more files from the request +func (b *RequestBuilder) DelFile(key ...string) *RequestBuilder { + if b.formFiles != nil { + for i, file := range b.formFiles { + if slices.Contains(key, file.Name) { + b.formFiles = append(b.formFiles[:i], b.formFiles[i+1:]...) + } + } + } + + return b +} + +// Body sets the request body +func (b *RequestBuilder) Body(body interface{}) *RequestBuilder { + b.bodyData = body + + return b +} + +// JSONBody sets the request body as JSON +func (b *RequestBuilder) JSONBody(v interface{}) *RequestBuilder { + b.bodyData = v + b.headers.Set(HeaderContentType, ContentTypeJSON) + + return b +} + +// XMLBody sets the request body as XML +func (b *RequestBuilder) XMLBody(v interface{}) *RequestBuilder { + b.bodyData = v + b.headers.Set(HeaderContentType, ContentTypeXML) + + return b +} + +// YAMLBody sets the request body as YAML +func (b *RequestBuilder) YAMLBody(v interface{}) *RequestBuilder { + b.bodyData = v + b.headers.Set(HeaderContentType, ContentTypeYAML) + + return b +} + +// TextBody sets the request body as plain text +func (b *RequestBuilder) TextBody(v string) *RequestBuilder { + b.bodyData = v + b.headers.Set(HeaderContentType, ContentTypeText) + + return b +} + +// RawBody sets the request body as raw bytes +func (b *RequestBuilder) RawBody(v []byte) *RequestBuilder { + b.bodyData = v + + return b +} + +// Timeout sets the request timeout +func (b *RequestBuilder) Timeout(timeout time.Duration) *RequestBuilder { + b.timeout = timeout + + return b +} + +// MaxRetries sets the maximum number of retry attempts +func (b *RequestBuilder) MaxRetries(maxRetries int) *RequestBuilder { + b.maxRetries = maxRetries + + return b +} + +// RetryStrategy sets the backoff strategy for retries +func (b *RequestBuilder) RetryStrategy(strategy BackoffStrategy) *RequestBuilder { + b.retryStrategy = strategy + + return b +} + +// RetryIf sets the custom retry condition function +func (b *RequestBuilder) RetryIf(retryIf RetryIfFunc) *RequestBuilder { + b.retryIf = retryIf + + return b +} + +func (b *RequestBuilder) do(ctx context.Context, req *http.Request) (*http.Response, error) { + finalHandler := MiddlewareHandlerFunc(func(req *http.Request) (*http.Response, error) { + var maxRetries = b.client.MaxRetries + if b.maxRetries > 0 { + maxRetries = b.maxRetries + } + + var retryStrategy = b.client.RetryStrategy + + if b.retryStrategy != nil { + retryStrategy = b.retryStrategy + } + + var retryIf = b.client.RetryIf + + if b.retryIf != nil { + retryIf = b.retryIf + } + + if maxRetries < 1 { + return b.client.HTTPClient.Do(req) + } + + var lastErr error + + var resp *http.Response + + for attempt := 0; attempt <= maxRetries; attempt++ { + resp, lastErr = b.client.HTTPClient.Do(req) + + shouldRetry := lastErr != nil || (resp != nil && retryIf != nil && retryIf(req, resp, lastErr)) + if !shouldRetry || attempt == maxRetries { + if lastErr != nil { + if b.client.Logger != nil { + b.client.Logger.Errorf("Error after %d attempts: %v", attempt+1, lastErr) + } + } + + break + } + + if resp != nil { + if err := resp.Body.Close(); err != nil { + if b.client.Logger != nil { + b.client.Logger.Errorf("Error closing response body: %v", err) + } + } + } + + if b.client.Logger != nil { + b.client.Logger.Infof("Retrying request (attempt %d) after backoff", attempt+1) + } + + // Logging context cancellation as an error condition + select { + case <-ctx.Done(): + if b.client.Logger != nil { + b.client.Logger.Errorf("Request canceled or timed out: %v", ctx.Err()) + } + + return nil, ctx.Err() + + case <-time.After(retryStrategy(attempt)): + } + } + + return resp, lastErr + }) + + if b.middlewares != nil { + for i := len(b.middlewares) - 1; i >= 0; i-- { + finalHandler = b.middlewares[i](finalHandler) + } + } + + if b.client.Middlewares != nil { + for i := len(b.client.Middlewares) - 1; i >= 0; i-- { + finalHandler = b.client.Middlewares[i](finalHandler) + } + } + + return finalHandler(req) +} + +// Stream sets the stream callback for the request +func (b *RequestBuilder) Stream(callback StreamCallback) *RequestBuilder { + b.stream = callback + + return b +} + +// StreamErr sets the error callback for the request. +func (b *RequestBuilder) StreamErr(callback StreamErrCallback) *RequestBuilder { + b.streamErr = callback + + return b +} + +// StreamDone sets the done callback for the request. +func (b *RequestBuilder) StreamDone(callback StreamDoneCallback) *RequestBuilder { + b.streamDone = callback + + return b +} + +func (b *RequestBuilder) setContentType() (io.Reader, string, error) { + var body io.Reader + + var contentType string + + var err error + + switch { + case len(b.formFiles) > 0: + // If the request includes files, indicating multipart/form-data encoding is required + body, contentType, err = b.prepareMultipartBody() + + case len(b.formFields) > 0: + // For form fields without files, use application/x-www-form-urlencoded encoding + body, contentType = b.prepareFormFieldsBody() + + case b.bodyData != nil: + // Fallback to handling as per original logic for JSON, XML, etc + body, contentType, err = b.prepareBodyBasedOnContentType() + } + + if err != nil { + if b.client.Logger != nil { + // surface to the client logger as well + b.client.Logger.Errorf("Error preparing request body: %v", err) + } + + return nil, contentType, err + } + + if contentType != "" { + b.headers.Set(HeaderContentType, contentType) + } + + return body, contentType, nil +} + +func (b *RequestBuilder) requestChecks(req *http.Request) *http.Request { + // apply the authentication method to the request + if b.auth != nil { + b.auth.Apply(req) + } else if b.client.Auth != nil { + b.client.Auth.Apply(req) + } + + // set the headers from the client + if b.client.Headers != nil { + for key := range *b.client.Headers { + values := (*b.client.Headers)[key] + for _, value := range values { + req.Header.Set(key, value) + } + } + } + // set the headers from the request builder + if b.headers != nil { + for key := range *b.headers { + values := (*b.headers)[key] + for _, value := range values { + req.Header.Set(key, value) + } + } + } + + // merge cookies from the client + if b.client.Cookies != nil { + for _, cookie := range b.client.Cookies { + req.AddCookie(cookie) + } + } + // merge cookies from the request builder + if b.cookies != nil { + for _, cookie := range b.cookies { + req.AddCookie(cookie) + } + } + + return req +} + +// Send executes the HTTP request +func (b *RequestBuilder) Send(ctx context.Context) (*Response, error) { + body, _, err := b.setContentType() + if err != nil { + return nil, err + } + + parsedURL, err := url.Parse(b.client.BaseURL + b.preparePath()) + if err != nil { + if b.client.Logger != nil { + // surface the error to the client logger as well + b.client.Logger.Errorf("Error parsing URL: %v", err) + } + + return nil, err + } + + query := parsedURL.Query() + + for key, values := range b.queries { + for _, value := range values { + query.Set(key, value) + } + } + + parsedURL.RawQuery = query.Encode() + + var cancel context.CancelFunc + + if _, ok := ctx.Deadline(); !ok { + if b.timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, b.timeout) + defer cancel() + } + } + + req, err := http.NewRequestWithContext(ctx, b.method, parsedURL.String(), body) + if err != nil { + if b.client.Logger != nil { + b.client.Logger.Errorf("Error creating request: %v", err) + } + + return nil, rout.HTTPErrorResponse(err) + } + + req = b.requestChecks(req) + + // Execute the HTTP request + resp, err := b.do(ctx, req) + if err != nil { + if b.client.Logger != nil { + b.client.Logger.Errorf("Error executing request: %v", err) + } + + if resp != nil { + _ = resp.Body.Close() + } + + return nil, err + } + + if resp == nil { + if b.client.Logger != nil { + b.client.Logger.Errorf("Response is nil") + } + + return nil, fmt.Errorf("%w: %v", ErrResponseNil, err) + } + + // Wrap and return the response + return NewResponse(ctx, resp, b.client, b.stream, b.streamErr, b.streamDone) +} + +func (b *RequestBuilder) prepareMultipartBody() (io.Reader, string, error) { + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + + // if a custom boundary is set, use it + if b.boundary != "" { + if err := writer.SetBoundary(b.boundary); err != nil { + return nil, "", rout.HTTPErrorResponse(err) + } + } + + // add form fields + for key, vals := range b.formFields { + for _, val := range vals { + if err := writer.WriteField(key, val); err != nil { + return nil, "", rout.HTTPErrorResponse(err) + } + } + } + + // add form files + for _, file := range b.formFiles { + // create a new multipart part for the file + part, err := writer.CreateFormFile(file.Name, file.FileName) + + if err != nil { + return nil, "", rout.HTTPErrorResponse(err) + } + // copy the file content to the part + if _, err = io.Copy(part, file.Content); err != nil { + return nil, "", rout.HTTPErrorResponse(err) + } + + // close the file content if it's a closer + if closer, ok := file.Content.(io.Closer); ok { + if err = closer.Close(); err != nil { + return nil, "", rout.HTTPErrorResponse(err) + } + } + } + + // close the multipart writer + if err := writer.Close(); err != nil { + return nil, "", rout.HTTPErrorResponse(err) + } + + return &buf, writer.FormDataContentType(), nil +} + +func (b *RequestBuilder) prepareFormFieldsBody() (io.Reader, string) { + data := b.formFields.Encode() + + return strings.NewReader(data), ContentTypeForm +} + +func (b *RequestBuilder) prepareBodyBasedOnContentType() (io.Reader, string, error) { + contentType := b.headers.Get(HeaderContentType) + + if contentType == "" && b.bodyData != nil { + switch b.bodyData.(type) { + case url.Values, map[string][]string, map[string]string: + contentType = ContentTypeForm + case map[string]interface{}, []interface{}, struct{}: + contentType = ContentTypeJSONUTF8 + case string, []byte: + contentType = ContentTypeText + } + + b.headers.Set(HeaderContentType, contentType) + } + + var body io.Reader + + var err error + + switch contentType { + case ContentTypeJSON, ContentTypeJSONUTF8: + body, err = b.client.JSONEncoder.Encode(b.bodyData) + case ContentTypeXML: + body, err = b.client.XMLEncoder.Encode(b.bodyData) + case ContentTypeYAML: + body, err = b.client.YAMLEncoder.Encode(b.bodyData) + case ContentTypeForm: + body, err = DefaultFormEncoder.Encode(b.bodyData) + case ContentTypeText, ContentTypeApplicationOctetStream: + switch data := b.bodyData.(type) { + case string: + body = strings.NewReader(data) + case []byte: + body = bytes.NewReader(data) + default: + err = fmt.Errorf("%w: %s", ErrUnsupportedContentType, contentType) + } + default: + err = fmt.Errorf("%w: %s", ErrUnsupportedContentType, contentType) + } + + return body, contentType, err +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..21e9111 --- /dev/null +++ b/request_test.go @@ -0,0 +1,889 @@ +package httpsling + +import ( + "context" + "encoding/base64" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRequestCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) // Simulate a long-running operation + fmt.Fprintln(w, "This response may never be sent") + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + ctx, cancel := context.WithCancel(context.Background()) + + defer cancel() // Ensure resources are cleaned up + + // Cancel the request after 1 second + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + + // Attempt to make a request that will be canceled + _, err := client.Get("/").Send(ctx) + if err == nil { + t.Errorf("Expected an error due to cancellation, but got none") + } + + // Check if the error is due to context cancellation + if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got %v", err) + } +} + +// TestSendMethodQuery checks the Send method for handling query parameters. +func TestSendMethodQuery(t *testing.T) { + // Start a test HTTP server. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Respond with the full URL received, including query parameters. + fmt.Fprintln(w, r.URL.String()) + })) + defer server.Close() + + // Define a client with the test server's URL. + client := Create(&Config{BaseURL: server.URL}) + + tests := []struct { + name string + url string // URL to request, may include query params + additionalQPS map[string]string // Query params added via Query method + expectedURL string // Expected URL path and query received by the server + }{ + { + name: "URL only", + url: "/test?param1=value1", + expectedURL: "/test?param1=value1", + }, + { + name: "Method only", + url: "/test", + additionalQPS: map[string]string{"param2": "value2"}, + expectedURL: "/test?param2=value2", + }, + { + name: "URL and Method", + url: "/test?param1=value1", + additionalQPS: map[string]string{"param2": "value2"}, + expectedURL: "/test?param1=value1¶m2=value2", + }, + { + name: "Method overwrites URL", + url: "/test?param1=value1", + additionalQPS: map[string]string{"param1": "value2"}, + expectedURL: "/test?param1=value2", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create a new RequestBuilder for each test case. + rb := client.NewRequestBuilder(http.MethodGet, tc.url) + + // If there are additional query params defined, add them. + if tc.additionalQPS != nil { + for key, value := range tc.additionalQPS { + rb.Queries(map[string][]string{key: {value}}) + } + } + + // Send the request. + resp, err := rb.Send(context.Background()) + assert.NoError(t, err) + + // Read the response body. + bodyBytes, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + + body := string(bodyBytes) + + // The body should contain the expected URL path and query. + assert.Contains(t, body, tc.expectedURL, "The server did not receive the expected URL.") + }) + } +} + +type testAddress struct { + Postcode string `url:"postcode"` + City string `url:"city"` +} + +type testQueryStruct struct { + Name string `url:"name"` + Occupation string `url:"occupation,omitempty"` + Age int `url:"age"` + IsActive bool `url:"is_active,int"` + Tags []string `url:"tags,comma"` + Address testAddress `url:"addr"` +} + +func TestQueryStructWithClient(t *testing.T) { + // Start a test HTTP server that JSON-encodes and echoes back the query parameters received + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + queryParams := r.URL.Query() + + w.Header().Set(HeaderContentType, ContentTypeJSON) + + if encoder := json.NewEncoder(w); encoder != nil { + if err := encoder.Encode(queryParams); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } + } else { + http.Error(w, "Failed to create JSON encoder", http.StatusInternalServerError) + } + })) + + defer server.Close() + + // Create an instance of the client, pointing to the test server + client := Create(&Config{BaseURL: server.URL}) + + // Define the struct to be used for query parameters + exampleStruct := testQueryStruct{ + Name: "John Snow", + Occupation: "Developer", + Age: 30, + IsActive: true, + Tags: []string{"go", "programming"}, + Address: testAddress{ + Postcode: "1234", + City: "GoCity", + }, + } + + // Send a request to the server using the client and the struct for query parameters + resp, err := client.NewRequestBuilder(http.MethodGet, "/").QueriesStruct(exampleStruct).Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + var response map[string][]string + err = resp.ScanJSON(&response) + assert.NoError(t, err) + + // Now we can assert the values directly + assert.Contains(t, response, "name") + assert.Equal(t, []string{"John Snow"}, response["name"]) + assert.Contains(t, response, "occupation") + assert.Equal(t, []string{"Developer"}, response["occupation"]) + assert.Contains(t, response, "age") + assert.Equal(t, []string{"30"}, response["age"]) + assert.Contains(t, response, "is_active") + assert.Equal(t, []string{"1"}, response["is_active"]) + assert.Contains(t, response, "tags") + assert.Equal(t, []string{"go,programming"}, response["tags"]) + + err = resp.Close() + assert.NoError(t, err) +} + +func TestHeaderManipulationMethods(t *testing.T) { + // Start a test HTTP server that checks received headers + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check headers + assert.Equal(t, ContentTypeJSON, r.Header.Get(HeaderContentType)) + assert.Equal(t, "Bearer token", r.Header.Get(HeaderAuthorization)) + assert.Empty(t, r.Header.Get("X-Deprecated-Header")) + + fmt.Fprintln(w, "Headers received") + })) + + defer server.Close() + + // Create an instance of the client, pointing to the test server + rq := Create(&Config{BaseURL: server.URL}).Get("/test-headers") + rq.Headers(http.Header{HeaderContentType: []string{ContentTypeJSON}}) + rq.AddHeader(HeaderAuthorization, "Bearer token") + rq.Header("X-Modified-Header", "NewValue") + rq.AddHeader("X-Deprecated-Header", "OldValue") + rq.DelHeader("X-Deprecated-Header") + + // Send the request + resp, err := rq.Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + responseBody, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + assert.Contains(t, string(responseBody), "Headers received") +} + +func TestUserAgentMethod(t *testing.T) { + // Start a test HTTP server that checks received User-Agent header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check User-Agent header + assert.Equal(t, "MyCustomUserAgent", r.Header.Get(HeaderUserAgent)) + + fmt.Fprintln(w, "User-Agent received") + })) + defer server.Close() + + // Create an instance of the client, pointing to the test server + client := Create(&Config{BaseURL: server.URL}) + rq := client.Get("/test-user-agent") + + // Set the User-Agent header using the UserAgent method + rq.UserAgent("MyCustomUserAgent") + + // Send the request + resp, err := rq.Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + responseBody, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + assert.Contains(t, string(responseBody), "User-Agent received") +} + +func TestContentTypeMethod(t *testing.T) { + // Start a test HTTP server that checks received Content-Type header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Content-Type header + assert.Equal(t, ContentTypeJSON, r.Header.Get(HeaderContentType)) + + fmt.Fprintln(w, "Content-Type received") + })) + + defer server.Close() + + // Create an instance of the client, pointing to the test server + client := Create(&Config{BaseURL: server.URL}) + rq := client.Get("/test-content-type") + + // Set the Content-Type header using the ContentType method + rq.ContentType(ContentTypeJSON) + + // Send the request + resp, err := rq.Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + responseBody, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + assert.Contains(t, string(responseBody), "Content-Type received") +} + +func TestAcceptMethod(t *testing.T) { + // Start a test HTTP server that checks received Accept header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header + assert.Equal(t, ContentTypeXML, r.Header.Get(HeaderAccept)) + + fmt.Fprintln(w, "Accept received") + })) + + defer server.Close() + + // Create an instance of the client, pointing to the test server + client := Create(&Config{BaseURL: server.URL}) + rq := client.Get("/test-accept") + + // Set the Accept header using the Accept method + rq.Accept(ContentTypeXML) + + // Send the request + resp, err := rq.Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + responseBody, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + assert.Contains(t, string(responseBody), "Accept received") +} + +func TestRefererMethod(t *testing.T) { + // Start a test HTTP server that checks received Referer header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Referer header + assert.Equal(t, "https://example.com", r.Header.Get(HeaderReferer)) + + fmt.Fprintln(w, "Referer received") + })) + + defer server.Close() + + // Create an instance of the client, pointing to the test server + client := Create(&Config{BaseURL: server.URL}) + rq := client.Get("/test-referer") + + // Set the Referer header + rq.Referer("https://example.com") + + // Send the request + resp, err := rq.Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + responseBody, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + assert.Contains(t, string(responseBody), "Referer received") +} + +func TestCookieManipulationMethods(t *testing.T) { + // Start a test HTTP server that checks received cookies + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check cookies + cookie1, err1 := r.Cookie("SessionID") + assert.NoError(t, err1) + assert.Equal(t, "12345", cookie1.Value) + + cookie2, err2 := r.Cookie("AuthToken") + assert.NoError(t, err2) + assert.Equal(t, "abcdef", cookie2.Value) + + // Ensure the deleted cookie is not present + _, err3 := r.Cookie("DeletedCookie") + assert.Error(t, err3) // We expect an error because the cookie should not be present + + fmt.Fprintln(w, "Cookies received") + })) + + defer server.Close() + + // Create an instance of the client, pointing to the test server + rq := Create(&Config{BaseURL: server.URL}).Get("/test-cookies") + // Using SetCookies to set multiple cookies at once + rq.Cookies(map[string]string{ + "SessionID": "12345", + "AuthToken": "abcdef", + "DeletedCookie": "should-be-deleted", + }) + // Demonstrate individual cookie manipulation + rq.Cookie("SingleCookie", "single-value") + // Removing a previously set cookie + rq.DelCookie("DeletedCookie") + + // Send the request + resp, err := rq.Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + responseBody, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + assert.Contains(t, string(responseBody), "Cookies received") +} + +func TestPathParameterMethods(t *testing.T) { + // Start a test HTTP server that checks the received path for correctness + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if the path is as expected + expectedPath := "/users/johnDoe/posts/123" + if r.URL.Path != expectedPath { + t.Errorf("expected path %s, got %s", expectedPath, r.URL.Path) + } + + fmt.Fprintln(w, "Path parameters received correctly") + })) + + defer server.Close() + + // Create an instance of the client, pointing to the test server + client := Create(&Config{BaseURL: server.URL}) + rq := client.Get("/users/{userId}/posts/{postId}") + + // Using PathParams to set multiple path params at once + rq.PathParams(map[string]string{ + "postId": "123", + }) + + // Demonstrate individual path parameter manipulation + rq.PathParam("userId", "johnDoe").PathParam("hello", "world") + rq.DelPathParam("hello") + + // Send the request + resp, err := rq.Send(context.Background()) + assert.NoError(t, err) + + // Read and verify the response + responseBody, err := io.ReadAll(resp.RawResponse.Body) + assert.NoError(t, err) + assert.Contains(t, string(responseBody), "Path parameters received correctly") +} + +func startEchoServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, _ := io.ReadAll(r.Body) + + w.Header().Set(HeaderContentType, ContentTypeJSON) + + response := map[string]string{ + "body": string(bodyBytes), + "contentType": r.Header.Get(HeaderContentType), + } + + if encoder := json.NewEncoder(w); encoder != nil { + if err := encoder.Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } + } else { + http.Error(w, "Failed to create JSON encoder", http.StatusInternalServerError) + } + })) +} + +func TestFormFields(t *testing.T) { + server := startEchoServer() // Starts a mock HTTP server that echoes back received requests + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Example form data using a map + formData := map[string]string{ + "name": "Jane Doe", + "age": "32", + } + + resp, err := client.Post("/"). + FormFields(formData). // Using FormFields to set form data + Send(context.Background()) + assert.NoError(t, err, "Request should not fail") + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err, "Response should be parsed without error") + + // Validates that the form data was correctly encoded and sent in the request body + expectedEncodedFormData := url.Values{"name": {"Jane Doe"}, "age": {"32"}}.Encode() + + assert.Equal(t, expectedEncodedFormData, response["body"], "The body content should match the encoded form data") + assert.Equal(t, ContentTypeForm, response["contentType"], "The content type should be application/x-www-form-urlencoded") +} + +func TestFormField(t *testing.T) { + server := startEchoServer() // Simulated HTTP server + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + resp, err := client.Post("/"). + FormField("name", "John Snow"). // Adding a single form field + Send(context.Background()) + assert.NoError(t, err, "No error expected on sending request") + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err, "Parsing response should not error") + + // Validate that the single form field was correctly encoded and sent + expectedEncodedFormData := "name=John+Snow" + assert.Equal(t, expectedEncodedFormData, response["body"], "The body content should match the single form field") +} + +func TestDelFormField(t *testing.T) { + server := startEchoServer() // Setup mock server + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Set initial form fields + initialFormData := map[string]string{ + "name": "Jane Doe", + "age": "32", + } + + // Delete the "age" field before sending + resp, err := client.Post("/"). + FormFields(initialFormData). + DelFormField("age"). // Removing an existing form field + Send(context.Background()) + assert.NoError(t, err, "Expect no error on request send") + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err, "Expect no error on response parse") + + // Validates that the "age" field was correctly removed + expectedEncodedFormData := "name=Jane+Doe" + assert.Equal(t, expectedEncodedFormData, response["body"], "The body should match after deleting a field") +} + +func TestBody(t *testing.T) { + server := startEchoServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Example body data + bodyData := url.Values{"key": []string{"value"}} + encodedData := bodyData.Encode() + + resp, err := client.Post("/"). + Body(bodyData). + ContentType(ContentTypeForm). + Send(context.Background()) + + assert.NoError(t, err) + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err) + + // Asserts + assert.Equal(t, encodedData, response["body"], "The body content should match.") + assert.Equal(t, ContentTypeForm, response["contentType"], "The content type should be set correctly.") +} + +func TestJSONBody(t *testing.T) { + server := startEchoServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Example JSON data + jsonData := map[string]interface{}{"name": "John Snow", "age": 30} + jsonDataStr, _ := json.Marshal(jsonData) + + resp, err := client.Post("/"). + JSONBody(jsonData). + Send(context.Background()) + assert.NoError(t, err) + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err) + + // Asserts + assert.JSONEq(t, string(jsonDataStr), response["body"], "The body content should match.") + assert.Equal(t, ContentTypeJSON, response["contentType"], "The content type should be set to application/json.") +} + +func TestXMLBody(t *testing.T) { + server := startEchoServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Example XML data + xmlData := struct { + XMLName xml.Name `xml:"Person"` + Name string `xml:"Name"` + Age int `xml:"Age"` + }{Name: "Jane Doe", Age: 32} + xmlDataStr, _ := xml.Marshal(xmlData) + + resp, err := client.Post("/"). + XMLBody(xmlData). + Send(context.Background()) + assert.NoError(t, err) + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err) + + // Asserts + assert.Equal(t, string(xmlDataStr), strings.TrimSpace(response["body"]), "The body content should match.") + assert.Equal(t, ContentTypeXML, response["contentType"], "The content type should be set to application/xml.") +} + +func TestFormWithUrlValues(t *testing.T) { + server := startEchoServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Example form data + formData := url.Values{ + "name": []string{"Jane Doe"}, + "age": []string{"32"}, + } + + resp, err := client.Post("/"). + Form(formData). + Send(context.Background()) + assert.NoError(t, err) + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err) + + // Asserts + assert.Equal(t, formData.Encode(), response["body"], "The body content should match.") + assert.Equal(t, ContentTypeForm, response["contentType"], "The content type should be set correctly.") +} + +func TestTextBody(t *testing.T) { + server := startEchoServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Example text data + textData := "This is a plain text body." + + resp, err := client.Post("/"). + TextBody(textData). + Send(context.Background()) + assert.NoError(t, err) + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err) + + // Asserts + assert.Equal(t, textData, response["body"], "The body content should match.") + assert.Equal(t, ContentTypeText, response["contentType"], "The content type should be set to text/plain.") +} + +func TestRawBody(t *testing.T) { + server := startEchoServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + // Example raw data + rawData := []byte("This is raw byte data.") + + resp, err := client.Post("/"). + RawBody(rawData). + ContentType("application/octet-stream"). // Explicitly set content type + Send(context.Background()) + assert.NoError(t, err) + + var response map[string]string + err = resp.Scan(&response) + assert.NoError(t, err) + + // Asserts + assert.Equal(t, string(rawData), response["body"], "The body content should match.") + assert.Equal(t, "application/octet-stream", response["contentType"], "The content type should be set to application/octet-stream.") +} + +func TestRequestLevelRetries(t *testing.T) { + var requestCount int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&requestCount, 1) + if count == 1 { + // Simulate a server error on the first request + w.WriteHeader(http.StatusInternalServerError) + } else { + // Succeed on subsequent attempts + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "Success") + } + })) + + defer server.Close() + + // Set up a request builder with retry configuration + client := Create(&Config{BaseURL: server.URL}) + rq := client.Get("/") + rq.MaxRetries(2) // Allow up to 2 retries + rq.RetryStrategy(func(attempt int) time.Duration { return 10 * time.Millisecond }) + rq.RetryIf(func(req *http.Request, resp *http.Response, err error) bool { + // Retry on server error + return resp.StatusCode == http.StatusInternalServerError + }) + + // Send the request + _, err := rq.Send(context.Background()) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + + // Verify that the retry logic was applied + expectedAttempts := int32(2) + if requestCount != expectedAttempts { + t.Errorf("Expected %d attempts, got %d", expectedAttempts, requestCount) + } +} + +func TestFormWithNil(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeJSON) + // Ensure a valid JSON response is sent back for all scenarios + response := map[string]interface{}{ + "status": "received", + "body": "empty or nil form", + } + if encoder := json.NewEncoder(w); encoder != nil { + if err := encoder.Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } + } else { + http.Error(w, "Failed to create JSON encoder", http.StatusInternalServerError) + } + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + resp, err := client.Post("/").Form(nil).Send(context.Background()) + assert.NoError(t, err, "No error expected on sending request with nil form") + + var response map[string]interface{} + err = resp.ScanJSON(&response) + assert.NoError(t, err, "Expect no error on parsing response") + + // Assert form is correctly received + assert.Contains(t, response, "status", "Status should be present") + assert.Contains(t, response, "body", "Body should be present") +} + +func startFormHandlingServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeJSON) + + err := r.ParseMultipartForm(32 << 20) // limit to 32MB + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + fields := make(map[string][]string) + files := make(map[string][]string) + + if r.MultipartForm != nil { + for key, values := range r.MultipartForm.Value { + fields[key] = values + } + + for key, fileHeaders := range r.MultipartForm.File { + for _, fileHeader := range fileHeaders { + files[key] = append(files[key], fileHeader.Filename) + } + } + } + + response := map[string]interface{}{ + "fields": fields, + "files": files, + } + + if encoder := json.NewEncoder(w); encoder != nil { + if err := encoder.Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } + } else { + http.Error(w, "Failed to create JSON encoder", http.StatusInternalServerError) + } + })) +} + +func TestFormWithFiles(t *testing.T) { + server := startFormHandlingServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + fileContent1 := strings.NewReader("File content 1") + fileContent2 := strings.NewReader("File content 2") + + formData := map[string]any{ + "file1": &File{Name: "file1", FileName: "file1.txt", Content: io.NopCloser(fileContent1)}, + "file2": &File{Name: "file2", FileName: "file2.txt", Content: io.NopCloser(fileContent2)}, + } + + resp, err := client.Post("/").Form(formData).Send(context.Background()) + assert.NoError(t, err, "No error expected on sending request with files") + + var response map[string]interface{} + err = resp.ScanJSON(&response) + assert.NoError(t, err, "Expect no error on parsing response") + + // Assert files are correctly received + assert.Contains(t, response["files"].(map[string]interface{}), "file1", "File1 should be present") + assert.Contains(t, response["files"].(map[string]interface{}), "file2", "File2 should be present") +} + +func TestFormWithMixedFilesAndFields(t *testing.T) { + server := startFormHandlingServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + fileContent := strings.NewReader("File content 1") + + formData := map[string]any{ + "name": "John Snow", + "age": "30", + "file": &File{Name: "file", FileName: "file.txt", Content: io.NopCloser(fileContent)}, + } + + resp, err := client.Post("/").Form(formData).Send(context.Background()) + assert.NoError(t, err, "No error expected on sending request with mixed form data") + + var response map[string]interface{} + err = resp.Scan(&response) + assert.NoError(t, err, "Expect no error on parsing response") + + // Assert fields and files are correctly received + fields := response["fields"].(map[string]interface{}) + assert.Contains(t, fields, "name", "Name should be present") + assert.Contains(t, fields, "age", "Age should be present") + + files := response["files"].(map[string]interface{}) + assert.Contains(t, files, "file", "File should be present") +} + +// TestAuthRequest verifies that the Auth method correctly applies basic authentication to a request +func TestAuthRequest(t *testing.T) { + // Expected username and password for basic authentication + expectedUsername := "testuser" + expectedPassword := "testpass" + + // Encode the username and password into the expected format for the Authorization header + expectedAuthValue := "Basic " + base64.StdEncoding.EncodeToString([]byte(expectedUsername+":"+expectedPassword)) + + // Set up a mock server to handle the request. This server checks the Authorization header + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Retrieve the Authorization header from the incoming request + authHeader := r.Header.Get(HeaderAuthorization) + + // Compare the Authorization header to the expected value + if authHeader != expectedAuthValue { + // If they don't match, respond with 401 Unauthorized to indicate a failed authentication attempt + w.WriteHeader(http.StatusUnauthorized) + return + } + + // If the Authorization header is correct, respond with 200 OK to indicate successful authentication + w.WriteHeader(http.StatusOK) + })) + + defer mockServer.Close() // Ensure the server is shut down at the end of the test + + // Initialize the HTTP client with the base URL set to the mock server's URL + client := Create(&Config{ + BaseURL: mockServer.URL, + }) + + // Create a request to the mock server with basic authentication credentials + resp, err := client.Get("/").Auth(BasicAuth{ + Username: expectedUsername, + Password: expectedPassword, + }).Send(context.Background()) + + if err != nil { + // If there's an error sending the request, fail the test + t.Fatalf("Failed to send request: %v", err) + } + + defer resp.Close() //nolint: errcheck + + // Check if the response status code is 200 OK, which indicates successful authentication + if resp.StatusCode() != http.StatusOK { + // If the status code is not 200, it indicates the Authorization header was not set correctly + t.Errorf("Expected status code 200, got %d. Indicates Authorization header was not set correctly", resp.StatusCode()) + } +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..1ea5e66 --- /dev/null +++ b/response.go @@ -0,0 +1,295 @@ +package httpsling + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/theopenlane/utils/rout" +) + +// Response represents an HTTP response +type Response struct { + // stream is the callback function for streaming responses + stream StreamCallback + // streamErr is the callback function for streaming errors + streamErr StreamErrCallback + // streamDone is the callback function for when the stream is done + streamDone StreamDoneCallback + // RawResponse is the original HTTP response + RawResponse *http.Response + // BodyBytes is the response body as a juicy byte slice + BodyBytes []byte + // Context is the request context + Context context.Context + // Client is the HTTP client + Client *Client +} + +// NewResponse creates a new wrapped response object leveraging the buffer pool +func NewResponse(ctx context.Context, resp *http.Response, client *Client, stream StreamCallback, streamErr StreamErrCallback, streamDone StreamDoneCallback) (*Response, error) { + response := &Response{ + RawResponse: resp, + Context: ctx, + BodyBytes: nil, + stream: stream, + streamErr: streamErr, + streamDone: streamDone, + Client: client, + } + + if response.stream != nil { + go response.handleStream() + } else if err := response.handleNonStream(); err != nil { + return nil, err + } + + return response, nil +} + +var maxStreamBufferSize = 512 * 1024 + +// handleStream processes the HTTP response as a stream +func (r *Response) handleStream() { + defer func() { + if err := r.RawResponse.Body.Close(); err != nil { + r.Client.Logger.Errorf("failed to close response body: %v", err) + } + }() + + scanner := bufio.NewScanner(r.RawResponse.Body) + + scanBuf := make([]byte, 0, maxStreamBufferSize) + + scanner.Buffer(scanBuf, maxStreamBufferSize) + + for scanner.Scan() { + if err := r.stream(scanner.Bytes()); err != nil { + break + } + } + + if err := scanner.Err(); err != nil && r.streamErr != nil { + r.streamErr(err) + } + + if r.streamDone != nil { + r.streamDone() + } +} + +// handleNonStream reads the HTTP response body into a buffer for non-streaming responses +func (r *Response) handleNonStream() error { + buf := GetBuffer() + defer PutBuffer(buf) + + _, err := buf.ReadFrom(r.RawResponse.Body) + if err != nil { + return fmt.Errorf("%w: %v", ErrResponseReadFailed, err) + } + + _ = r.RawResponse.Body.Close() + + r.RawResponse.Body = io.NopCloser(bytes.NewReader(buf.B)) + r.BodyBytes = buf.B + + return nil +} + +// StatusCode returns the HTTP status code of the response +func (r *Response) StatusCode() int { + return r.RawResponse.StatusCode +} + +// Status returns the status string of the response +func (r *Response) Status() string { + return r.RawResponse.Status +} + +// Header returns the response headers +func (r *Response) Header() http.Header { + return r.RawResponse.Header +} + +// Cookies parses and returns the cookies set in the response +func (r *Response) Cookies() []*http.Cookie { + return r.RawResponse.Cookies() +} + +// Location returns the URL redirected address +func (r *Response) Location() (*url.URL, error) { + return r.RawResponse.Location() +} + +// URL returns the request URL that elicited the response +func (r *Response) URL() *url.URL { + return r.RawResponse.Request.URL +} + +// ContentType returns the value of the HeaderContentType header +func (r *Response) ContentType() string { + return r.Header().Get(HeaderContentType) +} + +// IsContentType Checks if the response Content-Type header matches a given content type +func (r *Response) IsContentType(contentType string) bool { + return strings.Contains(r.ContentType(), contentType) +} + +// IsJSON checks if the response Content-Type indicates JSON +func (r *Response) IsJSON() bool { + return r.IsContentType(ContentTypeJSON) +} + +// IsXML checks if the response Content-Type indicates XML +func (r *Response) IsXML() bool { + return r.IsContentType(ContentTypeXML) +} + +// IsYAML checks if the response Content-Type indicates YAML +func (r *Response) IsYAML() bool { + return r.IsContentType(ContentTypeYAML) +} + +// ContentLength returns the length of the response body +func (r *Response) ContentLength() int { + if r.BodyBytes == nil { + return 0 + } + + return len(r.BodyBytes) +} + +// IsEmpty checks if the response body is empty +func (r *Response) IsEmpty() bool { + return r.ContentLength() == 0 +} + +// IsSuccess checks if the response status code indicates success +func (r *Response) IsSuccess() bool { + code := r.StatusCode() + + return code >= http.StatusOK && code <= http.StatusIMUsed +} + +// Body returns the response body as a juicy byte slice +func (r *Response) Body() []byte { + return r.BodyBytes +} + +// String returns the response body as a string +func (r *Response) String() string { + return string(r.BodyBytes) +} + +// Scan attempts to unmarshal the response body based on its content type +func (r *Response) Scan(v interface{}) error { + switch { + case r.IsJSON(): + return r.ScanJSON(v) + case r.IsXML(): + return r.ScanXML(v) + case r.IsYAML(): + return r.ScanYAML(v) + } + + return fmt.Errorf("%w: %s", ErrUnsupportedContentType, r.ContentType()) +} + +// ScanJSON unmarshals the response body into a struct via JSON decoding +func (r *Response) ScanJSON(v interface{}) error { + if r.BodyBytes == nil { + return nil + } + + return r.Client.JSONDecoder.Decode(bytes.NewReader(r.BodyBytes), v) +} + +// ScanXML unmarshals the response body into a struct via XML decoding +func (r *Response) ScanXML(v interface{}) error { + if r.BodyBytes == nil { + return nil + } + + return r.Client.XMLDecoder.Decode(bytes.NewReader(r.BodyBytes), v) +} + +// ScanYAML unmarshals the response body into a struct via YAML decoding +func (r *Response) ScanYAML(v interface{}) error { + if r.BodyBytes == nil { + return nil + } + + return r.Client.YAMLDecoder.Decode(bytes.NewReader(r.BodyBytes), v) +} + +const dirPermissions = 0755 + +// Save saves the response body to a file or io.Writer +func (r *Response) Save(v any) error { + switch p := v.(type) { + case string: + file := filepath.Clean(p) + dir := filepath.Dir(file) + + // Create the directory if it doesn't exist + if _, err := os.Stat(dir); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return rout.HTTPErrorResponse(err) + } + + if err = os.MkdirAll(dir, dirPermissions); err != nil { + return rout.HTTPErrorResponse(err) + } + } + + // Create and open the file for writing + outFile, err := os.Create(file) + if err != nil { + return rout.HTTPErrorResponse(err) + } + + defer func() { + if err := outFile.Close(); err != nil { + r.Client.Logger.Errorf("failed to close file: %v", err) + } + }() + + // Write the response body to the file + _, err = io.Copy(outFile, bytes.NewReader(r.Body())) + if err != nil { + return rout.HTTPErrorResponse(err) + } + + return nil + case io.Writer: + // Write the response body directly to the provided io.Writer + _, err := io.Copy(p, bytes.NewReader(r.Body())) + if err != nil { + return rout.HTTPErrorResponse(err) + } + + if pc, ok := p.(io.WriteCloser); ok { + if err := pc.Close(); err != nil { + r.Client.Logger.Errorf("failed to close io.Writer: %v", err) + } + } + + return nil + default: + return ErrNotSupportSaveMethod + } +} + +// Close closes the response body +func (r *Response) Close() error { + return r.RawResponse.Body.Close() +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..bc3d755 --- /dev/null +++ b/response_test.go @@ -0,0 +1,424 @@ +package httpsling + +import ( + "bytes" + "context" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestResponseContentType(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + tests := []struct { + url string + contentType string + expected bool + }{ + {"/test-json", ContentTypeJSON, true}, + {"/test-xml", ContentTypeXML, true}, + {"/test-text", ContentTypeText, true}, + {"/test-text", ContentTypeJSON, false}, + {"/test-json", ContentTypeText, false}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("ContentType is %s", tt.contentType), func(t *testing.T) { + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get(tt.url).Send(context.Background()) + assert.NoError(t, err) + assert.Equal(t, tt.expected, resp.IsContentType(tt.contentType)) + }) + } +} + +func TestResponseStatusAndStatusCode(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-status-code").Send(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode()) + assert.Contains(t, resp.Status(), "201 Created") +} + +func TestResponseHeaderAndCookies(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + t.Run("Test Headers", func(t *testing.T) { + resp, err := client.Get("/test-headers").Send(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "TestValue", resp.Header().Get("X-Custom-Header")) + }) + + t.Run("Test Cookies", func(t *testing.T) { + resp, err := client.Get("/test-cookies").Send(context.Background()) + assert.NoError(t, err) + + cookies := resp.Cookies() + + assert.Equal(t, 1, len(cookies)) + assert.Equal(t, "test-cookie", cookies[0].Name) + assert.Equal(t, "cookie-value", cookies[0].Value) + }) +} + +func TestResponseContentLengthAndIsEmpty(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + + t.Run("Non-empty response", func(t *testing.T) { + resp, err := client.Get("/test-content-type?ct=text/plain").Send(context.Background()) + assert.NoError(t, err) + assert.False(t, resp.IsEmpty()) + assert.Greater(t, resp.ContentLength(), 0) + }) + + t.Run("Empty response", func(t *testing.T) { + resp, err := client.Get("/test-empty").Send(context.Background()) + assert.NoError(t, err) + assert.True(t, resp.IsEmpty()) + assert.Equal(t, 0, resp.ContentLength()) + }) +} + +func TestResponseIsSuccess(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-status-code").Send(context.Background()) // This endpoint sets status 201 + assert.NoError(t, err) + + assert.True(t, resp.IsSuccess()) +} + +func TestResponseIsSuccessForFailure(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-failure").Send(context.Background()) // This endpoint sets status 500 + assert.NoError(t, err) + + assert.False(t, resp.IsSuccess()) +} + +func TestResponseAfterRedirect(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-redirect").Send(context.Background()) + assert.NoError(t, err) + + bodyStr := resp.String() + expectedContent := "Redirected\n" + assert.Contains(t, bodyStr, expectedContent, "The response content should be 'Redirected'") +} + +func TestResponseBodyAndString(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-body").Send(context.Background()) + assert.NoError(t, err) + + bodyStr := resp.String() + assert.Contains(t, bodyStr, "This is the response body.") + + bodyBytes := resp.Body() + assert.Contains(t, string(bodyBytes), "This is the response body.") +} + +func TestResponseScanJSON(t *testing.T) { + type jsonTestResponse struct { + Message string `json:"message"` + Status bool `json:"status"` + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeJSON) + fmt.Fprintln(w, `{"message": "Call your buddy JSON", "status": true}`) + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-json").Send(context.Background()) + assert.NoError(t, err) + + var jsonResponse jsonTestResponse + err = resp.Scan(&jsonResponse) + + assert.NoError(t, err) + assert.Equal(t, "Call your buddy JSON", jsonResponse.Message) + assert.True(t, jsonResponse.Status) +} + +func TestResponseScanXML(t *testing.T) { + type xmlTestResponse struct { + XMLName xml.Name `xml:"Response"` + Message string `xml:"Message"` + Status bool `xml:"Status"` + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(HeaderContentType, ContentTypeXML) + fmt.Fprintln(w, `XML is terribletrue`) + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-xml").Send(context.Background()) + assert.NoError(t, err) + + var xmlResponse xmlTestResponse + err = resp.Scan(&xmlResponse) + + assert.NoError(t, err) + assert.Equal(t, "XML is terrible", xmlResponse.Message) + assert.True(t, xmlResponse.Status) +} + +func TestResponseScanYAML(t *testing.T) { + type yamlTestResponse struct { + Message string `yaml:"message"` + Status bool `yaml:"status"` + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + yml := `--- +message: My YAML is better than your YAML +status: true +` + + w.Header().Set(HeaderContentType, ContentTypeYAML) + fmt.Fprint(w, yml) + })) + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-yaml").Send(context.Background()) + assert.NoError(t, err) + + var yamlResponse yamlTestResponse + err = resp.Scan(&yamlResponse) + + assert.NoError(t, err) + assert.Equal(t, "My YAML is better than your YAML", yamlResponse.Message) + assert.True(t, yamlResponse.Status) +} + +func TestResponseScanUnsupportedContentType(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-pdf").Send(context.Background()) + assert.NoError(t, err) + + var dummyResponse struct{} + err = resp.Scan(&dummyResponse) + + assert.Error(t, err, "expected an error for unsupported content type") + assert.ErrorIs(t, err, ErrUnsupportedContentType) +} + +func TestResponseClose(t *testing.T) { + server := startTestHTTPServer() + + defer server.Close() + + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get("/test-get").Send(context.Background()) + assert.NoError(t, err) + + err = resp.Close() + assert.NoError(t, err, "expected no error when closing the response") +} + +func TestResponseURL(t *testing.T) { + server := startTestHTTPServer() + defer server.Close() + + tests := []struct { + name string + path string // Path to append to the base URL + expected string // Expected final URL (for comparison) + }{ + { + name: "Base URL", + path: "", + expected: server.URL, + }, + { + name: "Path Parameter", + path: "/path-param", + expected: server.URL + "/path-param", + }, + { + name: "Query Parameter", + path: "/query?param=value", + expected: server.URL + "/query?param=value", + }, + { + name: "Hash Fragment", + path: "/hash#fragment", + expected: server.URL + "/hash#fragment", + }, + { + name: "Complex URL", + path: "/complex/path?param=value#fragment", + expected: server.URL + "/complex/path?param=value#fragment", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := Create(&Config{BaseURL: server.URL}) + resp, err := client.Get(tc.path).Send(context.Background()) + assert.NoError(t, err) + + expectedURL, _ := url.Parse(tc.expected) + + assert.Equal(t, expectedURL.String(), resp.URL().String(), "The response URL should match the expected URL") + }) + } +} + +func TestResponseSaveToFile(t *testing.T) { + // Setup a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Where is the money, Lebowski") + })) + + defer server.Close() + + // Create client and send request + client := Create(&Config{BaseURL: server.URL}) + + resp, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + // Define the path where to save the response body + filePath := "./testdata/sample_response.txt" + + err = resp.Save(filePath) + if err != nil { + t.Fatalf("Failed to save response to file: %v", err) + } + + // Read the saved file + savedData, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read saved file: %v", err) + } + + // Verify the file content + expected := "Where is the money, Lebowski" + if string(savedData) != expected { + t.Errorf("Expected file content %q, got %q", expected, string(savedData)) + } + + // Clean up the saved file + err = os.Remove(filePath) + if err != nil { + t.Fatalf("Failed to remove saved file: %v", err) + } +} + +func TestResponseSaveToWriter(t *testing.T) { + // Setup a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "You know nothing, John Snow") + })) + defer server.Close() + + // Create client and send request + client := Create(&Config{BaseURL: server.URL}) + + resp, err := client.Get("/").Send(context.Background()) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + + // Use bytes.Buffer as the writer + var buffer bytes.Buffer + + err = resp.Save(&buffer) + if err != nil { + t.Fatalf("Failed to save response to buffer: %v", err) + } + + // Verify the buffer content + expected := "You know nothing, John Snow" + if buffer.String() != expected { + t.Errorf("Expected buffer content %q, got %q", expected, buffer.String()) + } +} + +func TestStream(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(HeaderContentType, "text/event-stream") + + for i := 0; i < 3; i++ { + fmt.Fprintf(w, "data: Message %d\n", i) + w.(http.Flusher).Flush() + time.Sleep(100 * time.Millisecond) + } + })) + + defer server.Close() + + doneCh := make(chan struct{}) + dataReceived := make([]string, 0) + + client := Create(&Config{BaseURL: server.URL}) + _, err := client.Get("/").Stream(func(data []byte) error { + dataReceived = append(dataReceived, string(data)) + + assert.Contains(t, string(data), "data: Message") + return nil + }).StreamErr(func(err error) { + assert.NoError(t, err) + }).StreamDone(func() { + assert.Equal(t, 3, len(dataReceived)) + close(doneCh) + }).Send(context.Background()) + assert.NoError(t, err) + + time.Sleep(1 * time.Second) + <-doneCh + + assert.Equal(t, 3, len(dataReceived)) +} diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..ffbc26b --- /dev/null +++ b/retry.go @@ -0,0 +1,55 @@ +package httpsling + +import ( + "math" + "net/http" + "time" +) + +// RetryConfig defines the configuration for retrying requests +type RetryConfig struct { + // MaxRetries is the maximum number of retry attempts + MaxRetries int + // Strategy is the backoff strategy function + Strategy BackoffStrategy + // RetryIf is the custom retry condition function + RetryIf RetryIfFunc +} + +// RetryIfFunc defines the function signature for retry conditions +type RetryIfFunc func(req *http.Request, resp *http.Response, err error) bool + +// BackoffStrategy defines a function that returns the delay before the next retry +type BackoffStrategy func(attempt int) time.Duration + +// DefaultBackoffStrategy provides a simple constant delay between retries +func DefaultBackoffStrategy(delay time.Duration) func(int) time.Duration { + return func(attempt int) time.Duration { + return delay + } +} + +// LinearBackoffStrategy increases the delay linearly with each retry attempt +func LinearBackoffStrategy(initialInterval time.Duration) func(int) time.Duration { + return func(attempt int) time.Duration { + return initialInterval * time.Duration(attempt+1) + } +} + +// ExponentialBackoffStrategy increases the delay exponentially with each retry attempt +func ExponentialBackoffStrategy(initialInterval time.Duration, multiplier float64, maxBackoffTime time.Duration) func(int) time.Duration { + return func(attempt int) time.Duration { + delay := initialInterval * time.Duration(math.Pow(multiplier, float64(attempt))) + + if delay > maxBackoffTime { + return maxBackoffTime + } + + return delay + } +} + +// DefaultRetryIf is a simple retry condition that retries on 5xx status codes +func DefaultRetryIf(req *http.Request, resp *http.Response, err error) bool { + return resp.StatusCode >= http.StatusInternalServerError || err != nil +} diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..b8bc80c --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,16 @@ +sonar.projectKey=theopenlane_httpsling +sonar.organization=theopenlane + +sonar.projectName=httpsling +sonar.projectVersion=1.0 + +sonar.sources=. + +sonar.exclusions=**/*_test.go,**/vendor/** +sonar.tests=. +sonar.test.inclusions=**/*_test.go +sonar.test.exclusions=**/vendor/** + +sonar.sourceEncoding=UTF-8 +sonar.go.coverage.reportPaths=coverage.out +sonar.externalIssuesReportPaths=results.txt \ No newline at end of file