diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ab85d05 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,143 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + +jobs: + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-test- + + - name: Run tests + run: cargo test --workspace + + - name: Run tests with all features + run: cargo test --workspace --all-features + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-lint-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-lint- + + - name: Check formatting + run: cargo fmt --all -- --check + + - name: Run clippy + run: cargo clippy --workspace --all-features -- -D warnings + + build: + name: Build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-build-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-build- + + - name: Build + run: cargo build --workspace + + - name: Build with all features + run: cargo build --workspace --all-features + + - name: Build release + run: cargo build --workspace --release + + docs: + name: Documentation + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-docs-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-docs- + + - name: Build documentation + run: cargo doc --workspace --all-features --no-deps + env: + RUSTDOCFLAGS: -D warnings + + msrv: + name: MSRV (1.75) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust 1.75 + uses: dtolnay/rust-action@1.75 + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-msrv-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-msrv- + + - name: Check MSRV + run: cargo check --workspace --all-features diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..12639f0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,89 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.1.2] - 2024-12-31 + +### Added +- `skip_paths` method for JwtLayer to exclude paths from JWT validation +- `docs_with_auth` method for Basic Auth protected Swagger UI +- `docs_with_auth_and_info` method for customized protected docs + +### Changed +- auth-api example now demonstrates protected docs with Basic Auth +- JWT middleware can now skip validation for public endpoints + +## [0.1.1] - 2024-12-31 + +### Added + +#### Phase 4: Ergonomics & v1.0 Preparation +- Body size limit middleware with configurable limits +- `.body_limit(size)` builder method on RustApi (default: 1MB) +- 413 Payload Too Large response for oversized requests +- Production error masking (`RUSTAPI_ENV=production`) +- Development error details (`RUSTAPI_ENV=development`) +- Unique error IDs (`err_{uuid}`) for log correlation +- Enhanced tracing layer with request_id, status, and duration +- Custom span field support via `.with_field(key, value)` +- Prometheus metrics middleware (feature-gated) +- `http_requests_total` counter with method, path, status labels +- `http_request_duration_seconds` histogram +- `rustapi_info` gauge with version information +- `/metrics` endpoint handler +- TestClient for integration testing without network binding +- TestRequest builder with method, header, and body support +- TestResponse with assertion helpers +- `RUSTAPI_DEBUG=1` macro expansion output support +- Improved route path validation at compile time +- Enhanced route conflict detection messages + +### Changed +- Error responses now include `error_id` field +- TracingLayer enhanced with additional span fields + +## [0.1.0] - 2024-12-01 + +### Added + +#### Phase 1: MVP Core +- Core HTTP server built on tokio and hyper 1.0 +- Radix-tree based routing with matchit +- Request extractors: `Json`, `Query`, `Path` +- Response types with automatic serialization +- Async handler support +- Basic error handling with `ApiError` +- `#[rustapi::get]`, `#[rustapi::post]` route macros +- `#[rustapi::main]` async main macro + +#### Phase 2: Validation & OpenAPI +- Automatic OpenAPI spec generation +- Swagger UI at `/docs` endpoint +- Request validation with validator crate +- `#[validate]` attribute support +- 422 Unprocessable Entity for validation errors +- `#[rustapi::tag]` and `#[rustapi::summary]` macros +- Schema derivation for request/response types + +#### Phase 3: Batteries Included +- JWT authentication middleware (`jwt` feature) +- `AuthUser` extractor for authenticated routes +- CORS middleware with builder pattern (`cors` feature) +- IP-based rate limiting (`rate-limit` feature) +- Configuration management with `.env` support (`config` feature) +- Cookie parsing extractor (`cookies` feature) +- SQLx error conversion (`sqlx` feature) +- Request ID middleware +- Middleware layer trait for custom middleware +- `extras` meta-feature for common optional features +- `full` feature for all optional features + +[Unreleased]: https://github.com/Tuntii/RustAPI/compare/v0.1.2...HEAD +[0.1.2]: https://github.com/Tuntii/RustAPI/compare/v0.1.1...v0.1.2 +[0.1.1]: https://github.com/Tuntii/RustAPI/compare/v0.1.0...v0.1.1 +[0.1.0]: https://github.com/Tuntii/RustAPI/releases/tag/v0.1.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..a4acfa3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,141 @@ +# Contributing to RustAPI + +Thank you for your interest in contributing to RustAPI! This document provides guidelines and information for contributors. + +## Code of Conduct + +By participating in this project, you agree to maintain a respectful and inclusive environment for everyone. + +## Getting Started + +1. Fork the repository +2. Clone your fork: `git clone https://github.com/Tuntii/RustAPI.git` +3. Create a new branch: `git checkout -b feature/your-feature-name` +4. Make your changes +5. Run tests: `cargo test --workspace` +6. Submit a pull request + +## Development Setup + +### Prerequisites + +- Rust 1.75 or later +- Cargo (comes with Rust) + +### Building + +```bash +# Build all crates +cargo build --workspace + +# Build with all features +cargo build --workspace --all-features +``` + +### Running Tests + +```bash +# Run all tests +cargo test --workspace + +# Run tests with all features +cargo test --workspace --all-features + +# Run a specific crate's tests +cargo test -p rustapi-core +``` + +## Code Style + +### Formatting + +All code must be formatted with `rustfmt`: + +```bash +cargo fmt --all +``` + +### Linting + +All code must pass `clippy` checks: + +```bash +cargo clippy --workspace --all-features -- -D warnings +``` + +### Documentation + +- All public APIs must have rustdoc documentation +- Include code examples in doc comments where appropriate +- Doc examples must compile and run + +## Pull Request Process + +1. **Create a descriptive PR title** following conventional commits: + - `feat:` for new features + - `fix:` for bug fixes + - `docs:` for documentation changes + - `refactor:` for code refactoring + - `test:` for test additions/changes + - `chore:` for maintenance tasks + +2. **Fill out the PR template** with: + - Description of changes + - Related issue numbers + - Testing performed + +3. **Ensure all checks pass**: + - All tests pass + - Code is formatted (`cargo fmt`) + - No clippy warnings (`cargo clippy`) + - Documentation builds + +4. **Request review** from maintainers + +5. **Address feedback** promptly and push updates + +## Commit Guidelines + +- Write clear, concise commit messages +- Use present tense ("Add feature" not "Added feature") +- Reference issues when applicable (`Fixes #123`) + +## Project Structure + +``` +RustAPI/ +├── crates/ +│ ├── rustapi-rs/ # Public-facing crate (re-exports) +│ ├── rustapi-core/ # Core HTTP engine and routing +│ ├── rustapi-macros/ # Procedural macros +│ ├── rustapi-validate/ # Validation integration +│ ├── rustapi-openapi/ # OpenAPI/Swagger support +│ └── rustapi-extras/ # Optional features (JWT, CORS, etc.) +├── examples/ # Example applications +├── benches/ # Benchmarks +└── scripts/ # Build and publish scripts +``` + +## Adding New Features + +1. Discuss the feature in an issue first +2. Follow the existing architecture patterns +3. Add tests for new functionality +4. Update documentation +5. Add examples if applicable + +## Reporting Issues + +When reporting issues, please include: + +- Rust version (`rustc --version`) +- RustAPI version +- Minimal reproduction code +- Expected vs actual behavior +- Error messages (if any) + +## Questions? + +Feel free to open an issue for questions or join discussions in existing issues. + +Thank you for contributing to RustAPI! diff --git a/Cargo.lock b/Cargo.lock index 7b01f4b..3980581 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,6 +66,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "auth-api" +version = "0.1.0" +dependencies = [ + "rustapi-rs", + "serde", + "tokio", + "utoipa", + "validator", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -289,6 +300,17 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crud-api" +version = "0.1.0" +dependencies = [ + "rustapi-rs", + "serde", + "tokio", + "utoipa", + "validator", +] + [[package]] name = "crypto-common" version = "0.1.7" @@ -1435,6 +1457,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prometheus" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1" +dependencies = [ + "cfg-if", + "fnv", + "lazy_static", + "memchr", + "parking_lot", + "protobuf", + "thiserror 1.0.69", +] + [[package]] name = "proptest" version = "1.9.0" @@ -1454,6 +1491,12 @@ dependencies = [ "unarray", ] +[[package]] +name = "protobuf" +version = "2.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" + [[package]] name = "quick-error" version = "1.2.3" @@ -1626,8 +1669,9 @@ dependencies = [ [[package]] name = "rustapi-core" -version = "0.1.1" +version = "0.1.2" dependencies = [ + "base64 0.22.1", "bytes", "cookie", "futures-util", @@ -1638,6 +1682,7 @@ dependencies = [ "inventory", "matchit", "pin-project-lite", + "prometheus", "proptest", "rustapi-openapi", "rustapi-validate", @@ -1652,11 +1697,12 @@ dependencies = [ "tower-service", "tracing", "tracing-subscriber", + "uuid", ] [[package]] name = "rustapi-extras" -version = "0.1.1" +version = "0.1.2" dependencies = [ "bytes", "cookie", @@ -1669,6 +1715,7 @@ dependencies = [ "jsonwebtoken", "proptest", "rustapi-core", + "rustapi-openapi", "serde", "serde_json", "sqlx", @@ -1679,7 +1726,7 @@ dependencies = [ [[package]] name = "rustapi-macros" -version = "0.1.1" +version = "0.1.2" dependencies = [ "proc-macro2", "quote", @@ -1688,7 +1735,7 @@ dependencies = [ [[package]] name = "rustapi-openapi" -version = "0.1.1" +version = "0.1.2" dependencies = [ "bytes", "http", @@ -1700,7 +1747,7 @@ dependencies = [ [[package]] name = "rustapi-rs" -version = "0.1.1" +version = "0.1.2" dependencies = [ "rustapi-core", "rustapi-extras", @@ -1715,7 +1762,7 @@ dependencies = [ [[package]] name = "rustapi-validate" -version = "0.1.1" +version = "0.1.2" dependencies = [ "http", "serde", diff --git a/Cargo.toml b/Cargo.toml index c0c81c8..72017c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,12 @@ members = [ "crates/rustapi-extras", "examples/hello-world", "examples/sqlx-crud", + "examples/crud-api", + "examples/auth-api", ] [workspace.package] -version = "0.1.1" +version = "0.1.2" edition = "2021" authors = ["RustAPI Contributors"] license = "MIT OR Apache-2.0" @@ -60,12 +62,18 @@ inventory = "0.3" # Validation validator = { version = "0.18", features = ["derive"] } +# UUID +uuid = { version = "1.6", features = ["v4"] } + +# Metrics +prometheus = "0.13" + # OpenAPI utoipa = { version = "4.2" } # Internal crates -rustapi-core = { path = "crates/rustapi-core", version = "0.1.1" } -rustapi-macros = { path = "crates/rustapi-macros", version = "0.1.1" } -rustapi-validate = { path = "crates/rustapi-validate", version = "0.1.1" } -rustapi-openapi = { path = "crates/rustapi-openapi", version = "0.1.1" } -rustapi-extras = { path = "crates/rustapi-extras", version = "0.1.1" } +rustapi-core = { path = "crates/rustapi-core", version = "0.1.2" } +rustapi-macros = { path = "crates/rustapi-macros", version = "0.1.2" } +rustapi-validate = { path = "crates/rustapi-validate", version = "0.1.2" } +rustapi-openapi = { path = "crates/rustapi-openapi", version = "0.1.2" } +rustapi-extras = { path = "crates/rustapi-extras", version = "0.1.2" } diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..7b413fa --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,190 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to the Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +Copyright 2024 RustAPI Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..4ab2d90 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 RustAPI Contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/rustapi-core/Cargo.toml b/crates/rustapi-core/Cargo.toml index b52eca0..9c93246 100644 --- a/crates/rustapi-core/Cargo.toml +++ b/crates/rustapi-core/Cargo.toml @@ -40,6 +40,8 @@ thiserror = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } inventory = { workspace = true } +uuid = { workspace = true } +base64 = "0.22" # Cookies (optional) cookie = { version = "0.18", optional = true } @@ -47,6 +49,9 @@ cookie = { version = "0.18", optional = true } # Validation rustapi-validate = { workspace = true } +# Metrics (optional) +prometheus = { workspace = true, optional = true } + # SQLx (optional) sqlx = { version = "0.8", optional = true, default-features = false } @@ -62,3 +67,4 @@ swagger-ui = ["rustapi-openapi/swagger-ui"] test-utils = [] cookies = ["dep:cookie"] sqlx = ["dep:sqlx"] +metrics = ["dep:prometheus"] diff --git a/crates/rustapi-core/proptest-regressions/error.txt b/crates/rustapi-core/proptest-regressions/error.txt new file mode 100644 index 0000000..8751fc6 --- /dev/null +++ b/crates/rustapi-core/proptest-regressions/error.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 8650e551b5fa0957402d6dd6c864be825443b29eb5b8972df6ef70e9ee040b20 # shrinks to sensitive_message = "-", internal_details = "0", status_code = 500 diff --git a/crates/rustapi-core/proptest-regressions/path_validation.txt b/crates/rustapi-core/proptest-regressions/path_validation.txt new file mode 100644 index 0000000..df2d723 --- /dev/null +++ b/crates/rustapi-core/proptest-regressions/path_validation.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 458c473729cae03b500e62a0348b15a6867e59812d9fc17c7628d2cd572019e8 # shrinks to prefix = "/", suffix = "" +cc ca5cfb24132bd9cfe692bdfb5cf121ea311b2c1462bd7afa8017a61d457636b2 # shrinks to prefix = "//", digit = "0", rest = "" +cc ba59f10f0210fd4599979fbdddd45e811cf8f38e1196156b3c2ff2bcee1f804a # shrinks to segments = [], params = ["a"] +cc 84630321a32cc432deb568fcc306a124f4a3be76c0ca95fda8d96d1b538e1ad5 # shrinks to prefix = "//", param_start = "a" +cc 1a33534f1eaf73d5c4f32d498ad9071bddbdb6814e602c79c7fedcd6d89071da # shrinks to prefix = "/", suffix = "" diff --git a/crates/rustapi-core/src/app.rs b/crates/rustapi-core/src/app.rs index 06c69d4..ccdf8f8 100644 --- a/crates/rustapi-core/src/app.rs +++ b/crates/rustapi-core/src/app.rs @@ -1,7 +1,7 @@ //! RustApi application builder use crate::error::Result; -use crate::middleware::{LayerStack, MiddlewareLayer}; +use crate::middleware::{BodyLimitLayer, LayerStack, MiddlewareLayer, DEFAULT_BODY_LIMIT}; use crate::router::{MethodRouter, Router}; use crate::server::Server; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; @@ -27,6 +27,7 @@ pub struct RustApi { router: Router, openapi_spec: rustapi_openapi::OpenApiSpec, layers: LayerStack, + body_limit: Option, } impl RustApi { @@ -48,9 +49,52 @@ impl RustApi { .register::() .register::(), layers: LayerStack::new(), + body_limit: Some(DEFAULT_BODY_LIMIT), // Default 1MB limit } } + /// Set the global body size limit for request bodies + /// + /// This protects against denial-of-service attacks via large payloads. + /// The default limit is 1MB (1024 * 1024 bytes). + /// + /// # Arguments + /// + /// * `limit` - Maximum body size in bytes + /// + /// # Example + /// + /// ```rust,ignore + /// use rustapi_rs::prelude::*; + /// + /// RustApi::new() + /// .body_limit(5 * 1024 * 1024) // 5MB limit + /// .route("/upload", post(upload_handler)) + /// .run("127.0.0.1:8080") + /// .await + /// ``` + pub fn body_limit(mut self, limit: usize) -> Self { + self.body_limit = Some(limit); + self + } + + /// Disable the body size limit + /// + /// Warning: This removes protection against large payload attacks. + /// Only use this if you have other mechanisms to limit request sizes. + /// + /// # Example + /// + /// ```rust,ignore + /// RustApi::new() + /// .no_body_limit() // Disable body size limit + /// .route("/upload", post(upload_handler)) + /// ``` + pub fn no_body_limit(mut self) -> Self { + self.body_limit = None; + self + } + /// Add a middleware layer to the application /// /// Layers are executed in the order they are added (outermost first). @@ -157,7 +201,7 @@ impl RustApi { self.route(path, method_router) } - /// Mount a route created with #[rustapi::get], #[rustapi::post], etc. + /// Mount a route created with `#[rustapi::get]`, `#[rustapi::post]`, etc. /// /// # Example /// @@ -253,12 +297,12 @@ impl RustApi { /// /// # Example /// - /// ```rust,ignore - // / RustApi::new() - // / .route("/users", get(list_users)) - // / .docs("/docs") // Swagger UI at /docs, spec at /docs/openapi.json - // / .run("127.0.0.1:8080") - // / .await + /// ```text + /// RustApi::new() + /// .route("/users", get(list_users)) + /// .docs("/docs") // Swagger UI at /docs, spec at /docs/openapi.json + /// .run("127.0.0.1:8080") + /// .await /// ``` #[cfg(feature = "swagger-ui")] pub fn docs(self, path: &str) -> Self { @@ -326,6 +370,129 @@ impl RustApi { .route(path, get(docs_handler)) } + /// Enable Swagger UI documentation with Basic Auth protection + /// + /// When username and password are provided, the docs endpoint will require + /// Basic Authentication. This is useful for protecting API documentation + /// in production environments. + /// + /// # Example + /// + /// ```rust,ignore + /// RustApi::new() + /// .route("/users", get(list_users)) + /// .docs_with_auth("/docs", "admin", "secret123") + /// .run("127.0.0.1:8080") + /// .await + /// ``` + #[cfg(feature = "swagger-ui")] + pub fn docs_with_auth(self, path: &str, username: &str, password: &str) -> Self { + let title = self.openapi_spec.info.title.clone(); + let version = self.openapi_spec.info.version.clone(); + let description = self.openapi_spec.info.description.clone(); + + self.docs_with_auth_and_info( + path, + username, + password, + &title, + &version, + description.as_deref(), + ) + } + + /// Enable Swagger UI documentation with Basic Auth and custom API info + /// + /// # Example + /// + /// ```rust,ignore + /// RustApi::new() + /// .docs_with_auth_and_info( + /// "/docs", + /// "admin", + /// "secret", + /// "My API", + /// "2.0.0", + /// Some("Protected API documentation") + /// ) + /// ``` + #[cfg(feature = "swagger-ui")] + pub fn docs_with_auth_and_info( + mut self, + path: &str, + username: &str, + password: &str, + title: &str, + version: &str, + description: Option<&str>, + ) -> Self { + use crate::router::MethodRouter; + use base64::{engine::general_purpose::STANDARD, Engine}; + use std::collections::HashMap; + + // Update spec info + self.openapi_spec.info.title = title.to_string(); + self.openapi_spec.info.version = version.to_string(); + if let Some(desc) = description { + self.openapi_spec.info.description = Some(desc.to_string()); + } + + let path = path.trim_end_matches('/'); + let openapi_path = format!("{}/openapi.json", path); + + // Create expected auth header value + let credentials = format!("{}:{}", username, password); + let encoded = STANDARD.encode(credentials.as_bytes()); + let expected_auth = format!("Basic {}", encoded); + + // Clone values for closures + let spec_json = + serde_json::to_string_pretty(&self.openapi_spec.to_json()).unwrap_or_default(); + let openapi_url = openapi_path.clone(); + let expected_auth_spec = expected_auth.clone(); + let expected_auth_docs = expected_auth; + + // Create spec handler with auth check + let spec_handler: crate::handler::BoxedHandler = std::sync::Arc::new(move |req: crate::Request| { + let json = spec_json.clone(); + let expected = expected_auth_spec.clone(); + Box::pin(async move { + if !check_basic_auth(&req, &expected) { + return unauthorized_response(); + } + http::Response::builder() + .status(http::StatusCode::OK) + .header(http::header::CONTENT_TYPE, "application/json") + .body(http_body_util::Full::new(bytes::Bytes::from(json))) + .unwrap() + }) as std::pin::Pin + Send>> + }); + + // Create docs handler with auth check + let docs_handler: crate::handler::BoxedHandler = std::sync::Arc::new(move |req: crate::Request| { + let url = openapi_url.clone(); + let expected = expected_auth_docs.clone(); + Box::pin(async move { + if !check_basic_auth(&req, &expected) { + return unauthorized_response(); + } + rustapi_openapi::swagger_ui_html(&url) + }) as std::pin::Pin + Send>> + }); + + // Create method routers with boxed handlers + let mut spec_handlers = HashMap::new(); + spec_handlers.insert(http::Method::GET, spec_handler); + let spec_router = MethodRouter::from_boxed(spec_handlers); + + let mut docs_handlers = HashMap::new(); + docs_handlers.insert(http::Method::GET, docs_handler); + let docs_router = MethodRouter::from_boxed(docs_handlers); + + self.route(&openapi_path, spec_router) + .route(path, docs_router) + } + /// Run the server /// /// # Example @@ -336,7 +503,13 @@ impl RustApi { /// .run("127.0.0.1:8080") /// .await /// ``` - pub async fn run(self, addr: &str) -> Result<(), Box> { + pub async fn run(mut self, addr: &str) -> Result<(), Box> { + // Apply body limit layer if configured (should be first in the chain) + if let Some(limit) = self.body_limit { + // Prepend body limit layer so it's the first to process requests + self.layers.prepend(Box::new(BodyLimitLayer::new(limit))); + } + let server = Server::new(self.router, self.layers); server.run(addr).await } @@ -357,3 +530,24 @@ impl Default for RustApi { Self::new() } } + +/// Check Basic Auth header against expected credentials +#[cfg(feature = "swagger-ui")] +fn check_basic_auth(req: &crate::Request, expected: &str) -> bool { + req.headers() + .get(http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .map(|auth| auth == expected) + .unwrap_or(false) +} + +/// Create 401 Unauthorized response with WWW-Authenticate header +#[cfg(feature = "swagger-ui")] +fn unauthorized_response() -> crate::Response { + http::Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .header(http::header::WWW_AUTHENTICATE, "Basic realm=\"API Documentation\"") + .header(http::header::CONTENT_TYPE, "text/plain") + .body(http_body_util::Full::new(bytes::Bytes::from("Unauthorized"))) + .unwrap() +} diff --git a/crates/rustapi-core/src/error.rs b/crates/rustapi-core/src/error.rs index f95c69f..49d0566 100644 --- a/crates/rustapi-core/src/error.rs +++ b/crates/rustapi-core/src/error.rs @@ -1,15 +1,206 @@ //! Error types for RustAPI +//! +//! This module provides structured error handling with environment-aware +//! error masking for production safety. +//! +//! # Error Response Format +//! +//! All errors are returned as JSON with a consistent structure: +//! +//! ```json +//! { +//! "error": { +//! "type": "not_found", +//! "message": "User not found", +//! "fields": null +//! }, +//! "error_id": "err_a1b2c3d4e5f6" +//! } +//! ``` +//! +//! # Environment-Aware Error Masking +//! +//! In production mode (`RUSTAPI_ENV=production`), internal server errors (5xx) +//! are masked to prevent information leakage: +//! +//! - **Production**: Generic "An internal error occurred" message +//! - **Development**: Full error details for debugging +//! +//! Validation errors always include field details regardless of environment. +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::{ApiError, Result}; +//! use http::StatusCode; +//! +//! async fn get_user(id: i64) -> Result> { +//! let user = db.find_user(id) +//! .ok_or_else(|| ApiError::not_found("User not found"))?; +//! Ok(Json(user)) +//! } +//! +//! // Create custom errors +//! let error = ApiError::new(StatusCode::CONFLICT, "duplicate", "Email already exists"); +//! +//! // Convenience constructors +//! let bad_request = ApiError::bad_request("Invalid input"); +//! let unauthorized = ApiError::unauthorized("Invalid token"); +//! let forbidden = ApiError::forbidden("Access denied"); +//! let not_found = ApiError::not_found("Resource not found"); +//! let internal = ApiError::internal("Something went wrong"); +//! ``` +//! +//! # Error ID Correlation +//! +//! Every error response includes a unique `error_id` (format: `err_{uuid}`) that +//! appears in both the response and server logs, enabling easy correlation for +//! debugging. use http::StatusCode; use serde::Serialize; use std::fmt; +use std::sync::OnceLock; +use uuid::Uuid; /// Result type alias for RustAPI operations pub type Result = std::result::Result; +/// Environment configuration for error handling behavior +/// +/// Controls whether internal error details are exposed in API responses. +/// In production, internal details are masked to prevent information leakage. +/// In development, full error details are shown for debugging. +/// +/// # Example +/// +/// ``` +/// use rustapi_core::Environment; +/// +/// let dev = Environment::Development; +/// assert!(dev.is_development()); +/// assert!(!dev.is_production()); +/// +/// let prod = Environment::Production; +/// assert!(prod.is_production()); +/// assert!(!prod.is_development()); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum Environment { + /// Development mode - shows full error details in responses + #[default] + Development, + /// Production mode - masks internal error details in responses + Production, +} + +impl Environment { + /// Detect environment from `RUSTAPI_ENV` environment variable + /// + /// Returns `Production` if `RUSTAPI_ENV` is set to "production" or "prod" (case-insensitive). + /// Returns `Development` for all other values or if the variable is not set. + /// + /// # Example + /// + /// ```bash + /// # Production mode + /// RUSTAPI_ENV=production cargo run + /// RUSTAPI_ENV=prod cargo run + /// + /// # Development mode (default) + /// RUSTAPI_ENV=development cargo run + /// cargo run # No env var set + /// ``` + pub fn from_env() -> Self { + match std::env::var("RUSTAPI_ENV") + .map(|s| s.to_lowercase()) + .as_deref() + { + Ok("production") | Ok("prod") => Environment::Production, + _ => Environment::Development, + } + } + + /// Check if this is production environment + pub fn is_production(&self) -> bool { + matches!(self, Environment::Production) + } + + /// Check if this is development environment + pub fn is_development(&self) -> bool { + matches!(self, Environment::Development) + } +} + +impl fmt::Display for Environment { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Environment::Development => write!(f, "development"), + Environment::Production => write!(f, "production"), + } + } +} + +/// Global environment setting, cached on first access +static ENVIRONMENT: OnceLock = OnceLock::new(); + +/// Get the current environment (cached) +/// +/// This function caches the environment on first call for performance. +/// The environment is detected from the `RUSTAPI_ENV` environment variable. +pub fn get_environment() -> Environment { + *ENVIRONMENT.get_or_init(Environment::from_env) +} + +/// Set the environment explicitly (for testing purposes) +/// +/// Note: This only works if the environment hasn't been accessed yet. +/// Returns `Ok(())` if successful, `Err(env)` if already set. +#[cfg(test)] +pub fn set_environment_for_test(env: Environment) -> Result<(), Environment> { + ENVIRONMENT.set(env) +} + +/// Generate a unique error ID using UUID v4 format +/// +/// Returns a string in the format `err_{uuid}` where uuid is a 32-character +/// hexadecimal string (UUID v4 simple format). +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::error::generate_error_id; +/// +/// let id = generate_error_id(); +/// assert!(id.starts_with("err_")); +/// assert_eq!(id.len(), 36); // "err_" (4) + uuid (32) +/// ``` +pub fn generate_error_id() -> String { + format!("err_{}", Uuid::new_v4().simple()) +} + /// Standard API error type /// /// Provides structured error responses following a consistent JSON format. +/// +/// # Example +/// +/// ``` +/// use rustapi_core::ApiError; +/// use http::StatusCode; +/// +/// // Create a custom error +/// let error = ApiError::new(StatusCode::CONFLICT, "duplicate", "Email already exists"); +/// assert_eq!(error.status, StatusCode::CONFLICT); +/// assert_eq!(error.error_type, "duplicate"); +/// +/// // Use convenience constructors +/// let not_found = ApiError::not_found("User not found"); +/// assert_eq!(not_found.status, StatusCode::NOT_FOUND); +/// +/// let bad_request = ApiError::bad_request("Invalid input"); +/// assert_eq!(bad_request.status, StatusCode::BAD_REQUEST); +/// ``` #[derive(Debug, Clone)] pub struct ApiError { /// HTTP status code @@ -105,14 +296,16 @@ impl std::error::Error for ApiError {} /// JSON representation of API error response #[derive(Serialize)] -pub(crate) struct ErrorResponse { +pub struct ErrorResponse { pub error: ErrorBody, + /// Unique error ID for log correlation (format: err_{uuid}) + pub error_id: String, #[serde(skip_serializing_if = "Option::is_none")] pub request_id: Option, } #[derive(Serialize)] -pub(crate) struct ErrorBody { +pub struct ErrorBody { #[serde(rename = "type")] pub error_type: String, pub message: String, @@ -120,19 +313,87 @@ pub(crate) struct ErrorBody { pub fields: Option>, } -impl From for ErrorResponse { - fn from(err: ApiError) -> Self { +impl ErrorResponse { + /// Create an ErrorResponse from an ApiError with environment-aware masking + /// + /// In production mode: + /// - Internal server errors (5xx) show generic messages + /// - Validation errors always include field details + /// - Client errors (4xx) show their messages + /// + /// In development mode: + /// - All error details are shown + pub fn from_api_error(err: ApiError, env: Environment) -> Self { + let error_id = generate_error_id(); + + // Always log the full error details with error_id for correlation + if err.status.is_server_error() { + tracing::error!( + error_id = %error_id, + error_type = %err.error_type, + message = %err.message, + status = %err.status.as_u16(), + internal = ?err.internal, + environment = %env, + "Server error occurred" + ); + } else if err.status.is_client_error() { + tracing::warn!( + error_id = %error_id, + error_type = %err.error_type, + message = %err.message, + status = %err.status.as_u16(), + environment = %env, + "Client error occurred" + ); + } else { + tracing::info!( + error_id = %error_id, + error_type = %err.error_type, + message = %err.message, + status = %err.status.as_u16(), + environment = %env, + "Error response generated" + ); + } + + // Determine the message and fields based on environment and error type + let (message, fields) = if env.is_production() && err.status.is_server_error() { + // In production, mask internal server error details + // But preserve validation error fields (they're always shown per requirement 3.5) + let masked_message = "An internal error occurred".to_string(); + // Validation errors keep their fields even in production + let fields = if err.error_type == "validation_error" { + err.fields + } else { + None + }; + (masked_message, fields) + } else { + // In development or for non-5xx errors, show full details + (err.message, err.fields) + }; + Self { error: ErrorBody { error_type: err.error_type, - message: err.message, - fields: err.fields, + message, + fields, }, - request_id: None, // TODO: inject from request context + error_id, + request_id: None, } } } +impl From for ErrorResponse { + fn from(err: ApiError) -> Self { + // Use the cached environment + let env = get_environment(); + Self::from_api_error(err, env) + } +} + // Conversion from common error types impl From for ApiError { fn from(err: serde_json::Error) -> Self { @@ -289,3 +550,516 @@ impl From for ApiError { } } + + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + use std::collections::HashSet; + + // **Feature: phase4-ergonomics-v1, Property 6: Error ID Uniqueness** + // + // For any sequence of N errors generated by the system, all N error IDs + // should be unique. The error ID should appear in both the HTTP response + // and the corresponding log entry. + // + // **Validates: Requirements 3.3** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_error_id_uniqueness( + // Generate a random number of errors between 10 and 200 + num_errors in 10usize..200, + ) { + // Generate N error IDs + let error_ids: Vec = (0..num_errors) + .map(|_| generate_error_id()) + .collect(); + + // Collect into a HashSet to check uniqueness + let unique_ids: HashSet<&String> = error_ids.iter().collect(); + + // All IDs should be unique + prop_assert_eq!( + unique_ids.len(), + error_ids.len(), + "Generated {} error IDs but only {} were unique", + error_ids.len(), + unique_ids.len() + ); + + // All IDs should follow the format err_{uuid} + for id in &error_ids { + prop_assert!( + id.starts_with("err_"), + "Error ID '{}' does not start with 'err_'", + id + ); + + // The UUID part should be 32 hex characters (simple format) + let uuid_part = &id[4..]; + prop_assert_eq!( + uuid_part.len(), + 32, + "UUID part '{}' should be 32 characters, got {}", + uuid_part, + uuid_part.len() + ); + + // All characters should be valid hex + prop_assert!( + uuid_part.chars().all(|c| c.is_ascii_hexdigit()), + "UUID part '{}' contains non-hex characters", + uuid_part + ); + } + } + } + + // **Feature: phase4-ergonomics-v1, Property 6: Error ID in Response** + // + // For any ApiError converted to ErrorResponse, the error_id field should + // be present and follow the correct format. + // + // **Validates: Requirements 3.3** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_error_response_contains_error_id( + error_type in "[a-z_]{1,20}", + message in "[a-zA-Z0-9 ]{1,100}", + ) { + let api_error = ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, error_type, message); + let error_response = ErrorResponse::from(api_error); + + // error_id should be present and follow format + prop_assert!( + error_response.error_id.starts_with("err_"), + "Error ID '{}' does not start with 'err_'", + error_response.error_id + ); + + let uuid_part = &error_response.error_id[4..]; + prop_assert_eq!(uuid_part.len(), 32); + prop_assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit())); + } + } + + #[test] + fn test_error_id_format() { + let error_id = generate_error_id(); + + // Should start with "err_" + assert!(error_id.starts_with("err_")); + + // Total length should be 4 (prefix) + 32 (uuid simple format) = 36 + assert_eq!(error_id.len(), 36); + + // UUID part should be valid hex + let uuid_part = &error_id[4..]; + assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn test_error_response_includes_error_id() { + let api_error = ApiError::bad_request("test error"); + let error_response = ErrorResponse::from(api_error); + + // error_id should be present + assert!(error_response.error_id.starts_with("err_")); + assert_eq!(error_response.error_id.len(), 36); + } + + #[test] + fn test_error_id_in_json_serialization() { + let api_error = ApiError::internal("test error"); + let error_response = ErrorResponse::from(api_error); + + let json = serde_json::to_string(&error_response).unwrap(); + + // JSON should contain error_id field + assert!(json.contains("\"error_id\":")); + assert!(json.contains("err_")); + } + + #[test] + fn test_multiple_error_ids_are_unique() { + let ids: Vec = (0..1000).map(|_| generate_error_id()).collect(); + let unique: HashSet<_> = ids.iter().collect(); + + assert_eq!(ids.len(), unique.len(), "All error IDs should be unique"); + } + + // **Feature: phase4-ergonomics-v1, Property 4: Production Error Masking** + // + // For any internal error (5xx) when RUSTAPI_ENV=production, the response body + // should contain only a generic error message and error ID, without stack traces, + // internal details, or sensitive information. + // + // **Validates: Requirements 3.1** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_production_error_masking( + // Generate random error messages that could contain sensitive info + // Use longer strings to avoid false positives where short strings appear in masked message + sensitive_message in "[a-zA-Z0-9_]{10,200}", + internal_details in "[a-zA-Z0-9_]{10,200}", + // Generate random 5xx status codes + status_code in prop::sample::select(vec![500u16, 501, 502, 503, 504, 505]), + ) { + // Create an internal error with potentially sensitive details + let api_error = ApiError::new( + StatusCode::from_u16(status_code).unwrap(), + "internal_error", + sensitive_message.clone() + ).with_internal(internal_details.clone()); + + // Convert to ErrorResponse in production mode + let error_response = ErrorResponse::from_api_error(api_error, Environment::Production); + + // The message should be masked to a generic message + prop_assert_eq!( + &error_response.error.message, + "An internal error occurred", + "Production 5xx error should have masked message, got: {}", + &error_response.error.message + ); + + // The original sensitive message should NOT appear in the response + // (only check if the message is long enough to be meaningful) + if sensitive_message.len() >= 10 { + prop_assert!( + !error_response.error.message.contains(&sensitive_message), + "Production error response should not contain original message" + ); + } + + // Internal details should NOT appear anywhere in the serialized response + let json = serde_json::to_string(&error_response).unwrap(); + if internal_details.len() >= 10 { + prop_assert!( + !json.contains(&internal_details), + "Production error response should not contain internal details" + ); + } + + // Error ID should still be present + prop_assert!( + error_response.error_id.starts_with("err_"), + "Error ID should be present in production error response" + ); + } + } + + // **Feature: phase4-ergonomics-v1, Property 5: Development Error Details** + // + // For any error when RUSTAPI_ENV=development, the response body should contain + // detailed error information including the original error message and any + // available context. + // + // **Validates: Requirements 3.2** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_development_error_details( + // Generate random error messages + error_message in "[a-zA-Z0-9 ]{1,100}", + error_type in "[a-z_]{1,20}", + // Generate random status codes (both 4xx and 5xx) + status_code in prop::sample::select(vec![400u16, 401, 403, 404, 500, 502, 503]), + ) { + // Create an error with details + let api_error = ApiError::new( + StatusCode::from_u16(status_code).unwrap(), + error_type.clone(), + error_message.clone() + ); + + // Convert to ErrorResponse in development mode + let error_response = ErrorResponse::from_api_error(api_error, Environment::Development); + + // The original message should be preserved + prop_assert_eq!( + error_response.error.message, + error_message, + "Development error should preserve original message" + ); + + // The error type should be preserved + prop_assert_eq!( + error_response.error.error_type, + error_type, + "Development error should preserve error type" + ); + + // Error ID should be present + prop_assert!( + error_response.error_id.starts_with("err_"), + "Error ID should be present in development error response" + ); + } + } + + // **Feature: phase4-ergonomics-v1, Property 7: Validation Error Field Details** + // + // For any validation error in any environment (production or development), + // the response should include field-level error details with field name, + // error code, and message. + // + // **Validates: Requirements 3.5** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_validation_error_field_details( + // Generate random field errors + field_name in "[a-z_]{1,20}", + field_code in "[a-z_]{1,15}", + field_message in "[a-zA-Z0-9 ]{1,50}", + // Test in both environments + is_production in proptest::bool::ANY, + ) { + let env = if is_production { + Environment::Production + } else { + Environment::Development + }; + + // Create a validation error with field details + let field_error = FieldError { + field: field_name.clone(), + code: field_code.clone(), + message: field_message.clone(), + }; + let api_error = ApiError::validation(vec![field_error]); + + // Convert to ErrorResponse + let error_response = ErrorResponse::from_api_error(api_error, env); + + // Fields should always be present for validation errors + prop_assert!( + error_response.error.fields.is_some(), + "Validation error should always include fields in {} mode", + env + ); + + let fields = error_response.error.fields.as_ref().unwrap(); + prop_assert_eq!( + fields.len(), + 1, + "Should have exactly one field error" + ); + + let field = &fields[0]; + + // Field name should be preserved + prop_assert_eq!( + &field.field, + &field_name, + "Field name should be preserved in {} mode", + env + ); + + // Field code should be preserved + prop_assert_eq!( + &field.code, + &field_code, + "Field code should be preserved in {} mode", + env + ); + + // Field message should be preserved + prop_assert_eq!( + &field.message, + &field_message, + "Field message should be preserved in {} mode", + env + ); + + // Verify JSON serialization includes all field details + let json = serde_json::to_string(&error_response).unwrap(); + prop_assert!( + json.contains(&field_name), + "JSON should contain field name in {} mode", + env + ); + prop_assert!( + json.contains(&field_code), + "JSON should contain field code in {} mode", + env + ); + prop_assert!( + json.contains(&field_message), + "JSON should contain field message in {} mode", + env + ); + } + } + + // Unit tests for Environment enum + // Note: These tests verify the Environment::from_env() logic by testing the parsing + // directly rather than modifying global environment variables (which causes race conditions + // in parallel test execution). + + #[test] + fn test_environment_from_env_production() { + // Test the parsing logic directly by simulating what from_env() does + // This avoids race conditions with parallel tests + + // Test "production" variants + assert!(matches!( + match "production".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Production + )); + + assert!(matches!( + match "prod".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Production + )); + + assert!(matches!( + match "PRODUCTION".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Production + )); + + assert!(matches!( + match "PROD".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Production + )); + } + + #[test] + fn test_environment_from_env_development() { + // Test the parsing logic directly by simulating what from_env() does + // This avoids race conditions with parallel tests + + // Test "development" and other variants that should default to Development + assert!(matches!( + match "development".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Development + )); + + assert!(matches!( + match "dev".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Development + )); + + assert!(matches!( + match "test".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Development + )); + + assert!(matches!( + match "anything_else".to_lowercase().as_str() { + "production" | "prod" => Environment::Production, + _ => Environment::Development, + }, + Environment::Development + )); + } + + #[test] + fn test_environment_default_is_development() { + // Test that the default is Development + assert_eq!(Environment::default(), Environment::Development); + } + + #[test] + fn test_environment_display() { + assert_eq!(format!("{}", Environment::Development), "development"); + assert_eq!(format!("{}", Environment::Production), "production"); + } + + #[test] + fn test_environment_is_methods() { + assert!(Environment::Production.is_production()); + assert!(!Environment::Production.is_development()); + assert!(Environment::Development.is_development()); + assert!(!Environment::Development.is_production()); + } + + #[test] + fn test_production_masks_5xx_errors() { + let error = ApiError::internal("Sensitive database connection string: postgres://user:pass@host"); + let response = ErrorResponse::from_api_error(error, Environment::Production); + + assert_eq!(response.error.message, "An internal error occurred"); + assert!(!response.error.message.contains("postgres")); + } + + #[test] + fn test_production_shows_4xx_errors() { + let error = ApiError::bad_request("Invalid email format"); + let response = ErrorResponse::from_api_error(error, Environment::Production); + + // 4xx errors should show their message even in production + assert_eq!(response.error.message, "Invalid email format"); + } + + #[test] + fn test_development_shows_all_errors() { + let error = ApiError::internal("Detailed error: connection refused to 192.168.1.1:5432"); + let response = ErrorResponse::from_api_error(error, Environment::Development); + + assert_eq!(response.error.message, "Detailed error: connection refused to 192.168.1.1:5432"); + } + + #[test] + fn test_validation_errors_always_show_fields() { + let fields = vec![ + FieldError { + field: "email".to_string(), + code: "invalid_format".to_string(), + message: "Invalid email format".to_string(), + }, + FieldError { + field: "age".to_string(), + code: "min".to_string(), + message: "Must be at least 18".to_string(), + }, + ]; + + let error = ApiError::validation(fields.clone()); + + // Test in production + let prod_response = ErrorResponse::from_api_error(error.clone(), Environment::Production); + assert!(prod_response.error.fields.is_some()); + let prod_fields = prod_response.error.fields.unwrap(); + assert_eq!(prod_fields.len(), 2); + assert_eq!(prod_fields[0].field, "email"); + assert_eq!(prod_fields[1].field, "age"); + + // Test in development + let dev_response = ErrorResponse::from_api_error(error, Environment::Development); + assert!(dev_response.error.fields.is_some()); + let dev_fields = dev_response.error.fields.unwrap(); + assert_eq!(dev_fields.len(), 2); + } +} diff --git a/crates/rustapi-core/src/extract.rs b/crates/rustapi-core/src/extract.rs index 3779cf5..f563264 100644 --- a/crates/rustapi-core/src/extract.rs +++ b/crates/rustapi-core/src/extract.rs @@ -1,6 +1,58 @@ //! Extractors for RustAPI //! -//! Extractors automatically parse and validate data from incoming requests. +//! Extractors automatically parse and validate data from incoming HTTP requests. +//! They implement the [`FromRequest`] or [`FromRequestParts`] traits and can be +//! used as handler function parameters. +//! +//! # Available Extractors +//! +//! | Extractor | Description | Consumes Body | +//! |-----------|-------------|---------------| +//! | [`Json`] | Parse JSON request body | Yes | +//! | [`ValidatedJson`] | Parse and validate JSON body | Yes | +//! | [`Query`] | Parse query string parameters | No | +//! | [`Path`] | Extract path parameters | No | +//! | [`State`] | Access shared application state | No | +//! | [`Body`] | Raw request body bytes | Yes | +//! | [`Headers`] | Access all request headers | No | +//! | [`HeaderValue`] | Extract a specific header | No | +//! | [`Extension`] | Access middleware-injected data | No | +//! | [`ClientIp`] | Extract client IP address | No | +//! | [`Cookies`] | Parse request cookies (requires `cookies` feature) | No | +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::{Json, Query, Path, State}; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Deserialize)] +//! struct CreateUser { +//! name: String, +//! email: String, +//! } +//! +//! #[derive(Deserialize)] +//! struct Pagination { +//! page: Option, +//! limit: Option, +//! } +//! +//! // Multiple extractors can be combined +//! async fn create_user( +//! State(db): State, +//! Query(pagination): Query, +//! Json(body): Json, +//! ) -> impl IntoResponse { +//! // Use db, pagination, and body... +//! } +//! ``` +//! +//! # Extractor Order +//! +//! When using multiple extractors, body-consuming extractors (like `Json` or `Body`) +//! must come last since they consume the request body. Non-body extractors can be +//! in any order. use crate::error::{ApiError, Result}; use crate::request::Request; diff --git a/crates/rustapi-core/src/handler.rs b/crates/rustapi-core/src/handler.rs index a7cea12..03f76a3 100644 --- a/crates/rustapi-core/src/handler.rs +++ b/crates/rustapi-core/src/handler.rs @@ -1,4 +1,58 @@ //! Handler trait and utilities +//! +//! This module provides the [`Handler`] trait and related types for defining +//! HTTP request handlers in RustAPI. +//! +//! # Handler Functions +//! +//! Any async function that takes extractors as parameters and returns a type +//! implementing [`IntoResponse`] can be used as a handler: +//! +//! ```rust,ignore +//! use rustapi_core::{Json, Path, IntoResponse}; +//! use serde::{Deserialize, Serialize}; +//! +//! // No parameters +//! async fn hello() -> &'static str { +//! "Hello, World!" +//! } +//! +//! // With extractors +//! async fn get_user(Path(id): Path) -> Json { +//! Json(User { id, name: "Alice".to_string() }) +//! } +//! +//! // Multiple extractors (up to 5 supported) +//! async fn create_user( +//! State(db): State, +//! Json(body): Json, +//! ) -> Result, ApiError> { +//! // ... +//! } +//! ``` +//! +//! # Route Helpers +//! +//! The module provides helper functions for creating routes with metadata: +//! +//! ```rust,ignore +//! use rustapi_core::handler::{get_route, post_route}; +//! +//! let get = get_route("/users", list_users); +//! let post = post_route("/users", create_user); +//! ``` +//! +//! # Macro-Based Routing +//! +//! For more ergonomic routing, use the `#[rustapi::get]`, `#[rustapi::post]`, etc. +//! macros from `rustapi-macros`: +//! +//! ```rust,ignore +//! #[rustapi::get("/users/{id}")] +//! async fn get_user(Path(id): Path) -> Json { +//! // ... +//! } +//! ``` use crate::extract::FromRequest; use crate::request::Request; @@ -275,7 +329,7 @@ where }) } -/// Trait for handlers with route metadata (generated by #[rustapi::get], etc.) +/// Trait for handlers with route metadata (generated by `#[rustapi::get]`, etc.) /// /// This trait provides the path and method information for a handler, /// allowing `.mount(handler)` to automatically register the route. diff --git a/crates/rustapi-core/src/lib.rs b/crates/rustapi-core/src/lib.rs index 6d2069f..b4fa514 100644 --- a/crates/rustapi-core/src/lib.rs +++ b/crates/rustapi-core/src/lib.rs @@ -2,23 +2,70 @@ //! //! Core library providing the foundational types and traits for RustAPI. //! -//! This crate is not meant to be used directly. Use `rustapi-rs` instead. +//! This crate provides the essential building blocks for the RustAPI web framework: +//! +//! - **Application Builder**: [`RustApi`] - The main entry point for building web applications +//! - **Routing**: [`Router`], [`get`], [`post`], [`put`], [`patch`], [`delete`] - HTTP routing primitives +//! - **Extractors**: [`Json`], [`Query`], [`Path`], [`State`], [`Body`], [`Headers`] - Request data extraction +//! - **Responses**: [`IntoResponse`], [`Created`], [`NoContent`], [`Html`], [`Redirect`] - Response types +//! - **Middleware**: [`BodyLimitLayer`], [`RequestIdLayer`], [`TracingLayer`] - Request processing layers +//! - **Error Handling**: [`ApiError`], [`Result`] - Structured error responses +//! - **Testing**: `TestClient` - Integration testing without network binding (requires `test-utils` feature) +//! +//! ## Quick Start +//! +//! ```rust,ignore +//! use rustapi_core::{RustApi, get, Json}; +//! use serde::Serialize; +//! +//! #[derive(Serialize)] +//! struct Message { +//! text: String, +//! } +//! +//! async fn hello() -> Json { +//! Json(Message { text: "Hello, World!".to_string() }) +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! RustApi::new() +//! .route("/", get(hello)) +//! .run("127.0.0.1:8080") +//! .await +//! } +//! ``` +//! +//! ## Feature Flags +//! +//! - `metrics` - Enable Prometheus metrics middleware +//! - `cookies` - Enable cookie parsing extractor +//! - `test-utils` - Enable testing utilities like `TestClient` +//! - `swagger-ui` - Enable Swagger UI documentation endpoint +//! +//! ## Note +//! +//! This crate is typically not used directly. Use `rustapi-rs` instead for the +//! full framework experience with all features and re-exports. mod app; mod error; mod extract; mod handler; pub mod middleware; +pub mod path_validation; mod request; mod response; mod router; mod server; pub mod sse; pub mod stream; +#[cfg(any(test, feature = "test-utils"))] +mod test_client; // Public API pub use app::RustApi; -pub use error::{ApiError, Result}; +pub use error::{ApiError, Environment, Result, get_environment}; pub use extract::{Body, ClientIp, Extension, FromRequest, FromRequestParts, HeaderValue, Headers, Json, Path, Query, State, ValidatedJson}; #[cfg(feature = "cookies")] pub use extract::Cookies; @@ -26,9 +73,13 @@ pub use handler::{ Handler, HandlerService, Route, RouteHandler, get_route, post_route, put_route, patch_route, delete_route, }; -pub use middleware::{RequestId, RequestIdLayer, TracingLayer}; +pub use middleware::{BodyLimitLayer, RequestId, RequestIdLayer, TracingLayer, DEFAULT_BODY_LIMIT}; +#[cfg(feature = "metrics")] +pub use middleware::{MetricsLayer, MetricsResponse}; pub use request::Request; pub use response::{Created, Html, IntoResponse, NoContent, Redirect, Response, WithStatus}; pub use router::{delete, get, patch, post, put, MethodRouter, Router}; pub use sse::{Sse, SseEvent}; pub use stream::StreamBody; +#[cfg(any(test, feature = "test-utils"))] +pub use test_client::{TestClient, TestRequest, TestResponse}; diff --git a/crates/rustapi-core/src/middleware/body_limit.rs b/crates/rustapi-core/src/middleware/body_limit.rs new file mode 100644 index 0000000..d240c5f --- /dev/null +++ b/crates/rustapi-core/src/middleware/body_limit.rs @@ -0,0 +1,326 @@ +//! Body size limit middleware for RustAPI +//! +//! This module provides middleware to enforce request body size limits, +//! protecting against denial-of-service attacks via large payloads. +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_rs::prelude::*; +//! use rustapi_core::middleware::BodyLimitLayer; +//! +//! RustApi::new() +//! .layer(BodyLimitLayer::new(1024 * 1024)) // 1MB limit +//! .route("/upload", post(upload_handler)) +//! .run("127.0.0.1:8080") +//! .await +//! ``` + +use crate::error::ApiError; +use crate::request::Request; +use crate::response::{IntoResponse, Response}; +use super::{BoxedNext, MiddlewareLayer}; +use http::StatusCode; +use std::future::Future; +use std::pin::Pin; + +/// Default body size limit: 1MB +pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024; + +/// Body size limit middleware layer +/// +/// Enforces a maximum size for request bodies. When a request body exceeds +/// the configured limit, a 413 Payload Too Large response is returned. +#[derive(Clone)] +pub struct BodyLimitLayer { + limit: usize, +} + +impl BodyLimitLayer { + /// Create a new body limit layer with the specified limit in bytes + /// + /// # Arguments + /// + /// * `limit` - Maximum body size in bytes + /// + /// # Example + /// + /// ```rust,ignore + /// // 2MB limit + /// let layer = BodyLimitLayer::new(2 * 1024 * 1024); + /// ``` + pub fn new(limit: usize) -> Self { + Self { limit } + } + + /// Create a body limit layer with the default limit (1MB) + pub fn default_limit() -> Self { + Self::new(DEFAULT_BODY_LIMIT) + } + + /// Get the configured limit + pub fn limit(&self) -> usize { + self.limit + } +} + +impl Default for BodyLimitLayer { + fn default() -> Self { + Self::default_limit() + } +} + +impl MiddlewareLayer for BodyLimitLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let limit = self.limit; + + Box::pin(async move { + // Check Content-Length header first if available + if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) { + if let Ok(length_str) = content_length.to_str() { + if let Ok(length) = length_str.parse::() { + if length > limit { + return ApiError::new( + StatusCode::PAYLOAD_TOO_LARGE, + "payload_too_large", + format!("Request body exceeds limit of {} bytes", limit), + ) + .into_response(); + } + } + } + } + + // Also check actual body size (for cases without Content-Length or streaming) + // The body has already been read at this point in the pipeline + if let Some(body) = &req.body { + if body.len() > limit { + return ApiError::new( + StatusCode::PAYLOAD_TOO_LARGE, + "payload_too_large", + format!("Request body exceeds limit of {} bytes", limit), + ) + .into_response(); + } + } + + // Body is within limits, continue to next middleware/handler + next(req).await + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::request::Request; + use bytes::Bytes; + use http::{Extensions, Method}; + use proptest::prelude::*; + use std::collections::HashMap; + use std::sync::Arc; + + /// Create a test request with the given body + fn create_test_request_with_body(body: Bytes) -> Request { + let uri: http::Uri = "/test".parse().unwrap(); + let mut builder = http::Request::builder().method(Method::POST).uri(uri); + + // Set Content-Length header + builder = builder.header(http::header::CONTENT_LENGTH, body.len().to_string()); + + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + Request::new(parts, body, Arc::new(Extensions::new()), HashMap::new()) + } + + /// Create a test request without Content-Length header + fn create_test_request_without_content_length(body: Bytes) -> Request { + let uri: http::Uri = "/test".parse().unwrap(); + let builder = http::Request::builder().method(Method::POST).uri(uri); + + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + Request::new(parts, body, Arc::new(Extensions::new()), HashMap::new()) + } + + /// Create a simple handler that returns 200 OK + fn ok_handler() -> BoxedNext { + Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(StatusCode::OK) + .body(http_body_util::Full::new(Bytes::from("ok"))) + .unwrap() + }) as Pin + Send + 'static>> + }) + } + + + // **Feature: phase4-ergonomics-v1, Property 3: Body Size Limit Enforcement** + // + // For any configured body size limit L and any request body B where size(B) > L, + // the system should return a 413 Payload Too Large response. + // + // **Validates: Requirements 2.2, 2.3, 2.4, 2.5** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_body_size_limit_enforcement( + // Generate limit between 1 and 10KB for testing + limit in 1usize..10240usize, + // Generate body size relative to limit + body_size_factor in 0.5f64..2.0f64, + ) { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let body_size = ((limit as f64) * body_size_factor) as usize; + let body = Bytes::from(vec![b'x'; body_size]); + let request = create_test_request_with_body(body.clone()); + + let layer = BodyLimitLayer::new(limit); + let handler = ok_handler(); + + let response = layer.call(request, handler).await; + + if body_size > limit { + // Body exceeds limit - should return 413 + prop_assert_eq!( + response.status(), + StatusCode::PAYLOAD_TOO_LARGE, + "Expected 413 for body size {} > limit {}", + body_size, + limit + ); + } else { + // Body within limit - should return 200 + prop_assert_eq!( + response.status(), + StatusCode::OK, + "Expected 200 for body size {} <= limit {}", + body_size, + limit + ); + } + + Ok(()) + })?; + } + + #[test] + fn prop_body_limit_without_content_length_header( + limit in 1usize..10240usize, + body_size_factor in 0.5f64..2.0f64, + ) { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let body_size = ((limit as f64) * body_size_factor) as usize; + let body = Bytes::from(vec![b'x'; body_size]); + // Create request without Content-Length header + let request = create_test_request_without_content_length(body.clone()); + + let layer = BodyLimitLayer::new(limit); + let handler = ok_handler(); + + let response = layer.call(request, handler).await; + + if body_size > limit { + // Body exceeds limit - should return 413 + prop_assert_eq!( + response.status(), + StatusCode::PAYLOAD_TOO_LARGE, + "Expected 413 for body size {} > limit {} (no Content-Length)", + body_size, + limit + ); + } else { + // Body within limit - should return 200 + prop_assert_eq!( + response.status(), + StatusCode::OK, + "Expected 200 for body size {} <= limit {} (no Content-Length)", + body_size, + limit + ); + } + + Ok(()) + })?; + } + } + + #[tokio::test] + async fn test_body_at_exact_limit() { + let limit = 100; + let body = Bytes::from(vec![b'x'; limit]); + let request = create_test_request_with_body(body); + + let layer = BodyLimitLayer::new(limit); + let handler = ok_handler(); + + let response = layer.call(request, handler).await; + assert_eq!(response.status(), StatusCode::OK); + } + + #[tokio::test] + async fn test_body_one_byte_over_limit() { + let limit = 100; + let body = Bytes::from(vec![b'x'; limit + 1]); + let request = create_test_request_with_body(body); + + let layer = BodyLimitLayer::new(limit); + let handler = ok_handler(); + + let response = layer.call(request, handler).await; + assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); + } + + #[tokio::test] + async fn test_body_one_byte_under_limit() { + let limit = 100; + let body = Bytes::from(vec![b'x'; limit - 1]); + let request = create_test_request_with_body(body); + + let layer = BodyLimitLayer::new(limit); + let handler = ok_handler(); + + let response = layer.call(request, handler).await; + assert_eq!(response.status(), StatusCode::OK); + } + + #[tokio::test] + async fn test_empty_body() { + let limit = 100; + let body = Bytes::new(); + let request = create_test_request_with_body(body); + + let layer = BodyLimitLayer::new(limit); + let handler = ok_handler(); + + let response = layer.call(request, handler).await; + assert_eq!(response.status(), StatusCode::OK); + } + + #[tokio::test] + async fn test_default_limit() { + let layer = BodyLimitLayer::default(); + assert_eq!(layer.limit(), DEFAULT_BODY_LIMIT); + } + + #[test] + fn test_clone() { + let layer = BodyLimitLayer::new(1024); + let cloned = layer.clone(); + assert_eq!(layer.limit(), cloned.limit()); + } +} diff --git a/crates/rustapi-core/src/middleware/layer.rs b/crates/rustapi-core/src/middleware/layer.rs index 81813e3..fb1b25c 100644 --- a/crates/rustapi-core/src/middleware/layer.rs +++ b/crates/rustapi-core/src/middleware/layer.rs @@ -66,6 +66,13 @@ impl LayerStack { self.layers.push(layer); } + /// Add a middleware layer to the beginning of the stack + /// + /// This layer will be executed first (outermost). + pub fn prepend(&mut self, layer: Box) { + self.layers.insert(0, layer); + } + /// Check if the stack is empty pub fn is_empty(&self) -> bool { self.layers.is_empty() diff --git a/crates/rustapi-core/src/middleware/metrics.rs b/crates/rustapi-core/src/middleware/metrics.rs new file mode 100644 index 0000000..4c3ab45 --- /dev/null +++ b/crates/rustapi-core/src/middleware/metrics.rs @@ -0,0 +1,592 @@ +//! Prometheus Metrics middleware +//! +//! Provides HTTP request metrics collection and a `/metrics` endpoint for Prometheus scraping. +//! +//! This module is only available when the `metrics` feature is enabled. +//! +//! # Metrics Collected +//! +//! - `http_requests_total` - Counter with labels: method, path, status +//! - `http_request_duration_seconds` - Histogram with labels: method, path +//! - `rustapi_info` - Gauge with label: version +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::middleware::MetricsLayer; +//! +//! let metrics = MetricsLayer::new(); +//! +//! RustApi::new() +//! .layer(metrics.clone()) +//! .route("/metrics", get(metrics.handler())) +//! .run("127.0.0.1:8080") +//! .await +//! ``` + +use super::layer::{BoxedNext, MiddlewareLayer}; +use crate::request::Request; +use crate::response::Response; +use bytes::Bytes; +use prometheus::{ + Encoder, GaugeVec, HistogramOpts, HistogramVec, IntCounterVec, Opts, Registry, TextEncoder, +}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Instant; + +/// Default histogram buckets for request duration (in seconds) +const DEFAULT_BUCKETS: &[f64] = &[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]; + +/// Prometheus metrics middleware layer +/// +/// Collects HTTP request metrics and provides a handler for the `/metrics` endpoint. +/// +/// # Metrics +/// +/// - `http_requests_total{method, path, status}` - Total number of HTTP requests +/// - `http_request_duration_seconds{method, path}` - HTTP request duration histogram +/// - `rustapi_info{version}` - RustAPI version information gauge +#[derive(Clone)] +pub struct MetricsLayer { + inner: Arc, +} + +struct MetricsInner { + registry: Registry, + requests_total: IntCounterVec, + request_duration: HistogramVec, + #[allow(dead_code)] + info_gauge: GaugeVec, +} + +impl MetricsLayer { + /// Create a new MetricsLayer with default configuration + /// + /// This creates a new Prometheus registry and registers the default metrics. + pub fn new() -> Self { + let registry = Registry::new(); + Self::with_registry(registry) + } + + /// Create a new MetricsLayer with a custom registry + /// + /// Use this if you want to share a registry with other metrics collectors. + pub fn with_registry(registry: Registry) -> Self { + // Create http_requests_total counter + let requests_total = IntCounterVec::new( + Opts::new("http_requests_total", "Total number of HTTP requests"), + &["method", "path", "status"], + ) + .expect("Failed to create http_requests_total metric"); + + // Create http_request_duration_seconds histogram + let request_duration = HistogramVec::new( + HistogramOpts::new( + "http_request_duration_seconds", + "HTTP request duration in seconds", + ) + .buckets(DEFAULT_BUCKETS.to_vec()), + &["method", "path"], + ) + .expect("Failed to create http_request_duration_seconds metric"); + + // Create rustapi_info gauge + let info_gauge = GaugeVec::new( + Opts::new("rustapi_info", "RustAPI version information"), + &["version"], + ) + .expect("Failed to create rustapi_info metric"); + + // Register metrics + registry + .register(Box::new(requests_total.clone())) + .expect("Failed to register http_requests_total"); + registry + .register(Box::new(request_duration.clone())) + .expect("Failed to register http_request_duration_seconds"); + registry + .register(Box::new(info_gauge.clone())) + .expect("Failed to register rustapi_info"); + + // Set version info + let version = env!("CARGO_PKG_VERSION"); + info_gauge.with_label_values(&[version]).set(1.0); + + Self { + inner: Arc::new(MetricsInner { + registry, + requests_total, + request_duration, + info_gauge, + }), + } + } + + /// Get the Prometheus registry + /// + /// Use this to register additional custom metrics. + pub fn registry(&self) -> &Registry { + &self.inner.registry + } + + /// Create a handler function for the `/metrics` endpoint + /// + /// Returns metrics in Prometheus text format. + /// + /// # Example + /// + /// ```rust,ignore + /// let metrics = MetricsLayer::new(); + /// app.route("/metrics", get(metrics.handler())); + /// ``` + pub fn handler(&self) -> impl Fn() -> MetricsResponse + Clone + Send + Sync + 'static { + let registry = self.inner.registry.clone(); + move || { + let encoder = TextEncoder::new(); + let metric_families = registry.gather(); + let mut buffer = Vec::new(); + encoder + .encode(&metric_families, &mut buffer) + .expect("Failed to encode metrics"); + MetricsResponse(buffer) + } + } + + /// Record a request with the given method, path, status, and duration + fn record_request(&self, method: &str, path: &str, status: u16, duration_secs: f64) { + // Normalize path to avoid high cardinality + let normalized_path = normalize_path(path); + + // Increment request counter + self.inner + .requests_total + .with_label_values(&[method, &normalized_path, &status.to_string()]) + .inc(); + + // Record duration + self.inner + .request_duration + .with_label_values(&[method, &normalized_path]) + .observe(duration_secs); + } +} + +impl Default for MetricsLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for MetricsLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let method = req.method().to_string(); + let path = req.uri().path().to_string(); + let metrics = self.clone(); + + Box::pin(async move { + let start = Instant::now(); + + // Call the next handler + let response = next(req).await; + + // Record metrics + let duration = start.elapsed().as_secs_f64(); + let status = response.status().as_u16(); + metrics.record_request(&method, &path, status, duration); + + response + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +/// Response type for the metrics endpoint +pub struct MetricsResponse(Vec); + +impl crate::response::IntoResponse for MetricsResponse { + fn into_response(self) -> Response { + http::Response::builder() + .status(http::StatusCode::OK) + .header( + http::header::CONTENT_TYPE, + "text/plain; version=0.0.4; charset=utf-8", + ) + .body(http_body_util::Full::new(Bytes::from(self.0))) + .unwrap() + } +} + +/// Normalize a path to reduce cardinality +/// +/// This replaces path segments that look like IDs (UUIDs, numbers) with placeholders. +fn normalize_path(path: &str) -> String { + let segments: Vec<&str> = path.split('/').collect(); + let normalized: Vec = segments + .into_iter() + .map(|segment| { + if segment.is_empty() { + String::new() + } else if is_id_like(segment) { + ":id".to_string() + } else { + segment.to_string() + } + }) + .collect(); + normalized.join("/") +} + +/// Check if a path segment looks like an ID +fn is_id_like(segment: &str) -> bool { + // Check for UUID format + if segment.len() == 36 && segment.chars().filter(|c| *c == '-').count() == 4 { + return true; + } + + // Check for numeric ID + if segment.chars().all(|c| c.is_ascii_digit()) && !segment.is_empty() { + return true; + } + + // Check for hex string (common for IDs) + if segment.len() >= 8 && segment.chars().all(|c| c.is_ascii_hexdigit()) { + return true; + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::middleware::layer::{BoxedNext, LayerStack}; + use http::{Extensions, Method, StatusCode}; + use proptest::prelude::*; + use proptest::test_runner::TestCaseError; + use std::collections::HashMap; + use std::sync::Arc; + + /// Create a test request with the given method and path + fn create_test_request(method: Method, path: &str) -> crate::request::Request { + let uri: http::Uri = path.parse().unwrap(); + let builder = http::Request::builder().method(method).uri(uri); + + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + crate::request::Request::new( + parts, + Bytes::new(), + Arc::new(Extensions::new()), + HashMap::new(), + ) + } + + #[test] + fn test_metrics_layer_creation() { + let metrics = MetricsLayer::new(); + assert!(metrics.registry().gather().len() > 0); + } + + #[test] + fn test_metrics_handler_returns_prometheus_format() { + let metrics = MetricsLayer::new(); + let handler = metrics.handler(); + let response = handler(); + + // Convert to response and check content type + let http_response = crate::response::IntoResponse::into_response(response); + assert_eq!(http_response.status(), StatusCode::OK); + + let content_type = http_response + .headers() + .get(http::header::CONTENT_TYPE) + .unwrap(); + assert!(content_type + .to_str() + .unwrap() + .contains("text/plain")); + } + + #[test] + fn test_normalize_path_with_uuid() { + let path = "/users/550e8400-e29b-41d4-a716-446655440000/posts"; + let normalized = normalize_path(path); + assert_eq!(normalized, "/users/:id/posts"); + } + + #[test] + fn test_normalize_path_with_numeric_id() { + let path = "/users/12345/posts"; + let normalized = normalize_path(path); + assert_eq!(normalized, "/users/:id/posts"); + } + + #[test] + fn test_normalize_path_without_ids() { + let path = "/users/profile/settings"; + let normalized = normalize_path(path); + assert_eq!(normalized, "/users/profile/settings"); + } + + #[test] + fn test_is_id_like() { + // UUIDs + assert!(is_id_like("550e8400-e29b-41d4-a716-446655440000")); + + // Numeric IDs + assert!(is_id_like("12345")); + assert!(is_id_like("1")); + + // Hex strings + assert!(is_id_like("deadbeef")); + assert!(is_id_like("abc123def456")); + + // Not IDs + assert!(!is_id_like("users")); + assert!(!is_id_like("profile")); + assert!(!is_id_like("")); + } + + #[test] + fn test_rustapi_info_gauge_set() { + let metrics = MetricsLayer::new(); + let handler = metrics.handler(); + let response = handler(); + + let http_response = crate::response::IntoResponse::into_response(response); + let body = http_response.into_body(); + + // The body should contain rustapi_info metric + // We can't easily read the body here, but we verified the metric is registered + } + + // **Feature: phase4-ergonomics-v1, Property 9: Request Metrics Recording** + // + // For any HTTP request processed by the system with the `metrics` feature enabled, + // the `http_requests_total` counter should be incremented with correct method, path, + // and status labels, and the `http_request_duration_seconds` histogram should record + // the request duration. + // + // **Validates: Requirements 5.2, 5.3** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_request_metrics_recording( + method_idx in 0usize..5usize, + path in "/[a-z]{1,10}", + status_code in 200u16..600u16, + ) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let result: Result<(), TestCaseError> = rt.block_on(async { + // Create a fresh metrics layer for each test + let metrics = MetricsLayer::new(); + + // Create middleware stack + let mut stack = LayerStack::new(); + stack.push(Box::new(metrics.clone())); + + // Map index to HTTP method + let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH]; + let method = methods[method_idx].clone(); + + // Create handler that returns the specified status + let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK); + let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| { + let status = response_status; + Box::pin(async move { + http::Response::builder() + .status(status) + .body(http_body_util::Full::new(Bytes::from("test"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + // Execute request + let request = create_test_request(method.clone(), &path); + let response = stack.execute(request, handler).await; + + // Verify response status matches + prop_assert_eq!(response.status(), response_status); + + // Verify metrics were recorded + let metric_families = metrics.registry().gather(); + + // Find http_requests_total metric + let requests_total = metric_families + .iter() + .find(|mf| mf.get_name() == "http_requests_total"); + prop_assert!( + requests_total.is_some(), + "http_requests_total metric should exist" + ); + + let requests_total = requests_total.unwrap(); + let metrics_vec = requests_total.get_metric(); + + // Find the metric with matching labels + let matching_metric = metrics_vec.iter().find(|m| { + let labels = m.get_label(); + let method_label = labels.iter().find(|l| l.get_name() == "method"); + let path_label = labels.iter().find(|l| l.get_name() == "path"); + let status_label = labels.iter().find(|l| l.get_name() == "status"); + + method_label.map(|l| l.get_value()) == Some(method.as_str()) + && path_label.map(|l| l.get_value()) == Some(&path) + && status_label.map(|l| l.get_value()) == Some(&status_code.to_string()) + }); + + prop_assert!( + matching_metric.is_some(), + "Should have metric with method={}, path={}, status={}. Available metrics: {:?}", + method.as_str(), + path, + status_code, + metrics_vec.iter().map(|m| m.get_label()).collect::>() + ); + + // Verify counter was incremented + let counter_value = matching_metric.unwrap().get_counter().get_value(); + prop_assert!( + counter_value >= 1.0, + "Counter should be at least 1, got {}", + counter_value + ); + + // Find http_request_duration_seconds metric + let duration_metric = metric_families + .iter() + .find(|mf| mf.get_name() == "http_request_duration_seconds"); + prop_assert!( + duration_metric.is_some(), + "http_request_duration_seconds metric should exist" + ); + + let duration_metric = duration_metric.unwrap(); + let duration_vec = duration_metric.get_metric(); + + // Find the histogram with matching labels + let matching_histogram = duration_vec.iter().find(|m| { + let labels = m.get_label(); + let method_label = labels.iter().find(|l| l.get_name() == "method"); + let path_label = labels.iter().find(|l| l.get_name() == "path"); + + method_label.map(|l| l.get_value()) == Some(method.as_str()) + && path_label.map(|l| l.get_value()) == Some(&path) + }); + + prop_assert!( + matching_histogram.is_some(), + "Should have histogram with method={}, path={}", + method.as_str(), + path + ); + + // Verify histogram has recorded at least one observation + let histogram = matching_histogram.unwrap().get_histogram(); + prop_assert!( + histogram.get_sample_count() >= 1, + "Histogram should have at least 1 sample, got {}", + histogram.get_sample_count() + ); + + // Verify duration is reasonable (less than 10 seconds) + let sum = histogram.get_sample_sum(); + prop_assert!( + sum < 10.0, + "Duration sum should be less than 10 seconds, got {}", + sum + ); + + Ok(()) + }); + result?; + } + } + + #[test] + fn test_metrics_layer_records_request() { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let metrics = MetricsLayer::new(); + + let mut stack = LayerStack::new(); + stack.push(Box::new(metrics.clone())); + + let handler: BoxedNext = Arc::new(|_req: crate::request::Request| { + Box::pin(async { + http::Response::builder() + .status(StatusCode::OK) + .body(http_body_util::Full::new(Bytes::from("ok"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let request = create_test_request(Method::GET, "/test"); + let response = stack.execute(request, handler).await; + + assert_eq!(response.status(), StatusCode::OK); + + // Verify metrics were recorded + let metric_families = metrics.registry().gather(); + let requests_total = metric_families + .iter() + .find(|mf| mf.get_name() == "http_requests_total"); + assert!(requests_total.is_some()); + }); + } + + #[test] + fn test_metrics_layer_with_multiple_requests() { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let metrics = MetricsLayer::new(); + + let mut stack = LayerStack::new(); + stack.push(Box::new(metrics.clone())); + + let handler: BoxedNext = Arc::new(|_req: crate::request::Request| { + Box::pin(async { + http::Response::builder() + .status(StatusCode::OK) + .body(http_body_util::Full::new(Bytes::from("ok"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + // Send multiple requests + for _ in 0..5 { + let request = create_test_request(Method::GET, "/test"); + let _ = stack.execute(request, handler.clone()).await; + } + + // Verify counter was incremented 5 times + let metric_families = metrics.registry().gather(); + let requests_total = metric_families + .iter() + .find(|mf| mf.get_name() == "http_requests_total") + .unwrap(); + + let metrics_vec = requests_total.get_metric(); + let matching_metric = metrics_vec.iter().find(|m| { + let labels = m.get_label(); + labels.iter().any(|l| l.get_name() == "method" && l.get_value() == "GET") + && labels.iter().any(|l| l.get_name() == "path" && l.get_value() == "/test") + && labels.iter().any(|l| l.get_name() == "status" && l.get_value() == "200") + }); + + assert!(matching_metric.is_some()); + assert_eq!(matching_metric.unwrap().get_counter().get_value(), 5.0); + }); + } +} diff --git a/crates/rustapi-core/src/middleware/mod.rs b/crates/rustapi-core/src/middleware/mod.rs index 98f103c..a3c1e95 100644 --- a/crates/rustapi-core/src/middleware/mod.rs +++ b/crates/rustapi-core/src/middleware/mod.rs @@ -16,10 +16,16 @@ //! .await //! ``` +mod body_limit; mod layer; +#[cfg(feature = "metrics")] +mod metrics; mod request_id; mod tracing_layer; +pub use body_limit::{BodyLimitLayer, DEFAULT_BODY_LIMIT}; pub use layer::{BoxedNext, LayerStack, MiddlewareLayer}; +#[cfg(feature = "metrics")] +pub use metrics::{MetricsLayer, MetricsResponse}; pub use request_id::{RequestId, RequestIdLayer}; pub use tracing_layer::TracingLayer; diff --git a/crates/rustapi-core/src/middleware/tracing_layer.rs b/crates/rustapi-core/src/middleware/tracing_layer.rs index c52bd04..394f724 100644 --- a/crates/rustapi-core/src/middleware/tracing_layer.rs +++ b/crates/rustapi-core/src/middleware/tracing_layer.rs @@ -1,30 +1,75 @@ -//! Tracing middleware +//! Enhanced Tracing middleware //! -//! Logs request method, path, status code, and duration for each request. +//! Logs request method, path, request_id, status code, and duration for each request. +//! Supports custom fields that are included in all request spans. use super::layer::{BoxedNext, MiddlewareLayer}; +use super::request_id::RequestId; use crate::request::Request; use crate::response::Response; use std::future::Future; use std::pin::Pin; use std::time::Instant; -use tracing::{info, warn, Level}; +use tracing::{info_span, Instrument, Level}; -/// Middleware layer that logs request information +/// Middleware layer that creates tracing spans for requests +/// +/// This layer creates a span for each request containing: +/// - HTTP method +/// - Request path +/// - Request ID (if RequestIdLayer is applied) +/// - Response status code +/// - Request duration +/// - Any custom fields configured via `with_field()` +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::middleware::TracingLayer; +/// +/// RustApi::new() +/// .layer(TracingLayer::new() +/// .with_field("service", "my-api") +/// .with_field("version", "1.0.0")) +/// .route("/", get(handler)) +/// ``` #[derive(Clone)] pub struct TracingLayer { level: Level, + custom_fields: Vec<(String, String)>, } impl TracingLayer { /// Create a new TracingLayer with default INFO level pub fn new() -> Self { - Self { level: Level::INFO } + Self { + level: Level::INFO, + custom_fields: Vec::new(), + } } /// Create a TracingLayer with a specific log level pub fn with_level(level: Level) -> Self { - Self { level } + Self { + level, + custom_fields: Vec::new(), + } + } + + /// Add a custom field to all request spans + /// + /// Custom fields are included in every span created by this layer. + /// + /// # Example + /// + /// ```rust,ignore + /// TracingLayer::new() + /// .with_field("service", "my-api") + /// .with_field("environment", "production") + /// ``` + pub fn with_field(mut self, key: impl Into, value: impl Into) -> Self { + self.custom_fields.push((key.into(), value.into())); + self } } @@ -41,63 +86,110 @@ impl MiddlewareLayer for TracingLayer { next: BoxedNext, ) -> Pin + Send + 'static>> { let level = self.level; - let method = req.method().clone(); + let method = req.method().to_string(); let path = req.uri().path().to_string(); + let custom_fields = self.custom_fields.clone(); + + // Extract request_id if available + let request_id = req + .extensions() + .get::() + .map(|id| id.as_str().to_string()) + .unwrap_or_else(|| "unknown".to_string()); Box::pin(async move { let start = Instant::now(); - // Call the next handler - let response = next(req).await; + // Create span with all fields + // We use info_span! as the base and record custom fields dynamically + let span = info_span!( + "http_request", + method = %method, + path = %path, + request_id = %request_id, + status = tracing::field::Empty, + duration_ms = tracing::field::Empty, + error = tracing::field::Empty, + ); + + // Record custom fields in the span + for (key, value) in &custom_fields { + span.record(key.as_str(), value.as_str()); + } + + // Execute the request within the span + let response = async { + next(req).await + } + .instrument(span.clone()) + .await; let duration = start.elapsed(); let status = response.status(); + let status_code = status.as_u16(); + + // Record response fields + span.record("status", status_code); + span.record("duration_ms", duration.as_millis() as u64); - // Log based on status code + // Record error if status indicates failure + if status.is_client_error() || status.is_server_error() { + span.record("error", true); + } + + // Log based on status code and configured level + let _enter = span.enter(); if status.is_success() { match level { Level::TRACE => tracing::trace!( method = %method, path = %path, - status = %status.as_u16(), + request_id = %request_id, + status = %status_code, duration_ms = %duration.as_millis(), "Request completed" ), Level::DEBUG => tracing::debug!( method = %method, path = %path, - status = %status.as_u16(), + request_id = %request_id, + status = %status_code, duration_ms = %duration.as_millis(), "Request completed" ), - Level::INFO => info!( + Level::INFO => tracing::info!( method = %method, path = %path, - status = %status.as_u16(), + request_id = %request_id, + status = %status_code, duration_ms = %duration.as_millis(), "Request completed" ), - Level::WARN => warn!( + Level::WARN => tracing::warn!( method = %method, path = %path, - status = %status.as_u16(), + request_id = %request_id, + status = %status_code, duration_ms = %duration.as_millis(), "Request completed" ), Level::ERROR => tracing::error!( method = %method, path = %path, - status = %status.as_u16(), + request_id = %request_id, + status = %status_code, duration_ms = %duration.as_millis(), "Request completed" ), } } else { - warn!( + tracing::warn!( method = %method, path = %path, - status = %status.as_u16(), + request_id = %request_id, + status = %status_code, duration_ms = %duration.as_millis(), + error = true, "Request failed" ); } @@ -114,13 +206,355 @@ impl MiddlewareLayer for TracingLayer { #[cfg(test)] mod tests { use super::*; + use crate::middleware::layer::{BoxedNext, LayerStack}; + use crate::middleware::request_id::RequestIdLayer; + use bytes::Bytes; + use http::{Extensions, Method, StatusCode}; + use proptest::prelude::*; + use proptest::test_runner::TestCaseError; + use std::collections::HashMap; + use std::sync::Arc; + use tracing_subscriber::layer::SubscriberExt; + + /// Create a test request with the given method and path + fn create_test_request(method: Method, path: &str) -> crate::request::Request { + let uri: http::Uri = path.parse().unwrap(); + let builder = http::Request::builder().method(method).uri(uri); + + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + crate::request::Request::new( + parts, + Bytes::new(), + Arc::new(Extensions::new()), + HashMap::new(), + ) + } #[test] fn test_tracing_layer_creation() { let layer = TracingLayer::new(); assert_eq!(layer.level, Level::INFO); + assert!(layer.custom_fields.is_empty()); let layer = TracingLayer::with_level(Level::DEBUG); assert_eq!(layer.level, Level::DEBUG); } + + #[test] + fn test_tracing_layer_with_custom_fields() { + let layer = TracingLayer::new() + .with_field("service", "test-api") + .with_field("version", "1.0.0"); + + assert_eq!(layer.custom_fields.len(), 2); + assert_eq!(layer.custom_fields[0], ("service".to_string(), "test-api".to_string())); + assert_eq!(layer.custom_fields[1], ("version".to_string(), "1.0.0".to_string())); + } + + #[test] + fn test_tracing_layer_clone() { + let layer = TracingLayer::new() + .with_field("key", "value"); + + let cloned = layer.clone(); + assert_eq!(cloned.level, layer.level); + assert_eq!(cloned.custom_fields, layer.custom_fields); + } + + /// A test subscriber that captures span fields for verification + #[derive(Clone)] + struct SpanFieldCapture { + captured_fields: Arc>>, + } + + #[derive(Debug, Clone)] + struct CapturedSpan { + name: String, + fields: HashMap, + } + + impl SpanFieldCapture { + fn new() -> Self { + Self { + captured_fields: Arc::new(std::sync::Mutex::new(Vec::new())), + } + } + + fn get_spans(&self) -> Vec { + self.captured_fields.lock().unwrap().clone() + } + } + + impl tracing_subscriber::Layer for SpanFieldCapture + where + S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>, + { + fn on_new_span( + &self, + attrs: &tracing::span::Attributes<'_>, + _id: &tracing::span::Id, + _ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + let mut fields = HashMap::new(); + let mut visitor = FieldVisitor { fields: &mut fields }; + attrs.record(&mut visitor); + + let span = CapturedSpan { + name: attrs.metadata().name().to_string(), + fields, + }; + + self.captured_fields.lock().unwrap().push(span); + } + + fn on_record( + &self, + id: &tracing::span::Id, + values: &tracing::span::Record<'_>, + ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + if let Some(_span) = ctx.span(id) { + let mut captured = self.captured_fields.lock().unwrap(); + if let Some(last_span) = captured.last_mut() { + let mut visitor = FieldVisitor { fields: &mut last_span.fields }; + values.record(&mut visitor); + } + } + } + } + + struct FieldVisitor<'a> { + fields: &'a mut HashMap, + } + + impl<'a> tracing::field::Visit for FieldVisitor<'a> { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { + self.fields.insert(field.name().to_string(), format!("{:?}", value)); + } + + fn record_str(&mut self, field: &tracing::field::Field, value: &str) { + self.fields.insert(field.name().to_string(), value.to_string()); + } + + fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { + self.fields.insert(field.name().to_string(), value.to_string()); + } + + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + self.fields.insert(field.name().to_string(), value.to_string()); + } + + fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { + self.fields.insert(field.name().to_string(), value.to_string()); + } + } + + // **Feature: phase4-ergonomics-v1, Property 8: Tracing Span Completeness** + // + // For any HTTP request processed by the system with tracing enabled, the resulting + // span should contain: request method, request path, request ID, response status code, + // and response duration. + // + // **Validates: Requirements 4.1, 4.2, 4.3, 4.4** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_tracing_span_completeness( + method_idx in 0usize..5usize, + path in "/[a-z]{1,10}(/[a-z]{1,10})?", + status_code in 200u16..600u16, + custom_key in "[a-z]{3,10}", + custom_value in "[a-z0-9]{3,20}", + ) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let result: Result<(), TestCaseError> = rt.block_on(async { + // Set up span capture + let capture = SpanFieldCapture::new(); + let subscriber = tracing_subscriber::registry().with(capture.clone()); + + // Use a guard to set the subscriber for this test + let _guard = tracing::subscriber::set_default(subscriber); + + // Create middleware stack with RequestIdLayer and TracingLayer + let mut stack = LayerStack::new(); + stack.push(Box::new(RequestIdLayer::new())); + stack.push(Box::new(TracingLayer::new() + .with_field(&custom_key, &custom_value))); + + // Map index to HTTP method + let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH]; + let method = methods[method_idx].clone(); + + // Create handler that returns the specified status + let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK); + let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| { + let status = response_status; + Box::pin(async move { + http::Response::builder() + .status(status) + .body(http_body_util::Full::new(Bytes::from("test"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + // Execute request + let request = create_test_request(method.clone(), &path); + let response = stack.execute(request, handler).await; + + // Verify response status matches + prop_assert_eq!(response.status(), response_status); + + // Find the http_request span + let spans = capture.get_spans(); + let http_span = spans.iter().find(|s| s.name == "http_request"); + + prop_assert!(http_span.is_some(), "Should have created an http_request span"); + let span = http_span.unwrap(); + + // Verify required fields are present + // Method + prop_assert!( + span.fields.contains_key("method"), + "Span should contain 'method' field. Fields: {:?}", span.fields + ); + prop_assert_eq!( + span.fields.get("method").map(|s| s.trim_matches('"')), + Some(method.as_str()), + "Method should match request method" + ); + + // Path + prop_assert!( + span.fields.contains_key("path"), + "Span should contain 'path' field. Fields: {:?}", span.fields + ); + prop_assert_eq!( + span.fields.get("path").map(|s| s.trim_matches('"')), + Some(path.as_str()), + "Path should match request path" + ); + + // Request ID + prop_assert!( + span.fields.contains_key("request_id"), + "Span should contain 'request_id' field. Fields: {:?}", span.fields + ); + let request_id = span.fields.get("request_id").unwrap(); + // Request ID should be a UUID format (36 chars with hyphens) or "unknown" + let request_id_trimmed = request_id.trim_matches('"'); + prop_assert!( + request_id_trimmed == "unknown" || request_id_trimmed.len() == 36, + "Request ID should be UUID format or 'unknown', got: {}", request_id + ); + + // Status code (recorded after response) + prop_assert!( + span.fields.contains_key("status"), + "Span should contain 'status' field. Fields: {:?}", span.fields + ); + let recorded_status: u16 = span.fields.get("status") + .and_then(|s| s.parse().ok()) + .unwrap_or(0); + prop_assert_eq!( + recorded_status, + status_code, + "Status should match response status code" + ); + + // Duration (recorded after response) + prop_assert!( + span.fields.contains_key("duration_ms"), + "Span should contain 'duration_ms' field. Fields: {:?}", span.fields + ); + let duration: u64 = span.fields.get("duration_ms") + .and_then(|s| s.parse().ok()) + .unwrap_or(u64::MAX); + prop_assert!( + duration < 10000, // Should complete in less than 10 seconds + "Duration should be reasonable, got: {} ms", duration + ); + + // Error field should be present for error responses + if response_status.is_client_error() || response_status.is_server_error() { + prop_assert!( + span.fields.contains_key("error"), + "Span should contain 'error' field for error responses. Fields: {:?}", span.fields + ); + } + + Ok(()) + }); + result?; + } + } + + #[test] + fn test_tracing_layer_records_request_id() { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let capture = SpanFieldCapture::new(); + let subscriber = tracing_subscriber::registry().with(capture.clone()); + let _guard = tracing::subscriber::set_default(subscriber); + + let mut stack = LayerStack::new(); + stack.push(Box::new(RequestIdLayer::new())); + stack.push(Box::new(TracingLayer::new())); + + let handler: BoxedNext = Arc::new(|_req: crate::request::Request| { + Box::pin(async { + http::Response::builder() + .status(StatusCode::OK) + .body(http_body_util::Full::new(Bytes::from("ok"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let request = create_test_request(Method::GET, "/test"); + let _response = stack.execute(request, handler).await; + + let spans = capture.get_spans(); + let http_span = spans.iter().find(|s| s.name == "http_request"); + assert!(http_span.is_some(), "Should have http_request span"); + + let span = http_span.unwrap(); + assert!(span.fields.contains_key("request_id"), "Should have request_id field"); + }); + } + + #[test] + fn test_tracing_layer_records_error_for_failures() { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let capture = SpanFieldCapture::new(); + let subscriber = tracing_subscriber::registry().with(capture.clone()); + let _guard = tracing::subscriber::set_default(subscriber); + + let mut stack = LayerStack::new(); + stack.push(Box::new(TracingLayer::new())); + + let handler: BoxedNext = Arc::new(|_req: crate::request::Request| { + Box::pin(async { + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(http_body_util::Full::new(Bytes::from("error"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let request = create_test_request(Method::GET, "/test"); + let response = stack.execute(request, handler).await; + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let spans = capture.get_spans(); + let http_span = spans.iter().find(|s| s.name == "http_request"); + assert!(http_span.is_some(), "Should have http_request span"); + + let span = http_span.unwrap(); + assert!(span.fields.contains_key("error"), "Should have error field for 5xx response"); + }); + } } diff --git a/crates/rustapi-core/src/path_validation.rs b/crates/rustapi-core/src/path_validation.rs new file mode 100644 index 0000000..54a9482 --- /dev/null +++ b/crates/rustapi-core/src/path_validation.rs @@ -0,0 +1,530 @@ +//! Route path validation utilities +//! +//! This module provides compile-time and runtime validation for route paths. +//! The validation logic is shared between the proc-macro crate (for compile-time +//! validation) and the core crate (for runtime validation and testing). + +/// Result of path validation +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PathValidationError { + /// Path must start with '/' + MustStartWithSlash { path: String }, + /// Path contains empty segment (double slash) + EmptySegment { path: String }, + /// Nested braces are not allowed + NestedBraces { path: String, position: usize }, + /// Unmatched closing brace + UnmatchedClosingBrace { path: String, position: usize }, + /// Empty parameter name + EmptyParameterName { path: String, position: usize }, + /// Invalid parameter name (contains invalid characters) + InvalidParameterName { path: String, param_name: String, position: usize }, + /// Parameter name starts with digit + ParameterStartsWithDigit { path: String, param_name: String, position: usize }, + /// Unclosed brace + UnclosedBrace { path: String }, + /// Invalid character in path + InvalidCharacter { path: String, character: char, position: usize }, +} + +impl std::fmt::Display for PathValidationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PathValidationError::MustStartWithSlash { path } => { + write!(f, "route path must start with '/', got: \"{}\"", path) + } + PathValidationError::EmptySegment { path } => { + write!(f, "route path contains empty segment (double slash): \"{}\"", path) + } + PathValidationError::NestedBraces { path, position } => { + write!(f, "nested braces are not allowed in route path at position {}: \"{}\"", position, path) + } + PathValidationError::UnmatchedClosingBrace { path, position } => { + write!(f, "unmatched closing brace '}}' at position {} in route path: \"{}\"", position, path) + } + PathValidationError::EmptyParameterName { path, position } => { + write!(f, "empty parameter name '{{}}' at position {} in route path: \"{}\"", position, path) + } + PathValidationError::InvalidParameterName { path, param_name, position } => { + write!(f, "invalid parameter name '{{{}}}' at position {} - parameter names must contain only alphanumeric characters and underscores: \"{}\"", param_name, position, path) + } + PathValidationError::ParameterStartsWithDigit { path, param_name, position } => { + write!(f, "parameter name '{{{}}}' cannot start with a digit at position {}: \"{}\"", param_name, position, path) + } + PathValidationError::UnclosedBrace { path } => { + write!(f, "unclosed brace '{{' in route path (missing closing '}}'): \"{}\"", path) + } + PathValidationError::InvalidCharacter { path, character, position } => { + write!(f, "invalid character '{}' at position {} in route path: \"{}\"", character, position, path) + } + } + } +} + +impl std::error::Error for PathValidationError {} + +/// Validate route path syntax +/// +/// Returns Ok(()) if the path is valid, or Err with a descriptive error. +/// +/// # Valid paths +/// - Must start with '/' +/// - Can contain alphanumeric characters, '-', '_', '.', '/' +/// - Can contain path parameters in the form `{param_name}` +/// - Parameter names must be valid identifiers (alphanumeric + underscore, not starting with digit) +/// +/// # Invalid paths +/// - Paths not starting with '/' +/// - Paths with empty segments (double slashes like '//') +/// - Paths with unclosed or nested braces +/// - Paths with empty parameter names like '{}' +/// - Paths with invalid parameter names +/// - Paths with invalid characters +/// +/// # Examples +/// +/// ``` +/// use rustapi_core::path_validation::validate_path; +/// +/// // Valid paths +/// assert!(validate_path("/").is_ok()); +/// assert!(validate_path("/users").is_ok()); +/// assert!(validate_path("/users/{id}").is_ok()); +/// assert!(validate_path("/users/{user_id}/posts/{post_id}").is_ok()); +/// +/// // Invalid paths +/// assert!(validate_path("users").is_err()); // Missing leading / +/// assert!(validate_path("/users//posts").is_err()); // Double slash +/// assert!(validate_path("/users/{").is_err()); // Unclosed brace +/// assert!(validate_path("/users/{}").is_err()); // Empty parameter +/// assert!(validate_path("/users/{123}").is_err()); // Parameter starts with digit +/// ``` +pub fn validate_path(path: &str) -> Result<(), PathValidationError> { + // Path must start with / + if !path.starts_with('/') { + return Err(PathValidationError::MustStartWithSlash { + path: path.to_string(), + }); + } + + // Check for empty path segments (double slashes) + if path.contains("//") { + return Err(PathValidationError::EmptySegment { + path: path.to_string(), + }); + } + + // Validate path parameter syntax + let mut brace_depth = 0; + let mut param_start = None; + + for (i, ch) in path.char_indices() { + match ch { + '{' => { + if brace_depth > 0 { + return Err(PathValidationError::NestedBraces { + path: path.to_string(), + position: i, + }); + } + brace_depth += 1; + param_start = Some(i); + } + '}' => { + if brace_depth == 0 { + return Err(PathValidationError::UnmatchedClosingBrace { + path: path.to_string(), + position: i, + }); + } + brace_depth -= 1; + + // Check that parameter name is not empty + if let Some(start) = param_start { + let param_name = &path[start + 1..i]; + if param_name.is_empty() { + return Err(PathValidationError::EmptyParameterName { + path: path.to_string(), + position: start, + }); + } + // Validate parameter name contains only valid identifier characters + if !param_name.chars().all(|c| c.is_alphanumeric() || c == '_') { + return Err(PathValidationError::InvalidParameterName { + path: path.to_string(), + param_name: param_name.to_string(), + position: start, + }); + } + // Parameter name must not start with a digit + if param_name.chars().next().map(|c| c.is_ascii_digit()).unwrap_or(false) { + return Err(PathValidationError::ParameterStartsWithDigit { + path: path.to_string(), + param_name: param_name.to_string(), + position: start, + }); + } + } + param_start = None; + } + // Check for invalid characters in path (outside of parameters) + _ if brace_depth == 0 => { + // Allow alphanumeric, -, _, ., /, and common URL characters + if !ch.is_alphanumeric() && !"-_./*".contains(ch) { + return Err(PathValidationError::InvalidCharacter { + path: path.to_string(), + character: ch, + position: i, + }); + } + } + _ => {} + } + } + + // Check for unclosed braces + if brace_depth > 0 { + return Err(PathValidationError::UnclosedBrace { + path: path.to_string(), + }); + } + + Ok(()) +} + +/// Check if a path is valid (convenience function) +pub fn is_valid_path(path: &str) -> bool { + validate_path(path).is_ok() +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + // Unit tests for specific cases + #[test] + fn test_valid_paths() { + assert!(validate_path("/").is_ok()); + assert!(validate_path("/users").is_ok()); + assert!(validate_path("/users/{id}").is_ok()); + assert!(validate_path("/users/{user_id}").is_ok()); + assert!(validate_path("/users/{user_id}/posts").is_ok()); + assert!(validate_path("/users/{user_id}/posts/{post_id}").is_ok()); + assert!(validate_path("/api/v1/users").is_ok()); + assert!(validate_path("/api-v1/users").is_ok()); + assert!(validate_path("/api_v1/users").is_ok()); + assert!(validate_path("/api.v1/users").is_ok()); + assert!(validate_path("/users/*").is_ok()); // Wildcard + } + + #[test] + fn test_missing_leading_slash() { + let result = validate_path("users"); + assert!(matches!(result, Err(PathValidationError::MustStartWithSlash { .. }))); + + let result = validate_path("users/{id}"); + assert!(matches!(result, Err(PathValidationError::MustStartWithSlash { .. }))); + } + + #[test] + fn test_double_slash() { + let result = validate_path("/users//posts"); + assert!(matches!(result, Err(PathValidationError::EmptySegment { .. }))); + + let result = validate_path("//users"); + assert!(matches!(result, Err(PathValidationError::EmptySegment { .. }))); + } + + #[test] + fn test_unclosed_brace() { + let result = validate_path("/users/{id"); + assert!(matches!(result, Err(PathValidationError::UnclosedBrace { .. }))); + + let result = validate_path("/users/{"); + assert!(matches!(result, Err(PathValidationError::UnclosedBrace { .. }))); + } + + #[test] + fn test_unmatched_closing_brace() { + let result = validate_path("/users/id}"); + assert!(matches!(result, Err(PathValidationError::UnmatchedClosingBrace { .. }))); + + let result = validate_path("/users/}"); + assert!(matches!(result, Err(PathValidationError::UnmatchedClosingBrace { .. }))); + } + + #[test] + fn test_empty_parameter_name() { + let result = validate_path("/users/{}"); + assert!(matches!(result, Err(PathValidationError::EmptyParameterName { .. }))); + + let result = validate_path("/users/{}/posts"); + assert!(matches!(result, Err(PathValidationError::EmptyParameterName { .. }))); + } + + #[test] + fn test_nested_braces() { + let result = validate_path("/users/{{id}}"); + assert!(matches!(result, Err(PathValidationError::NestedBraces { .. }))); + + let result = validate_path("/users/{outer{inner}}"); + assert!(matches!(result, Err(PathValidationError::NestedBraces { .. }))); + } + + #[test] + fn test_parameter_starts_with_digit() { + let result = validate_path("/users/{123}"); + assert!(matches!(result, Err(PathValidationError::ParameterStartsWithDigit { .. }))); + + let result = validate_path("/users/{1id}"); + assert!(matches!(result, Err(PathValidationError::ParameterStartsWithDigit { .. }))); + } + + #[test] + fn test_invalid_parameter_name() { + let result = validate_path("/users/{id-name}"); + assert!(matches!(result, Err(PathValidationError::InvalidParameterName { .. }))); + + let result = validate_path("/users/{id.name}"); + assert!(matches!(result, Err(PathValidationError::InvalidParameterName { .. }))); + + let result = validate_path("/users/{id name}"); + assert!(matches!(result, Err(PathValidationError::InvalidParameterName { .. }))); + } + + #[test] + fn test_invalid_characters() { + let result = validate_path("/users?query"); + assert!(matches!(result, Err(PathValidationError::InvalidCharacter { .. }))); + + let result = validate_path("/users#anchor"); + assert!(matches!(result, Err(PathValidationError::InvalidCharacter { .. }))); + + let result = validate_path("/users@domain"); + assert!(matches!(result, Err(PathValidationError::InvalidCharacter { .. }))); + } + + // **Feature: phase4-ergonomics-v1, Property 2: Invalid Path Syntax Rejection** + // + // For any route path string that contains invalid syntax (e.g., unclosed braces, + // invalid characters), the system should reject it with a clear error message. + // + // **Validates: Requirements 1.5** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property: Valid paths are accepted + /// + /// For any path that follows the valid path structure: + /// - Starts with / + /// - Contains only valid segments (alphanumeric, -, _, .) + /// - Has properly formed parameters {name} where name is a valid identifier + /// + /// The validation should succeed. + #[test] + fn prop_valid_paths_accepted( + // Generate valid path segments (non-empty to avoid double slashes) + segments in prop::collection::vec("[a-zA-Z][a-zA-Z0-9_-]{0,10}", 0..5), + // Generate valid parameter names (must start with letter or underscore) + params in prop::collection::vec("[a-zA-Z_][a-zA-Z0-9_]{0,10}", 0..3), + ) { + // Build a valid path from segments and parameters + let mut path = String::from("/"); + + for (i, segment) in segments.iter().enumerate() { + if i > 0 { + path.push('/'); + } + path.push_str(segment); + } + + // Add parameters at the end (only if we have segments or it's the root) + for param in params.iter() { + if path != "/" { + path.push('/'); + } + path.push('{'); + path.push_str(param); + path.push('}'); + } + + // If path is just "/", that's valid + // Otherwise ensure we have a valid structure + let result = validate_path(&path); + prop_assert!( + result.is_ok(), + "Valid path '{}' should be accepted, but got error: {:?}", + path, + result.err() + ); + } + + /// Property: Paths without leading slash are rejected + /// + /// For any path that doesn't start with '/', validation should fail + /// with MustStartWithSlash error. + #[test] + fn prop_missing_leading_slash_rejected( + // Generate path content that doesn't start with / + content in "[a-zA-Z][a-zA-Z0-9/_-]{0,20}", + ) { + // Ensure the path doesn't start with / + let path = if content.starts_with('/') { + format!("x{}", content) + } else { + content + }; + + let result = validate_path(&path); + prop_assert!( + matches!(result, Err(PathValidationError::MustStartWithSlash { .. })), + "Path '{}' without leading slash should be rejected with MustStartWithSlash, got: {:?}", + path, + result + ); + } + + /// Property: Paths with unclosed braces are rejected + /// + /// For any path containing an unclosed '{', validation should fail. + #[test] + fn prop_unclosed_brace_rejected( + // Use a valid prefix without double slashes + prefix in "/[a-zA-Z][a-zA-Z0-9_-]{0,10}", + param_start in "[a-zA-Z_][a-zA-Z0-9_]{0,5}", + ) { + // Create a path with an unclosed brace + let path = format!("{}/{{{}", prefix, param_start); + + let result = validate_path(&path); + prop_assert!( + matches!(result, Err(PathValidationError::UnclosedBrace { .. })), + "Path '{}' with unclosed brace should be rejected with UnclosedBrace, got: {:?}", + path, + result + ); + } + + /// Property: Paths with unmatched closing braces are rejected + /// + /// For any path containing a '}' without a matching '{', validation should fail. + #[test] + fn prop_unmatched_closing_brace_rejected( + // Use a valid prefix without double slashes + prefix in "/[a-zA-Z][a-zA-Z0-9_-]{0,10}", + suffix in "[a-zA-Z0-9_]{0,5}", + ) { + // Create a path with an unmatched closing brace + let path = format!("{}/{}}}", prefix, suffix); + + let result = validate_path(&path); + prop_assert!( + matches!(result, Err(PathValidationError::UnmatchedClosingBrace { .. })), + "Path '{}' with unmatched closing brace should be rejected, got: {:?}", + path, + result + ); + } + + /// Property: Paths with empty parameter names are rejected + /// + /// For any path containing '{}', validation should fail with EmptyParameterName. + #[test] + fn prop_empty_parameter_rejected( + // Use a valid prefix without double slashes + prefix in "/[a-zA-Z][a-zA-Z0-9_-]{0,10}", + has_suffix in proptest::bool::ANY, + suffix_content in "[a-zA-Z][a-zA-Z0-9_-]{0,10}", + ) { + // Create a path with an empty parameter + let suffix = if has_suffix { + format!("/{}", suffix_content) + } else { + String::new() + }; + let path = format!("{}/{{}}{}", prefix, suffix); + + let result = validate_path(&path); + prop_assert!( + matches!(result, Err(PathValidationError::EmptyParameterName { .. })), + "Path '{}' with empty parameter should be rejected with EmptyParameterName, got: {:?}", + path, + result + ); + } + + /// Property: Paths with parameters starting with digits are rejected + /// + /// For any path containing a parameter that starts with a digit, + /// validation should fail with ParameterStartsWithDigit. + #[test] + fn prop_parameter_starting_with_digit_rejected( + // Use a valid prefix without double slashes + prefix in "/[a-zA-Z][a-zA-Z0-9_-]{0,10}", + digit in "[0-9]", + rest in "[a-zA-Z0-9_]{0,5}", + ) { + // Create a path with a parameter starting with a digit + let path = format!("{}/{{{}{}}}", prefix, digit, rest); + + let result = validate_path(&path); + prop_assert!( + matches!(result, Err(PathValidationError::ParameterStartsWithDigit { .. })), + "Path '{}' with parameter starting with digit should be rejected, got: {:?}", + path, + result + ); + } + + /// Property: Paths with double slashes are rejected + /// + /// For any path containing '//', validation should fail with EmptySegment. + #[test] + fn prop_double_slash_rejected( + prefix in "/[a-zA-Z0-9_-]{0,10}", + suffix in "[a-zA-Z0-9/_-]{0,10}", + ) { + // Create a path with double slash + let path = format!("{}//{}", prefix, suffix); + + let result = validate_path(&path); + prop_assert!( + matches!(result, Err(PathValidationError::EmptySegment { .. })), + "Path '{}' with double slash should be rejected with EmptySegment, got: {:?}", + path, + result + ); + } + + /// Property: Error messages contain the original path + /// + /// For any invalid path, the error message should contain the original path + /// for debugging purposes. + #[test] + fn prop_error_contains_path( + // Generate various invalid paths + invalid_type in 0..5usize, + content in "[a-zA-Z][a-zA-Z0-9_]{1,10}", + ) { + let path = match invalid_type { + 0 => content.clone(), // Missing leading slash + 1 => format!("/{}//test", content), // Double slash + 2 => format!("/{}/{{", content), // Unclosed brace + 3 => format!("/{}/{{}}", content), // Empty parameter + 4 => format!("/{}/{{1{content}}}", content = content), // Parameter starts with digit + _ => content.clone(), + }; + + let result = validate_path(&path); + if let Err(err) = result { + let error_message = err.to_string(); + prop_assert!( + error_message.contains(&path) || error_message.contains(&content), + "Error message '{}' should contain the path or content for debugging", + error_message + ); + } + } + } +} diff --git a/crates/rustapi-core/src/request.rs b/crates/rustapi-core/src/request.rs index 8c29834..afa84d0 100644 --- a/crates/rustapi-core/src/request.rs +++ b/crates/rustapi-core/src/request.rs @@ -1,4 +1,43 @@ //! Request types for RustAPI +//! +//! This module provides the [`Request`] type which wraps an incoming HTTP request +//! and provides access to all its components. +//! +//! # Accessing Request Data +//! +//! While extractors are the preferred way to access request data in handlers, +//! the `Request` type provides direct access when needed: +//! +//! ```rust,ignore +//! // In middleware or custom extractors +//! fn process_request(req: &Request) { +//! let method = req.method(); +//! let path = req.path(); +//! let headers = req.headers(); +//! let query = req.query_string(); +//! } +//! ``` +//! +//! # Path Parameters +//! +//! Path parameters extracted from the URL pattern are available via: +//! +//! ```rust,ignore +//! // For route "/users/{id}" +//! let id = req.path_param("id"); +//! let all_params = req.path_params(); +//! ``` +//! +//! # Request Body +//! +//! The body can only be consumed once: +//! +//! ```rust,ignore +//! if let Some(body) = req.take_body() { +//! // Process body bytes +//! } +//! // Subsequent calls return None +//! ``` use bytes::Bytes; use http::{request::Parts, Extensions, HeaderMap, Method, Uri, Version}; diff --git a/crates/rustapi-core/src/response.rs b/crates/rustapi-core/src/response.rs index a8adc3a..d798b66 100644 --- a/crates/rustapi-core/src/response.rs +++ b/crates/rustapi-core/src/response.rs @@ -1,4 +1,74 @@ //! Response types for RustAPI +//! +//! This module provides types for building HTTP responses. The core trait is +//! [`IntoResponse`], which allows any type to be converted into an HTTP response. +//! +//! # Built-in Response Types +//! +//! | Type | Status | Content-Type | Description | +//! |------|--------|--------------|-------------| +//! | `String` / `&str` | 200 | text/plain | Plain text response | +//! | `()` | 200 | - | Empty response | +//! | [`Json`] | 200 | application/json | JSON response | +//! | [`Created`] | 201 | application/json | Created resource | +//! | [`NoContent`] | 204 | - | No content response | +//! | [`Html`] | 200 | text/html | HTML response | +//! | [`Redirect`] | 3xx | - | HTTP redirect | +//! | [`WithStatus`] | N | varies | Custom status code | +//! | [`ApiError`] | varies | application/json | Error response | +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::{Json, Created, NoContent, IntoResponse}; +//! use serde::Serialize; +//! +//! #[derive(Serialize)] +//! struct User { +//! id: i64, +//! name: String, +//! } +//! +//! // Return JSON with 200 OK +//! async fn get_user() -> Json { +//! Json(User { id: 1, name: "Alice".to_string() }) +//! } +//! +//! // Return JSON with 201 Created +//! async fn create_user() -> Created { +//! Created(User { id: 2, name: "Bob".to_string() }) +//! } +//! +//! // Return 204 No Content +//! async fn delete_user() -> NoContent { +//! NoContent +//! } +//! +//! // Return custom status code +//! async fn accepted() -> WithStatus { +//! WithStatus("Request accepted".to_string()) +//! } +//! ``` +//! +//! # Tuple Responses +//! +//! You can also return tuples to customize the response: +//! +//! ```rust,ignore +//! use http::StatusCode; +//! +//! // (StatusCode, body) +//! async fn custom_status() -> (StatusCode, String) { +//! (StatusCode::ACCEPTED, "Accepted".to_string()) +//! } +//! +//! // (StatusCode, headers, body) +//! async fn with_headers() -> (StatusCode, HeaderMap, String) { +//! let mut headers = HeaderMap::new(); +//! headers.insert("X-Custom", "value".parse().unwrap()); +//! (StatusCode::OK, headers, "Hello".to_string()) +//! } +//! ``` use crate::error::{ApiError, ErrorResponse}; use bytes::Bytes; @@ -96,10 +166,11 @@ impl IntoResponse for Result { } // Implement for ApiError -// Implement for ApiError +// Implement for ApiError with environment-aware error masking impl IntoResponse for ApiError { fn into_response(self) -> Response { let status = self.status; + // ErrorResponse::from now handles environment-aware masking let error_response = ErrorResponse::from(self); let body = serde_json::to_vec(&error_response).unwrap_or_else(|_| { br#"{"error":{"type":"internal_error","message":"Failed to serialize error"}}"#.to_vec() diff --git a/crates/rustapi-core/src/router.rs b/crates/rustapi-core/src/router.rs index ab3a74b..3d1b6ac 100644 --- a/crates/rustapi-core/src/router.rs +++ b/crates/rustapi-core/src/router.rs @@ -1,4 +1,45 @@ //! Router implementation using radix tree (matchit) +//! +//! This module provides HTTP routing functionality for RustAPI. Routes are +//! registered using path patterns and HTTP method handlers. +//! +//! # Path Patterns +//! +//! Routes support dynamic path parameters using `{param}` syntax: +//! +//! - `/users` - Static path +//! - `/users/{id}` - Single parameter +//! - `/users/{user_id}/posts/{post_id}` - Multiple parameters +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::{Router, get, post, put, delete}; +//! +//! async fn list_users() -> &'static str { "List users" } +//! async fn get_user() -> &'static str { "Get user" } +//! async fn create_user() -> &'static str { "Create user" } +//! async fn update_user() -> &'static str { "Update user" } +//! async fn delete_user() -> &'static str { "Delete user" } +//! +//! let router = Router::new() +//! .route("/users", get(list_users).post(create_user)) +//! .route("/users/{id}", get(get_user).put(update_user).delete(delete_user)); +//! ``` +//! +//! # Method Chaining +//! +//! Multiple HTTP methods can be registered for the same path using method chaining: +//! +//! ```rust,ignore +//! .route("/users", get(list).post(create)) +//! .route("/users/{id}", get(show).put(update).delete(destroy)) +//! ``` +//! +//! # Route Conflict Detection +//! +//! The router detects conflicting routes at registration time and provides +//! helpful error messages with resolution guidance. use crate::handler::{into_boxed_handler, BoxedHandler, Handler}; use http::{Extensions, Method}; @@ -7,6 +48,63 @@ use rustapi_openapi::Operation; use std::collections::HashMap; use std::sync::Arc; +/// Information about a registered route for conflict detection +#[derive(Debug, Clone)] +pub struct RouteInfo { + /// The original path pattern (e.g., "/users/{id}") + pub path: String, + /// The HTTP methods registered for this path + pub methods: Vec, +} + +/// Error returned when a route conflict is detected +#[derive(Debug, Clone)] +pub struct RouteConflictError { + /// The path that was being registered + pub new_path: String, + /// The HTTP method that conflicts + pub method: Option, + /// The existing path that conflicts + pub existing_path: String, + /// Detailed error message from the underlying router + pub details: String, +} + +impl std::fmt::Display for RouteConflictError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "\n╭─────────────────────────────────────────────────────────────╮")?; + writeln!(f, "│ ROUTE CONFLICT DETECTED │")?; + writeln!(f, "╰─────────────────────────────────────────────────────────────╯")?; + writeln!(f)?; + writeln!(f, " Conflicting routes:")?; + writeln!(f, " → Existing: {}", self.existing_path)?; + writeln!(f, " → New: {}", self.new_path)?; + writeln!(f)?; + if let Some(ref method) = self.method { + writeln!(f, " HTTP Method: {}", method)?; + writeln!(f)?; + } + writeln!(f, " Details: {}", self.details)?; + writeln!(f)?; + writeln!(f, " How to resolve:")?; + writeln!(f, " 1. Use different path patterns for each route")?; + writeln!(f, " 2. If paths must be similar, ensure parameter names differ")?; + writeln!(f, " 3. Consider using different HTTP methods if appropriate")?; + writeln!(f)?; + writeln!(f, " Example:")?; + writeln!(f, " Instead of:")?; + writeln!(f, " .route(\"/users/{{id}}\", get(handler1))")?; + writeln!(f, " .route(\"/users/{{user_id}}\", get(handler2))")?; + writeln!(f)?; + writeln!(f, " Use:")?; + writeln!(f, " .route(\"/users/{{id}}\", get(handler1))")?; + writeln!(f, " .route(\"/users/{{id}}/profile\", get(handler2))")?; + Ok(()) + } +} + +impl std::error::Error for RouteConflictError {} + /// HTTP method router for a single path pub struct MethodRouter { handlers: HashMap, @@ -113,6 +211,8 @@ where pub struct Router { inner: MatchitRouter, state: Arc, + /// Track registered routes for conflict detection + registered_routes: HashMap, } impl Router { @@ -121,6 +221,7 @@ impl Router { Self { inner: MatchitRouter::new(), state: Arc::new(Extensions::new()), + registered_routes: HashMap::new(), } } @@ -129,14 +230,58 @@ impl Router { // Convert {param} style to :param for matchit let matchit_path = convert_path_params(path); + // Get the methods being registered + let methods: Vec = method_router.handlers.keys().cloned().collect(); + match self.inner.insert(matchit_path.clone(), method_router) { - Ok(_) => {} + Ok(_) => { + // Track the registered route + self.registered_routes.insert( + matchit_path.clone(), + RouteInfo { + path: path.to_string(), + methods, + }, + ); + } Err(e) => { - panic!("Route conflict: {} - {}", path, e); + // Find the existing conflicting route + let existing_path = self.find_conflicting_route(&matchit_path) + .map(|info| info.path.clone()) + .unwrap_or_else(|| "".to_string()); + + let conflict_error = RouteConflictError { + new_path: path.to_string(), + method: methods.first().cloned(), + existing_path, + details: e.to_string(), + }; + + panic!("{}", conflict_error); } } self } + + /// Find a conflicting route by checking registered routes + fn find_conflicting_route(&self, matchit_path: &str) -> Option<&RouteInfo> { + // Try to find an exact match first + if let Some(info) = self.registered_routes.get(matchit_path) { + return Some(info); + } + + // Try to find a route that would conflict (same structure but different param names) + let normalized_new = normalize_path_for_comparison(matchit_path); + + for (registered_path, info) in &self.registered_routes { + let normalized_existing = normalize_path_for_comparison(registered_path); + if normalized_new == normalized_existing { + return Some(info); + } + } + + None + } /// Add application state pub fn state(mut self, state: S) -> Self { @@ -184,6 +329,11 @@ impl Router { pub(crate) fn state_ref(&self) -> Arc { self.state.clone() } + + /// Get registered routes (for testing and debugging) + pub fn registered_routes(&self) -> &HashMap { + &self.registered_routes + } } impl Default for Router { @@ -225,6 +375,33 @@ fn convert_path_params(path: &str) -> String { result } +/// Normalize a path for conflict comparison by replacing parameter names with a placeholder +fn normalize_path_for_comparison(path: &str) -> String { + let mut result = String::with_capacity(path.len()); + let mut in_param = false; + + for ch in path.chars() { + match ch { + ':' => { + in_param = true; + result.push_str(":_"); + } + '/' => { + in_param = false; + result.push('/'); + } + _ if in_param => { + // Skip parameter name characters + } + _ => { + result.push(ch); + } + } + } + + result +} + #[cfg(test)] mod tests { use super::*; @@ -238,4 +415,278 @@ mod tests { ); assert_eq!(convert_path_params("/static/path"), "/static/path"); } + + #[test] + fn test_normalize_path_for_comparison() { + assert_eq!(normalize_path_for_comparison("/users/:id"), "/users/:_"); + assert_eq!(normalize_path_for_comparison("/users/:user_id"), "/users/:_"); + assert_eq!( + normalize_path_for_comparison("/users/:id/posts/:post_id"), + "/users/:_/posts/:_" + ); + assert_eq!(normalize_path_for_comparison("/static/path"), "/static/path"); + } + + #[test] + #[should_panic(expected = "ROUTE CONFLICT DETECTED")] + fn test_route_conflict_detection() { + async fn handler1() -> &'static str { "handler1" } + async fn handler2() -> &'static str { "handler2" } + + let _router = Router::new() + .route("/users/{id}", get(handler1)) + .route("/users/{user_id}", get(handler2)); // This should panic + } + + #[test] + fn test_no_conflict_different_paths() { + async fn handler1() -> &'static str { "handler1" } + async fn handler2() -> &'static str { "handler2" } + + let router = Router::new() + .route("/users/{id}", get(handler1)) + .route("/users/{id}/profile", get(handler2)); + + assert_eq!(router.registered_routes().len(), 2); + } + + #[test] + fn test_route_info_tracking() { + async fn handler() -> &'static str { "handler" } + + let router = Router::new() + .route("/users/{id}", get(handler)); + + let routes = router.registered_routes(); + assert_eq!(routes.len(), 1); + + let info = routes.get("/users/:id").unwrap(); + assert_eq!(info.path, "/users/{id}"); + assert_eq!(info.methods.len(), 1); + assert_eq!(info.methods[0], Method::GET); + } +} + +#[cfg(test)] +mod property_tests { + use super::*; + use proptest::prelude::*; + use std::panic::{catch_unwind, AssertUnwindSafe}; + + // **Feature: phase4-ergonomics-v1, Property 1: Route Conflict Detection** + // + // For any two routes with the same path and HTTP method registered on the same + // RustApi instance, the system should detect the conflict and report an error + // at startup time. + // + // **Validates: Requirements 1.2** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property: Routes with identical path structure but different parameter names conflict + /// + /// For any valid path with parameters, registering two routes with the same + /// structure but different parameter names should be detected as a conflict. + #[test] + fn prop_same_structure_different_param_names_conflict( + // Generate valid path segments + segments in prop::collection::vec("[a-z][a-z0-9]{0,5}", 1..4), + // Generate two different parameter names + param1 in "[a-z][a-z0-9]{0,5}", + param2 in "[a-z][a-z0-9]{0,5}", + ) { + // Ensure param names are different + prop_assume!(param1 != param2); + + // Build two paths with same structure but different param names + let mut path1 = String::from("/"); + let mut path2 = String::from("/"); + + for segment in &segments { + path1.push_str(segment); + path1.push('/'); + path2.push_str(segment); + path2.push('/'); + } + + path1.push('{'); + path1.push_str(¶m1); + path1.push('}'); + + path2.push('{'); + path2.push_str(¶m2); + path2.push('}'); + + // Try to register both routes - should panic + let result = catch_unwind(AssertUnwindSafe(|| { + async fn handler1() -> &'static str { "handler1" } + async fn handler2() -> &'static str { "handler2" } + + let _router = Router::new() + .route(&path1, get(handler1)) + .route(&path2, get(handler2)); + })); + + prop_assert!( + result.is_err(), + "Routes '{}' and '{}' should conflict but didn't", + path1, path2 + ); + } + + /// Property: Routes with different path structures don't conflict + /// + /// For any two paths with different structures (different number of segments + /// or different static segments), they should not conflict. + #[test] + fn prop_different_structures_no_conflict( + // Generate different path segments for two routes + segments1 in prop::collection::vec("[a-z][a-z0-9]{0,5}", 1..3), + segments2 in prop::collection::vec("[a-z][a-z0-9]{0,5}", 1..3), + // Optional parameter at the end + has_param1 in any::(), + has_param2 in any::(), + ) { + // Build two paths + let mut path1 = String::from("/"); + let mut path2 = String::from("/"); + + for segment in &segments1 { + path1.push_str(segment); + path1.push('/'); + } + path1.pop(); // Remove trailing slash + + for segment in &segments2 { + path2.push_str(segment); + path2.push('/'); + } + path2.pop(); // Remove trailing slash + + if has_param1 { + path1.push_str("/{id}"); + } + + if has_param2 { + path2.push_str("/{id}"); + } + + // Normalize paths for comparison + let norm1 = normalize_path_for_comparison(&convert_path_params(&path1)); + let norm2 = normalize_path_for_comparison(&convert_path_params(&path2)); + + // Only test if paths are actually different + prop_assume!(norm1 != norm2); + + // Try to register both routes - should succeed + let result = catch_unwind(AssertUnwindSafe(|| { + async fn handler1() -> &'static str { "handler1" } + async fn handler2() -> &'static str { "handler2" } + + let router = Router::new() + .route(&path1, get(handler1)) + .route(&path2, get(handler2)); + + router.registered_routes().len() + })); + + prop_assert!( + result.is_ok(), + "Routes '{}' and '{}' should not conflict but did", + path1, path2 + ); + + if let Ok(count) = result { + prop_assert_eq!(count, 2, "Should have registered 2 routes"); + } + } + + /// Property: Conflict error message contains both route paths + /// + /// When a conflict is detected, the error message should include both + /// the existing route path and the new conflicting route path. + #[test] + fn prop_conflict_error_contains_both_paths( + // Generate a valid path segment + segment in "[a-z][a-z0-9]{1,5}", + param1 in "[a-z][a-z0-9]{1,5}", + param2 in "[a-z][a-z0-9]{1,5}", + ) { + prop_assume!(param1 != param2); + + let path1 = format!("/{}/{{{}}}", segment, param1); + let path2 = format!("/{}/{{{}}}", segment, param2); + + let result = catch_unwind(AssertUnwindSafe(|| { + async fn handler1() -> &'static str { "handler1" } + async fn handler2() -> &'static str { "handler2" } + + let _router = Router::new() + .route(&path1, get(handler1)) + .route(&path2, get(handler2)); + })); + + prop_assert!(result.is_err(), "Should have panicked due to conflict"); + + // Check that the panic message contains useful information + if let Err(panic_info) = result { + if let Some(msg) = panic_info.downcast_ref::() { + prop_assert!( + msg.contains("ROUTE CONFLICT DETECTED"), + "Error should contain 'ROUTE CONFLICT DETECTED', got: {}", + msg + ); + prop_assert!( + msg.contains("Existing:") && msg.contains("New:"), + "Error should contain both 'Existing:' and 'New:' labels, got: {}", + msg + ); + prop_assert!( + msg.contains("How to resolve:"), + "Error should contain resolution guidance, got: {}", + msg + ); + } + } + } + + /// Property: Exact duplicate paths conflict + /// + /// Registering the exact same path twice should always be detected as a conflict. + #[test] + fn prop_exact_duplicate_paths_conflict( + // Generate valid path segments + segments in prop::collection::vec("[a-z][a-z0-9]{0,5}", 1..4), + has_param in any::(), + ) { + // Build a path + let mut path = String::from("/"); + + for segment in &segments { + path.push_str(segment); + path.push('/'); + } + path.pop(); // Remove trailing slash + + if has_param { + path.push_str("/{id}"); + } + + // Try to register the same path twice - should panic + let result = catch_unwind(AssertUnwindSafe(|| { + async fn handler1() -> &'static str { "handler1" } + async fn handler2() -> &'static str { "handler2" } + + let _router = Router::new() + .route(&path, get(handler1)) + .route(&path, get(handler2)); + })); + + prop_assert!( + result.is_err(), + "Registering path '{}' twice should conflict but didn't", + path + ); + } + } } diff --git a/crates/rustapi-core/src/test_client.rs b/crates/rustapi-core/src/test_client.rs new file mode 100644 index 0000000..7b17691 --- /dev/null +++ b/crates/rustapi-core/src/test_client.rs @@ -0,0 +1,748 @@ +//! TestClient for integration testing without network binding +//! +//! This module provides a test client that allows sending simulated HTTP requests +//! through the full middleware and handler pipeline without starting a real server. +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::{RustApi, TestClient, get}; +//! +//! async fn hello() -> &'static str { +//! "Hello, World!" +//! } +//! +//! #[tokio::test] +//! async fn test_hello() { +//! let app = RustApi::new().route("/", get(hello)); +//! let client = TestClient::new(app); +//! +//! let response = client.get("/").await; +//! response.assert_status(200); +//! assert_eq!(response.text(), "Hello, World!"); +//! } +//! ``` + +use crate::middleware::{BoxedNext, LayerStack, BodyLimitLayer, DEFAULT_BODY_LIMIT}; +use crate::request::Request; +use crate::response::Response; +use crate::router::{RouteMatch, Router}; +use crate::error::ApiError; +use crate::response::IntoResponse; +use bytes::Bytes; +use http::{header, HeaderMap, HeaderValue, Method, StatusCode}; +use http_body_util::BodyExt; +use serde::{de::DeserializeOwned, Serialize}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +/// Test client for integration testing without network binding +/// +/// TestClient wraps a RustApi instance and allows sending simulated HTTP requests +/// through the full middleware and handler pipeline. +pub struct TestClient { + router: Arc, + layers: Arc, +} + +impl TestClient { + /// Create a new test client from a RustApi instance + /// + /// # Example + /// + /// ```rust,ignore + /// let app = RustApi::new().route("/", get(handler)); + /// let client = TestClient::new(app); + /// ``` + pub fn new(app: crate::app::RustApi) -> Self { + // Get the router and layers from the app + let layers = app.layers().clone(); + let router = app.into_router(); + + // Apply body limit layer if not already present + let mut layers = layers; + layers.prepend(Box::new(BodyLimitLayer::new(DEFAULT_BODY_LIMIT))); + + Self { + router: Arc::new(router), + layers: Arc::new(layers), + } + } + + /// Create a new test client with custom body limit + pub fn with_body_limit(app: crate::app::RustApi, limit: usize) -> Self { + let layers = app.layers().clone(); + let router = app.into_router(); + + let mut layers = layers; + layers.prepend(Box::new(BodyLimitLayer::new(limit))); + + Self { + router: Arc::new(router), + layers: Arc::new(layers), + } + } + + /// Send a GET request + /// + /// # Example + /// + /// ```rust,ignore + /// let response = client.get("/users").await; + /// ``` + pub async fn get(&self, path: &str) -> TestResponse { + self.request(TestRequest::get(path)).await + } + + /// Send a POST request with JSON body + /// + /// # Example + /// + /// ```rust,ignore + /// let response = client.post_json("/users", &CreateUser { name: "Alice" }).await; + /// ``` + pub async fn post_json(&self, path: &str, body: &T) -> TestResponse { + self.request(TestRequest::post(path).json(body)).await + } + + /// Send a request with full control + /// + /// # Example + /// + /// ```rust,ignore + /// let response = client.request( + /// TestRequest::put("/users/1") + /// .header("Authorization", "Bearer token") + /// .json(&UpdateUser { name: "Bob" }) + /// ).await; + /// ``` + pub async fn request(&self, req: TestRequest) -> TestResponse { + let method = req.method.clone(); + let path = req.path.clone(); + + // Match the route to get path params + let (handler, params) = match self.router.match_route(&path, &method) { + RouteMatch::Found { handler, params } => (handler.clone(), params), + RouteMatch::NotFound => { + let response = ApiError::not_found(format!("No route found for {} {}", method, path)) + .into_response(); + return TestResponse::from_response(response).await; + } + RouteMatch::MethodNotAllowed { allowed } => { + let allowed_str: Vec<&str> = allowed.iter().map(|m| m.as_str()).collect(); + let mut response = ApiError::new( + StatusCode::METHOD_NOT_ALLOWED, + "method_not_allowed", + format!("Method {} not allowed for {}", method, path), + ) + .into_response(); + + response.headers_mut().insert( + header::ALLOW, + allowed_str.join(", ").parse().unwrap(), + ); + return TestResponse::from_response(response).await; + } + }; + + // Build the internal Request + let uri: http::Uri = path.parse().unwrap_or_else(|_| "/".parse().unwrap()); + let mut builder = http::Request::builder() + .method(method) + .uri(uri); + + // Add headers + for (key, value) in req.headers.iter() { + builder = builder.header(key, value); + } + + let http_req = builder.body(()).unwrap(); + let (parts, _) = http_req.into_parts(); + + let body_bytes = req.body.unwrap_or_default(); + + let request = Request::new( + parts, + body_bytes, + self.router.state_ref(), + params, + ); + + // Create the final handler as a BoxedNext + let final_handler: BoxedNext = Arc::new(move |req: Request| { + let handler = handler.clone(); + Box::pin(async move { handler(req).await }) + as Pin + Send + 'static>> + }); + + // Execute through middleware stack + let response = self.layers.execute(request, final_handler).await; + + TestResponse::from_response(response).await + } +} + +/// Test request builder +/// +/// Provides a fluent API for building test requests with custom methods, +/// headers, and body content. +#[derive(Debug, Clone)] +pub struct TestRequest { + method: Method, + path: String, + headers: HeaderMap, + body: Option, +} + +impl TestRequest { + /// Create a new request with the given method and path + fn new(method: Method, path: &str) -> Self { + Self { + method, + path: path.to_string(), + headers: HeaderMap::new(), + body: None, + } + } + + /// Create a GET request + pub fn get(path: &str) -> Self { + Self::new(Method::GET, path) + } + + /// Create a POST request + pub fn post(path: &str) -> Self { + Self::new(Method::POST, path) + } + + /// Create a PUT request + pub fn put(path: &str) -> Self { + Self::new(Method::PUT, path) + } + + /// Create a PATCH request + pub fn patch(path: &str) -> Self { + Self::new(Method::PATCH, path) + } + + /// Create a DELETE request + pub fn delete(path: &str) -> Self { + Self::new(Method::DELETE, path) + } + + /// Add a header to the request + /// + /// # Example + /// + /// ```rust,ignore + /// let req = TestRequest::get("/") + /// .header("Authorization", "Bearer token") + /// .header("Accept", "application/json"); + /// ``` + pub fn header(mut self, key: &str, value: &str) -> Self { + if let (Ok(name), Ok(val)) = ( + key.parse::(), + HeaderValue::from_str(value), + ) { + self.headers.insert(name, val); + } + self + } + + /// Set the request body as JSON + /// + /// This automatically sets the Content-Type header to `application/json`. + /// + /// # Example + /// + /// ```rust,ignore + /// let req = TestRequest::post("/users") + /// .json(&CreateUser { name: "Alice" }); + /// ``` + pub fn json(mut self, body: &T) -> Self { + match serde_json::to_vec(body) { + Ok(bytes) => { + self.body = Some(Bytes::from(bytes)); + self.headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + } + Err(_) => { + // If serialization fails, leave body empty + } + } + self + } + + /// Set the request body as raw bytes + /// + /// # Example + /// + /// ```rust,ignore + /// let req = TestRequest::post("/upload") + /// .body("raw content"); + /// ``` + pub fn body(mut self, body: impl Into) -> Self { + self.body = Some(body.into()); + self + } + + /// Set the Content-Type header + pub fn content_type(self, content_type: &str) -> Self { + self.header("content-type", content_type) + } +} + +/// Test response with assertion helpers +/// +/// Provides methods to inspect and assert on the response status, headers, and body. +#[derive(Debug)] +pub struct TestResponse { + status: StatusCode, + headers: HeaderMap, + body: Bytes, +} + +impl TestResponse { + /// Create a TestResponse from an HTTP response + async fn from_response(response: Response) -> Self { + let (parts, body) = response.into_parts(); + let body_bytes = body.collect().await + .map(|b| b.to_bytes()) + .unwrap_or_default(); + + Self { + status: parts.status, + headers: parts.headers, + body: body_bytes, + } + } + + /// Get the response status code + pub fn status(&self) -> StatusCode { + self.status + } + + /// Get the response headers + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Get the response body as bytes + pub fn body(&self) -> &Bytes { + &self.body + } + + /// Get the response body as a string + /// + /// Returns an empty string if the body is not valid UTF-8. + pub fn text(&self) -> String { + String::from_utf8_lossy(&self.body).to_string() + } + + /// Parse the response body as JSON + /// + /// # Example + /// + /// ```rust,ignore + /// let user: User = response.json().unwrap(); + /// ``` + pub fn json(&self) -> Result { + serde_json::from_slice(&self.body) + } + + /// Assert that the response has the expected status code + /// + /// # Panics + /// + /// Panics if the status code doesn't match. + /// + /// # Example + /// + /// ```rust,ignore + /// response.assert_status(StatusCode::OK); + /// response.assert_status(200); + /// ``` + pub fn assert_status>(&self, expected: S) -> &Self { + let expected = expected.into(); + assert_eq!( + self.status, expected, + "Expected status {}, got {}. Body: {}", + expected, self.status, self.text() + ); + self + } + + /// Assert that the response has the expected header value + /// + /// # Panics + /// + /// Panics if the header doesn't exist or doesn't match. + /// + /// # Example + /// + /// ```rust,ignore + /// response.assert_header("content-type", "application/json"); + /// ``` + pub fn assert_header(&self, key: &str, expected: &str) -> &Self { + let actual = self.headers + .get(key) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + assert_eq!( + actual, expected, + "Expected header '{}' to be '{}', got '{}'", + key, expected, actual + ); + self + } + + /// Assert that the response body matches the expected JSON value + /// + /// # Panics + /// + /// Panics if the body can't be parsed as JSON or doesn't match. + /// + /// # Example + /// + /// ```rust,ignore + /// response.assert_json(&User { id: 1, name: "Alice".to_string() }); + /// ``` + pub fn assert_json(&self, expected: &T) -> &Self { + let actual: T = self.json().expect("Failed to parse response body as JSON"); + assert_eq!( + &actual, expected, + "JSON body mismatch" + ); + self + } + + /// Assert that the response body contains the expected string + /// + /// # Panics + /// + /// Panics if the body doesn't contain the expected string. + pub fn assert_body_contains(&self, expected: &str) -> &Self { + let body = self.text(); + assert!( + body.contains(expected), + "Expected body to contain '{}', got '{}'", + expected, body + ); + self + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::app::RustApi; + use crate::router::get; + use proptest::prelude::*; + use serde::{Deserialize, Serialize}; + + // Simple handler for testing + async fn hello() -> &'static str { + "Hello, World!" + } + + // Handler that returns JSON as string + async fn json_string_handler() -> String { + r#"{"message":"test","count":42}"#.to_string() + } + + // JSON data structure for testing + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + struct TestData { + message: String, + count: i32, + } + + // Handler that echoes body as string + async fn echo_body(body: crate::extract::Body) -> String { + String::from_utf8_lossy(&body.0).to_string() + } + + #[tokio::test] + async fn test_client_get_request() { + let app = RustApi::new().route("/", get(hello)); + let client = TestClient::new(app); + + let response = client.get("/").await; + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text(), "Hello, World!"); + } + + #[tokio::test] + async fn test_client_not_found() { + let app = RustApi::new().route("/", get(hello)); + let client = TestClient::new(app); + + let response = client.get("/nonexistent").await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn test_client_json_response() { + let app = RustApi::new().route("/json", get(json_string_handler)); + let client = TestClient::new(app); + + let response = client.get("/json").await; + response.assert_status(StatusCode::OK); + + let data: TestData = response.json().unwrap(); + assert_eq!(data.message, "test"); + assert_eq!(data.count, 42); + } + + #[tokio::test] + async fn test_client_post_json() { + let app = RustApi::new().route("/echo", crate::router::post(echo_body)); + let client = TestClient::new(app); + + let input = TestData { + message: "hello".to_string(), + count: 123, + }; + + let response = client.post_json("/echo", &input).await; + response.assert_status(StatusCode::OK); + + let output: TestData = response.json().unwrap(); + assert_eq!(output, input); + } + + #[tokio::test] + async fn test_request_builder_methods() { + // Test all HTTP methods are available + let get_req = TestRequest::get("/test"); + assert_eq!(get_req.method, Method::GET); + + let post_req = TestRequest::post("/test"); + assert_eq!(post_req.method, Method::POST); + + let put_req = TestRequest::put("/test"); + assert_eq!(put_req.method, Method::PUT); + + let patch_req = TestRequest::patch("/test"); + assert_eq!(patch_req.method, Method::PATCH); + + let delete_req = TestRequest::delete("/test"); + assert_eq!(delete_req.method, Method::DELETE); + } + + #[tokio::test] + async fn test_request_builder_headers() { + let req = TestRequest::get("/test") + .header("Authorization", "Bearer token") + .header("Accept", "application/json"); + + assert!(req.headers.contains_key("authorization")); + assert!(req.headers.contains_key("accept")); + } + + #[tokio::test] + async fn test_request_builder_json_sets_content_type() { + let data = TestData { + message: "test".to_string(), + count: 1, + }; + + let req = TestRequest::post("/test").json(&data); + + assert!(req.body.is_some()); + assert_eq!( + req.headers.get(header::CONTENT_TYPE).unwrap(), + "application/json" + ); + } + + #[tokio::test] + async fn test_response_assertions() { + let app = RustApi::new().route("/json", get(json_string_handler)); + let client = TestClient::new(app); + + let response = client.get("/json").await; + + // Chain assertions + response + .assert_status(StatusCode::OK) + .assert_body_contains("test"); + } + + #[tokio::test] + async fn test_response_assert_json() { + let app = RustApi::new().route("/json", get(json_string_handler)); + let client = TestClient::new(app); + + let response = client.get("/json").await; + + let expected = TestData { + message: "test".to_string(), + count: 42, + }; + + response.assert_json(&expected); + } + + // **Feature: phase4-ergonomics-v1, Property 10: TestClient Request/Response Handling** + // + // For any request sent through TestClient, it should be processed through the full + // middleware and handler pipeline, and the response should be accessible with correct + // status, headers, and body. When sending JSON, the Content-Type header should be + // automatically set to `application/json`. + // + // **Validates: Requirements 6.1, 6.2, 6.3, 6.4** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_test_client_request_response_handling( + message in "[a-zA-Z0-9 ]{1,50}", + count in 0i32..1000, + ) { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + // Create app with echo handler + let app = RustApi::new().route("/echo", crate::router::post(echo_body)); + let client = TestClient::new(app); + + // Create test data + let input = TestData { + message: message.clone(), + count, + }; + + // Send request through TestClient + let response = client.post_json("/echo", &input).await; + + // Verify response status is accessible + prop_assert_eq!(response.status(), StatusCode::OK); + + // Verify response body is accessible and correct + let output: TestData = response.json().expect("Should parse JSON"); + prop_assert_eq!(output.message, message); + prop_assert_eq!(output.count, count); + + Ok(()) + })?; + } + + #[test] + fn prop_test_client_json_content_type_auto_set( + message in "[a-zA-Z0-9]{1,20}", + ) { + // Verify that when sending JSON, Content-Type is automatically set + let data = TestData { + message, + count: 1, + }; + + let req = TestRequest::post("/test").json(&data); + + // Content-Type should be set to application/json + let content_type = req.headers.get(header::CONTENT_TYPE); + prop_assert!(content_type.is_some()); + prop_assert_eq!( + content_type.unwrap().to_str().unwrap(), + "application/json" + ); + + // Body should be set + prop_assert!(req.body.is_some()); + } + + #[test] + fn prop_test_client_processes_through_middleware( + path in "/[a-z]{1,10}", + ) { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + // Create app with a simple handler + let app = RustApi::new().route(&path, get(hello)); + let client = TestClient::new(app); + + // Request should go through middleware pipeline + let response = client.get(&path).await; + + // Should get successful response + prop_assert_eq!(response.status(), StatusCode::OK); + prop_assert_eq!(response.text(), "Hello, World!"); + + Ok(()) + })?; + } + + #[test] + fn prop_test_client_not_found_for_unregistered_paths( + registered_path in "/[a-z]{1,5}", + unregistered_path in "/[a-z]{6,10}", + ) { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + // Create app with one route + let app = RustApi::new().route(®istered_path, get(hello)); + let client = TestClient::new(app); + + // Request to unregistered path should return 404 + let response = client.get(&unregistered_path).await; + prop_assert_eq!(response.status(), StatusCode::NOT_FOUND); + + Ok(()) + })?; + } + } + + #[tokio::test] + async fn test_client_method_not_allowed() { + let app = RustApi::new().route("/get-only", get(hello)); + let client = TestClient::new(app); + + // POST to a GET-only route should return 405 + let response = client.request(TestRequest::post("/get-only")).await; + assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); + + // Should have Allow header + assert!(response.headers().contains_key(header::ALLOW)); + } + + #[tokio::test] + async fn test_client_custom_headers() { + // Handler that echoes back a specific header value + async fn echo_header(body: crate::extract::Body) -> String { + // For this test, we just verify the request goes through + // The header checking is done via the body echo + String::from_utf8_lossy(&body.0).to_string() + } + + let app = RustApi::new().route("/check", crate::router::post(echo_header)); + let client = TestClient::new(app); + + let response = client.request( + TestRequest::post("/check") + .header("X-Custom-Header", "test-value") + .body("test body") + ).await; + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text(), "test body"); + } + + #[tokio::test] + async fn test_client_raw_body() { + let app = RustApi::new().route("/echo", crate::router::post(echo_body)); + let client = TestClient::new(app); + + let response = client.request( + TestRequest::post("/echo") + .body("raw body content") + ).await; + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text(), "raw body content"); + } +} diff --git a/crates/rustapi-extras/Cargo.toml b/crates/rustapi-extras/Cargo.toml index 9ffbef2..bdda1d9 100644 --- a/crates/rustapi-extras/Cargo.toml +++ b/crates/rustapi-extras/Cargo.toml @@ -12,6 +12,7 @@ documentation.workspace = true [dependencies] # Core dependency rustapi-core = { workspace = true } +rustapi-openapi = { workspace = true } # Async tokio = { workspace = true } diff --git a/crates/rustapi-extras/src/config/mod.rs b/crates/rustapi-extras/src/config/mod.rs index 7363590..fa94ad1 100644 --- a/crates/rustapi-extras/src/config/mod.rs +++ b/crates/rustapi-extras/src/config/mod.rs @@ -515,13 +515,19 @@ mod tests { std::env::remove_var("UNIT_TEST_NUMBER"); } + #[derive(Debug, Deserialize, PartialEq)] + struct MissingVarTestConfig { + missing_var_test_string: String, + missing_var_test_number: u32, + } + #[test] fn test_config_from_env_missing_var() { - // Ensure the variable doesn't exist - std::env::remove_var("UNIT_TEST_STRING"); - std::env::remove_var("UNIT_TEST_NUMBER"); + // Ensure the variables don't exist (use unique names to avoid race conditions) + std::env::remove_var("MISSING_VAR_TEST_STRING"); + std::env::remove_var("MISSING_VAR_TEST_NUMBER"); - let result = Config::::from_env(); + let result = Config::::from_env(); assert!(result.is_err()); } diff --git a/crates/rustapi-extras/src/jwt/mod.rs b/crates/rustapi-extras/src/jwt/mod.rs index cf35458..60d3a8e 100644 --- a/crates/rustapi-extras/src/jwt/mod.rs +++ b/crates/rustapi-extras/src/jwt/mod.rs @@ -26,6 +26,7 @@ use http_body_util::Full; use jsonwebtoken::{decode, DecodingKey, Validation}; use rustapi_core::middleware::{BoxedNext, MiddlewareLayer}; use rustapi_core::{ApiError, FromRequestParts, Request, Response, Result}; +use rustapi_openapi::{Operation, OperationModifier}; use serde::de::DeserializeOwned; use serde::Serialize; use std::future::Future; @@ -85,13 +86,15 @@ impl JwtValidation { /// } /// /// let app = RustApi::new() -/// .layer(JwtLayer::::new("my-secret-key")) +/// .layer(JwtLayer::::new("my-secret-key") +/// .skip_paths(vec!["/health", "/docs", "/auth/login"])) /// .route("/protected", get(protected_handler)); /// ``` #[derive(Clone)] pub struct JwtLayer { secret: Arc, validation: JwtValidation, + skip_paths: Arc>, _claims: PhantomData, } @@ -101,6 +104,7 @@ impl JwtLayer { Self { secret: Arc::new(secret.into()), validation: JwtValidation::default(), + skip_paths: Arc::new(Vec::new()), _claims: PhantomData, } } @@ -111,6 +115,22 @@ impl JwtLayer { self } + /// Skip JWT validation for specific paths. + /// + /// Paths that start with any of the provided prefixes will bypass JWT validation. + /// This is useful for public endpoints like health checks, documentation, and login. + /// + /// # Example + /// + /// ```ignore + /// let layer = JwtLayer::::new("secret") + /// .skip_paths(vec!["/health", "/docs", "/auth/login"]); + /// ``` + pub fn skip_paths(mut self, paths: Vec<&str>) -> Self { + self.skip_paths = Arc::new(paths.into_iter().map(String::from).collect()); + self + } + /// Get the configured secret. pub fn secret(&self) -> &str { &self.secret @@ -141,8 +161,15 @@ impl MiddlewareLayer for Jw ) -> Pin + Send + 'static>> { let secret = self.secret.clone(); let validation = self.validation.clone(); + let skip_paths = self.skip_paths.clone(); Box::pin(async move { + // Check if this path should skip JWT validation + let path = req.uri().path(); + if skip_paths.iter().any(|skip| path.starts_with(skip)) { + return next(req).await; + } + // Extract the Authorization header let auth_header = req.headers().get(http::header::AUTHORIZATION); @@ -299,6 +326,35 @@ impl FromRequestParts for AuthUser { } } +// Implement OperationModifier for AuthUser to enable use in handlers +impl OperationModifier for AuthUser { + fn update_operation(op: &mut Operation) { + // Add 401 Unauthorized response to OpenAPI spec + use rustapi_openapi::{MediaType, ResponseSpec, SchemaRef}; + use std::collections::HashMap; + + op.responses.insert( + "401".to_string(), + ResponseSpec { + description: "Unauthorized - Invalid or missing JWT token".to_string(), + content: { + let mut map = HashMap::new(); + map.insert( + "application/json".to_string(), + MediaType { + schema: SchemaRef::Ref { + reference: "#/components/schemas/ErrorSchema".to_string(), + }, + }, + ); + Some(map) + }, + ..Default::default() + }, + ); + } +} + /// Helper function to create a JWT token (useful for testing) /// /// # Example diff --git a/crates/rustapi-macros/src/lib.rs b/crates/rustapi-macros/src/lib.rs index c3efcb4..4e67000 100644 --- a/crates/rustapi-macros/src/lib.rs +++ b/crates/rustapi-macros/src/lib.rs @@ -8,11 +8,149 @@ //! - `#[rustapi::put("/path")]` - PUT route handler //! - `#[rustapi::patch("/path")]` - PATCH route handler //! - `#[rustapi::delete("/path")]` - DELETE route handler +//! +//! ## Debugging +//! +//! Set `RUSTAPI_DEBUG=1` environment variable during compilation to see +//! expanded macro output for debugging purposes. use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, ItemFn, LitStr}; +/// Check if RUSTAPI_DEBUG is enabled at compile time +fn is_debug_enabled() -> bool { + std::env::var("RUSTAPI_DEBUG") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +/// Print debug output if RUSTAPI_DEBUG=1 is set +fn debug_output(name: &str, tokens: &proc_macro2::TokenStream) { + if is_debug_enabled() { + eprintln!("\n=== RUSTAPI_DEBUG: {} ===", name); + eprintln!("{}", tokens); + eprintln!("=== END {} ===\n", name); + } +} + +/// Validate route path syntax at compile time +/// +/// Returns Ok(()) if the path is valid, or Err with a descriptive error message. +fn validate_path_syntax(path: &str, span: proc_macro2::Span) -> Result<(), syn::Error> { + // Path must start with / + if !path.starts_with('/') { + return Err(syn::Error::new( + span, + format!("route path must start with '/', got: \"{}\"", path), + )); + } + + // Check for empty path segments (double slashes) + if path.contains("//") { + return Err(syn::Error::new( + span, + format!("route path contains empty segment (double slash): \"{}\"", path), + )); + } + + // Validate path parameter syntax + let mut brace_depth = 0; + let mut param_start = None; + + for (i, ch) in path.char_indices() { + match ch { + '{' => { + if brace_depth > 0 { + return Err(syn::Error::new( + span, + format!( + "nested braces are not allowed in route path at position {}: \"{}\"", + i, path + ), + )); + } + brace_depth += 1; + param_start = Some(i); + } + '}' => { + if brace_depth == 0 { + return Err(syn::Error::new( + span, + format!( + "unmatched closing brace '}}' at position {} in route path: \"{}\"", + i, path + ), + )); + } + brace_depth -= 1; + + // Check that parameter name is not empty + if let Some(start) = param_start { + let param_name = &path[start + 1..i]; + if param_name.is_empty() { + return Err(syn::Error::new( + span, + format!( + "empty parameter name '{{}}' at position {} in route path: \"{}\"", + start, path + ), + )); + } + // Validate parameter name contains only valid identifier characters + if !param_name.chars().all(|c| c.is_alphanumeric() || c == '_') { + return Err(syn::Error::new( + span, + format!( + "invalid parameter name '{{{}}}' at position {} - parameter names must contain only alphanumeric characters and underscores: \"{}\"", + param_name, start, path + ), + )); + } + // Parameter name must not start with a digit + if param_name.chars().next().map(|c| c.is_ascii_digit()).unwrap_or(false) { + return Err(syn::Error::new( + span, + format!( + "parameter name '{{{}}}' cannot start with a digit at position {}: \"{}\"", + param_name, start, path + ), + )); + } + } + param_start = None; + } + // Check for invalid characters in path (outside of parameters) + _ if brace_depth == 0 => { + // Allow alphanumeric, -, _, ., /, and common URL characters + if !ch.is_alphanumeric() && !"-_./*".contains(ch) { + return Err(syn::Error::new( + span, + format!( + "invalid character '{}' at position {} in route path: \"{}\"", + ch, i, path + ), + )); + } + } + _ => {} + } + } + + // Check for unclosed braces + if brace_depth > 0 { + return Err(syn::Error::new( + span, + format!( + "unclosed brace '{{' in route path (missing closing '}}'): \"{}\"", + path + ), + )); + } + + Ok(()) +} + /// Main entry point macro for RustAPI applications /// /// This macro wraps your async main function with the tokio runtime. @@ -47,6 +185,8 @@ pub fn main(_attr: TokenStream, item: TokenStream) -> TokenStream { } }; + debug_output("main", &expanded); + TokenStream::from(expanded) } @@ -66,6 +206,11 @@ fn generate_route_handler(method: &str, attr: TokenStream, item: TokenStream) -> let path_value = path.value(); + // Validate path syntax at compile time + if let Err(err) = validate_path_syntax(&path_value, path.span()) { + return err.to_compile_error().into(); + } + // Generate a companion module with route info let route_fn_name = syn::Ident::new( &format!("{}_route", fn_name), @@ -122,6 +267,8 @@ fn generate_route_handler(method: &str, attr: TokenStream, item: TokenStream) -> } }; + debug_output(&format!("{} {}", method, path_value), &expanded); + TokenStream::from(expanded) } diff --git a/examples/auth-api/Cargo.toml b/examples/auth-api/Cargo.toml new file mode 100644 index 0000000..feac445 --- /dev/null +++ b/examples/auth-api/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "auth-api" +version = "0.1.0" +edition = "2021" +publish = false +description = "Authentication example demonstrating JWT middleware" + +[dependencies] +rustapi-rs = { path = "../../crates/rustapi-rs", features = ["jwt", "rate-limit"] } +tokio = { version = "1.35", features = ["full"] } +serde = { version = "1.0", features = ["derive"] } +validator = { workspace = true } +utoipa = { workspace = true } diff --git a/examples/auth-api/src/main.rs b/examples/auth-api/src/main.rs new file mode 100644 index 0000000..5b76e18 --- /dev/null +++ b/examples/auth-api/src/main.rs @@ -0,0 +1,230 @@ +//! Authentication API Example for RustAPI +//! +//! This example demonstrates: +//! - JWT authentication middleware +//! - Protected routes +//! - Rate limiting +//! +//! Run with: cargo run -p auth-api +//! Then visit: http://127.0.0.1:8080/docs +//! +//! ## Testing the API +//! +//! 1. Login to get a token: +//! ```bash +//! curl -X POST http://127.0.0.1:8080/auth/login \ +//! -H "Content-Type: application/json" \ +//! -d '{"username": "admin", "password": "secret"}' +//! ``` +//! +//! 2. Access protected route with token: +//! ```bash +//! curl http://127.0.0.1:8080/protected/profile \ +//! -H "Authorization: Bearer " +//! ``` + +use rustapi_rs::prelude::*; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +// ============================================ +// Configuration +// ============================================ + +/// JWT secret key (in production, use environment variable) +const JWT_SECRET: &str = "super-secret-key-change-in-production"; + +/// Token expiration time (1 hour) +const TOKEN_EXPIRY_SECS: u64 = 3600; + +// ============================================ +// Data Models +// ============================================ + +/// JWT claims structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Claims { + /// Subject (user ID) + pub sub: String, + /// Username + pub username: String, + /// User role + pub role: String, + /// Expiration timestamp + pub exp: u64, +} + +/// Login request body +#[derive(Debug, Deserialize, Validate, Schema)] +pub struct LoginRequest { + #[validate(length(min = 1, max = 50))] + pub username: String, + #[validate(length(min = 1, max = 100))] + pub password: String, +} + +/// Login response with JWT token +#[derive(Debug, Serialize, Schema)] +pub struct LoginResponse { + pub token: String, + pub token_type: String, + pub expires_in: u64, +} + +/// User profile response +#[derive(Debug, Serialize, Schema)] +pub struct UserProfile { + pub user_id: String, + pub username: String, + pub role: String, +} + +/// Public message response +#[derive(Debug, Serialize, Schema)] +pub struct Message { + pub message: String, +} + +// ============================================ +// Public Handlers (No Auth Required) +// ============================================ + +/// Public endpoint - no authentication required +#[rustapi_rs::get("/")] +#[rustapi_rs::tag("Public")] +#[rustapi_rs::summary("Welcome")] +async fn welcome() -> Json { + Json(Message { + message: "Welcome to the Auth API! Login at /auth/login".to_string(), + }) +} + +/// Health check endpoint +#[rustapi_rs::get("/health")] +#[rustapi_rs::tag("Public")] +#[rustapi_rs::summary("Health Check")] +async fn health() -> &'static str { + "OK" +} + +/// Login endpoint - returns JWT token +#[rustapi_rs::post("/auth/login")] +#[rustapi_rs::tag("Authentication")] +#[rustapi_rs::summary("Login")] +#[rustapi_rs::description("Authenticate with username and password to receive a JWT token.")] +async fn login(ValidatedJson(body): ValidatedJson) -> Result, ApiError> { + // In production, verify credentials against database + // For demo, accept admin/secret + if body.username != "admin" || body.password != "secret" { + return Err(ApiError::unauthorized("Invalid username or password")); + } + + // Calculate expiration time + let exp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + TOKEN_EXPIRY_SECS; + + // Create claims + let claims = Claims { + sub: "user-123".to_string(), + username: body.username, + role: "admin".to_string(), + exp, + }; + + // Generate token + let token = create_token(&claims, JWT_SECRET) + .map_err(|e| ApiError::internal(format!("Failed to create token: {}", e)))?; + + Ok(Json(LoginResponse { + token, + token_type: "Bearer".to_string(), + expires_in: TOKEN_EXPIRY_SECS, + })) +} + + +// ============================================ +// Protected Handlers (Auth Required) +// ============================================ + +/// Get current user's profile (requires authentication) +async fn get_profile(AuthUser(claims): AuthUser) -> Json { + Json(UserProfile { + user_id: claims.sub, + username: claims.username, + role: claims.role, + }) +} + +/// Admin-only endpoint +async fn admin_only(AuthUser(claims): AuthUser) -> Result, ApiError> { + if claims.role != "admin" { + return Err(ApiError::forbidden("Admin access required")); + } + + Ok(Json(Message { + message: format!("Hello admin {}! You have full access.", claims.username), + })) +} + +/// Protected data endpoint +async fn get_protected_data(AuthUser(claims): AuthUser) -> Json { + Json(Message { + message: format!("Secret data for user: {}", claims.username), + }) +} + +// ============================================ +// Main +// ============================================ + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("🔐 Authentication API Example"); + println!(); + println!("Public Endpoints:"); + println!(" GET / - Welcome message"); + println!(" GET /health - Health check"); + println!(" POST /auth/login - Login (username: admin, password: secret)"); + println!(); + println!("Protected Endpoints (require JWT token):"); + println!(" GET /protected/profile - Get user profile"); + println!(" GET /protected/admin - Admin only"); + println!(" GET /protected/data - Protected data"); + println!(); + println!("Documentation:"); + println!(" GET /docs - Swagger UI (Basic Auth: docs / docs123)"); + println!(); + println!("Server running at http://127.0.0.1:8080"); + + // Create the app with JWT middleware for protected routes + // Public routes (/health, /auth/login, /) are excluded from JWT validation + // Docs has its own Basic Auth protection + let app = RustApi::new() + .body_limit(1024 * 1024) // 1MB limit + .layer(RequestIdLayer::new()) + .layer(TracingLayer::new()) + // Rate limiting: 100 requests per minute + .layer(RateLimitLayer::new(100, Duration::from_secs(60))) + // JWT middleware - skip public paths (docs has its own auth) + .layer(JwtLayer::::new(JWT_SECRET) + .skip_paths(vec!["/health", "/docs", "/auth/login", "/"])) + .register_schema::() + .register_schema::() + .register_schema::() + .register_schema::() + // Public routes + .mount_route(welcome_route()) + .mount_route(health_route()) + .mount_route(login_route()) + // Protected routes + .route("/protected/profile", get(get_profile)) + .route("/protected/admin", get(admin_only)) + .route("/protected/data", get(get_protected_data)) + // Docs with Basic Auth protection + .docs_with_auth("/docs", "docs", "docs123"); + + app.run("127.0.0.1:8080").await +} diff --git a/examples/crud-api/Cargo.toml b/examples/crud-api/Cargo.toml new file mode 100644 index 0000000..74679ad --- /dev/null +++ b/examples/crud-api/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "crud-api" +version = "0.1.0" +edition = "2021" +publish = false +description = "CRUD API example demonstrating all RustAPI features" + +[dependencies] +rustapi-rs = { path = "../../crates/rustapi-rs", features = ["full"] } +tokio = { version = "1.35", features = ["full"] } +serde = { version = "1.0", features = ["derive"] } +validator = { workspace = true } +utoipa = { workspace = true } diff --git a/examples/crud-api/src/main.rs b/examples/crud-api/src/main.rs new file mode 100644 index 0000000..9c54cce --- /dev/null +++ b/examples/crud-api/src/main.rs @@ -0,0 +1,344 @@ +//! CRUD API Example for RustAPI +//! +//! This example demonstrates a complete CRUD API with: +//! - All HTTP methods (GET, POST, PUT, PATCH, DELETE) +//! - Request validation +//! - Error handling +//! - Middleware (RequestId, Tracing, Body Limit) +//! - OpenAPI documentation with Swagger UI +//! +//! Run with: cargo run -p crud-api +//! Then visit: http://127.0.0.1:8080/docs + +use rustapi_rs::prelude::*; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +// ============================================ +// Data Models +// ============================================ + +/// A task in our todo list +#[derive(Debug, Clone, Serialize, Deserialize, Schema)] +pub struct Task { + pub id: u64, + pub title: String, + pub description: Option, + pub completed: bool, +} + +/// Request body for creating a task +#[derive(Debug, Deserialize, Validate, Schema)] +pub struct CreateTask { + #[validate(length(min = 1, max = 200, message = "Title must be 1-200 characters"))] + pub title: String, + #[validate(length(max = 1000, message = "Description must be at most 1000 characters"))] + pub description: Option, +} + +/// Request body for updating a task +#[derive(Debug, Deserialize, Validate, Schema)] +pub struct UpdateTask { + #[validate(length(min = 1, max = 200, message = "Title must be 1-200 characters"))] + pub title: String, + #[validate(length(max = 1000, message = "Description must be at most 1000 characters"))] + pub description: Option, + pub completed: bool, +} + +/// Request body for partial task update +#[derive(Debug, Deserialize, Schema)] +pub struct PatchTask { + pub title: Option, + pub description: Option, + pub completed: Option, +} + +/// Query parameters for listing tasks +#[derive(Debug, Deserialize, IntoParams)] +pub struct ListParams { + /// Filter by completion status + pub completed: Option, + /// Page number (1-indexed) + #[param(minimum = 1)] + pub page: Option, + /// Items per page + #[param(minimum = 1, maximum = 100)] + pub limit: Option, +} + +/// Paginated response wrapper +#[derive(Debug, Serialize, Schema)] +pub struct PaginatedTasks { + pub tasks: Vec, + pub total: usize, + pub page: u32, + pub limit: u32, +} + +// ============================================ +// In-Memory Database +// ============================================ + +/// Simple in-memory task store +#[derive(Clone)] +pub struct TaskStore { + tasks: Arc>>, + next_id: Arc>, +} + +impl TaskStore { + pub fn new() -> Self { + Self { + tasks: Arc::new(RwLock::new(HashMap::new())), + next_id: Arc::new(RwLock::new(1)), + } + } + + pub fn create(&self, create: CreateTask) -> Task { + let mut next_id = self.next_id.write().unwrap(); + let id = *next_id; + *next_id += 1; + + let task = Task { + id, + title: create.title, + description: create.description, + completed: false, + }; + + self.tasks.write().unwrap().insert(id, task.clone()); + task + } + + pub fn get(&self, id: u64) -> Option { + self.tasks.read().unwrap().get(&id).cloned() + } + + pub fn list(&self, completed: Option) -> Vec { + let tasks = self.tasks.read().unwrap(); + let mut result: Vec = tasks + .values() + .filter(|t| completed.map_or(true, |c| t.completed == c)) + .cloned() + .collect(); + result.sort_by_key(|t| t.id); + result + } + + pub fn update(&self, id: u64, update: UpdateTask) -> Option { + let mut tasks = self.tasks.write().unwrap(); + if let Some(task) = tasks.get_mut(&id) { + task.title = update.title; + task.description = update.description; + task.completed = update.completed; + Some(task.clone()) + } else { + None + } + } + + pub fn patch(&self, id: u64, patch: PatchTask) -> Option { + let mut tasks = self.tasks.write().unwrap(); + if let Some(task) = tasks.get_mut(&id) { + if let Some(title) = patch.title { + task.title = title; + } + if let Some(description) = patch.description { + task.description = Some(description); + } + if let Some(completed) = patch.completed { + task.completed = completed; + } + Some(task.clone()) + } else { + None + } + } + + pub fn delete(&self, id: u64) -> bool { + self.tasks.write().unwrap().remove(&id).is_some() + } +} + +impl Default for TaskStore { + fn default() -> Self { + Self::new() + } +} + + +// ============================================ +// Handlers +// ============================================ + +/// List all tasks with optional filtering and pagination +#[rustapi_rs::get("/tasks")] +#[rustapi_rs::tag("Tasks")] +#[rustapi_rs::summary("List Tasks")] +#[rustapi_rs::description("Returns a paginated list of tasks. Can filter by completion status.")] +async fn list_tasks( + State(store): State, + Query(params): Query, +) -> Json { + let all_tasks = store.list(params.completed); + let total = all_tasks.len(); + + let page = params.page.unwrap_or(1); + let limit = params.limit.unwrap_or(10); + let skip = ((page - 1) * limit) as usize; + + let tasks: Vec = all_tasks + .into_iter() + .skip(skip) + .take(limit as usize) + .collect(); + + Json(PaginatedTasks { + tasks, + total, + page, + limit, + }) +} + +/// Get a single task by ID +#[rustapi_rs::get("/tasks/{id}")] +#[rustapi_rs::tag("Tasks")] +#[rustapi_rs::summary("Get Task")] +#[rustapi_rs::description("Returns a single task by its ID. Returns 404 if not found.")] +async fn get_task( + State(store): State, + Path(id): Path, +) -> Result, ApiError> { + store + .get(id) + .map(Json) + .ok_or_else(|| ApiError::not_found(format!("Task {} not found", id))) +} + +/// Create a new task +#[rustapi_rs::post("/tasks")] +#[rustapi_rs::tag("Tasks")] +#[rustapi_rs::summary("Create Task")] +#[rustapi_rs::description("Creates a new task. Validates title (1-200 chars) and description (max 1000 chars).")] +async fn create_task( + State(store): State, + ValidatedJson(body): ValidatedJson, +) -> Created { + let task = store.create(body); + Created(task) +} + +/// Update a task completely (PUT) +#[rustapi_rs::put("/tasks/{id}")] +#[rustapi_rs::tag("Tasks")] +#[rustapi_rs::summary("Update Task")] +#[rustapi_rs::description("Replaces a task entirely. All fields are required.")] +async fn update_task( + State(store): State, + Path(id): Path, + ValidatedJson(body): ValidatedJson, +) -> Result, ApiError> { + store + .update(id, body) + .map(Json) + .ok_or_else(|| ApiError::not_found(format!("Task {} not found", id))) +} + +/// Partially update a task (PATCH) +#[rustapi_rs::patch("/tasks/{id}")] +#[rustapi_rs::tag("Tasks")] +#[rustapi_rs::summary("Patch Task")] +#[rustapi_rs::description("Partially updates a task. Only provided fields are updated.")] +async fn patch_task( + State(store): State, + Path(id): Path, + Json(body): Json, +) -> Result, ApiError> { + store + .patch(id, body) + .map(Json) + .ok_or_else(|| ApiError::not_found(format!("Task {} not found", id))) +} + +/// Delete a task +#[rustapi_rs::delete("/tasks/{id}")] +#[rustapi_rs::tag("Tasks")] +#[rustapi_rs::summary("Delete Task")] +#[rustapi_rs::description("Deletes a task by ID. Returns 204 on success, 404 if not found.")] +async fn delete_task( + State(store): State, + Path(id): Path, +) -> Result { + if store.delete(id) { + Ok(NoContent) + } else { + Err(ApiError::not_found(format!("Task {} not found", id))) + } +} + +/// Health check endpoint +#[rustapi_rs::get("/health")] +#[rustapi_rs::tag("System")] +#[rustapi_rs::summary("Health Check")] +async fn health() -> &'static str { + "OK" +} + +// ============================================ +// Main +// ============================================ + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize task store with some sample data + let store = TaskStore::new(); + store.create(CreateTask { + title: "Learn RustAPI".to_string(), + description: Some("Build a web API with RustAPI framework".to_string()), + }); + store.create(CreateTask { + title: "Write tests".to_string(), + description: Some("Add unit and integration tests".to_string()), + }); + store.create(CreateTask { + title: "Deploy to production".to_string(), + description: None, + }); + + println!("🚀 CRUD API Example"); + println!(); + println!("Endpoints:"); + println!(" GET /tasks - List all tasks"); + println!(" GET /tasks/:id - Get a task"); + println!(" POST /tasks - Create a task"); + println!(" PUT /tasks/:id - Update a task"); + println!(" PATCH /tasks/:id - Partially update a task"); + println!(" DELETE /tasks/:id - Delete a task"); + println!(" GET /health - Health check"); + println!(" GET /docs - Swagger UI"); + println!(); + println!("Server running at http://127.0.0.1:8080"); + + RustApi::new() + .state(store) + .body_limit(1024 * 1024) // 1MB limit + .layer(RequestIdLayer::new()) + .layer(TracingLayer::new()) + .register_schema::() + .register_schema::() + .register_schema::() + .register_schema::() + .register_schema::() + .mount_route(list_tasks_route()) + .mount_route(get_task_route()) + .mount_route(create_task_route()) + .mount_route(update_task_route()) + .mount_route(patch_task_route()) + .mount_route(delete_task_route()) + .mount_route(health_route()) + .docs("/docs") + .run("127.0.0.1:8080") + .await +}