diff --git a/.gitignore b/.gitignore index e1f17a1..1d69e7c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ /.kiro assets/myadam.jpg +.github/copilot-instructions.md +docs/UPDATE_SUMMARIES.md diff --git a/Cargo.lock b/Cargo.lock index 03ea471..44b462e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,12 +53,77 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + +[[package]] +name = "assert_cmd" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcbb6924530aa9e0432442af08bbcafdad182db80d2e560da42a6d442535bf85" +dependencies = [ + "anstyle", + "bstr", + "libc", + "predicates", + "predicates-core", + "predicates-tree", + "wait-timeout", +] + [[package]] name = "async-compression" version = "0.4.36" @@ -155,6 +220,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "brotli" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor 4.0.3", +] + [[package]] name = "brotli" version = "8.0.2" @@ -163,7 +239,17 @@ checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", - "brotli-decompressor", + "brotli-decompressor 5.0.0", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a334ef7c9e23abf0ce748e8cd309037da93e606ad52eb372e4ce327a0dcfbdfd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", ] [[package]] @@ -176,6 +262,17 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "bstr" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" +dependencies = [ + "memchr", + "regex-automata", + "serde", +] + [[package]] name = "bumpalo" version = "3.19.1" @@ -194,6 +291,28 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +[[package]] +name = "cargo-rustapi" +version = "0.1.4" +dependencies = [ + "anyhow", + "assert_cmd", + "clap", + "console", + "dialoguer", + "indicatif", + "serde", + "serde_json", + "tempfile", + "thiserror 1.0.69", + "tokio", + "toml", + "toml_edit", + "tracing", + "tracing-subscriber", + "walkdir", +] + [[package]] name = "cast" version = "0.3.0" @@ -232,6 +351,28 @@ dependencies = [ "windows-link", ] +[[package]] +name = "chrono-tz" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + [[package]] name = "ciborium" version = "0.2.2" @@ -266,6 +407,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" dependencies = [ "clap_builder", + "clap_derive", ] [[package]] @@ -274,8 +416,22 @@ version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" dependencies = [ + "anstream", "anstyle", "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.111", ] [[package]] @@ -284,13 +440,19 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + [[package]] name = "compression-codecs" version = "0.4.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0f7ac3e5b97fdce45e8922fb05cae2c37f7bbd63d30dd94821dacfd8f3f2bf2" dependencies = [ - "brotli", + "brotli 8.0.2", "compression-core", "flate2", "memchr", @@ -313,6 +475,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -524,6 +699,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + [[package]] name = "der" version = "0.7.10" @@ -544,6 +725,31 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "deunicode" +version = "1.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abd57806937c9cc163efc8ea3910e00a62e2aeb0b8119f1793a978088f8f6b04" + +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror 1.0.69", + "zeroize", +] + +[[package]] +name = "difflib" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" + [[package]] name = "digest" version = "0.10.7" @@ -582,6 +788,12 @@ dependencies = [ "serde", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "envy" version = "0.4.2" @@ -802,6 +1014,30 @@ dependencies = [ "wasip2", ] +[[package]] +name = "globset" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52dfc19153a48bde0cbd630453615c8151bce3a5adfac7a0aebfbf0a1e1f57e3" +dependencies = [ + "aho-corasick", + "bstr", + "log", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "globwalk" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" +dependencies = [ + "bitflags", + "ignore", + "walkdir", +] + [[package]] name = "h2" version = "0.4.12" @@ -987,6 +1223,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humansize" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7" +dependencies = [ + "libm", +] + [[package]] name = "hyper" version = "1.8.1" @@ -1179,6 +1424,22 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "ignore" +version = "0.4.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3d782a365a015e0f5c04902246139249abf769125006fbe7649e2ee88169b4a" +dependencies = [ + "crossbeam-deque", + "globset", + "log", + "memchr", + "regex-automata", + "same-file", + "walkdir", + "winapi-util", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -1201,6 +1462,19 @@ dependencies = [ "serde_core", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "inventory" version = "0.3.21" @@ -1237,6 +1511,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.10.5" @@ -1526,12 +1806,24 @@ dependencies = [ "libm", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "once_cell" version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "oorandom" version = "11.1.5" @@ -1567,6 +1859,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "parse-zoneinfo" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" +dependencies = [ + "regex", +] + [[package]] name = "pem" version = "3.0.6" @@ -1592,6 +1893,87 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "pest" +version = "2.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9eb05c21a464ea704b53158d358a31e6425db2f63a1a7312268b05fe2b75f7" +dependencies = [ + "memchr", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f9dbced329c441fa79d80472764b1a2c7e57123553b8519b36663a2fb234ed" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bb96d5051a78f44f43c8f712d8e810adb0ebf923fc9ed2655a7f66f63ba8ee5" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.111", +] + +[[package]] +name = "pest_meta" +version = "2.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "602113b5b5e8621770cfd490cfd90b9f84ab29bd2b0e49ad83eb6d186cef2365" +dependencies = [ + "pest", + "sha2", +] + +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand 0.8.5", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.10" @@ -1679,6 +2061,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" + [[package]] name = "potential_utf" version = "0.1.4" @@ -1703,6 +2091,33 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "difflib", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1987,8 +2402,10 @@ name = "rustapi-core" version = "0.1.4" dependencies = [ "base64 0.22.1", + "brotli 6.0.0", "bytes", "cookie", + "flate2", "futures-util", "http", "http-body-util", @@ -2035,9 +2452,11 @@ dependencies = [ "serde", "serde_json", "sqlx", + "tempfile", "thiserror 1.0.69", "tokio", "tracing", + "urlencoding", ] [[package]] @@ -2070,6 +2489,8 @@ dependencies = [ "rustapi-macros", "rustapi-openapi", "rustapi-toon", + "rustapi-view", + "rustapi-ws", "serde", "serde_json", "tokio", @@ -2108,6 +2529,47 @@ dependencies = [ "validator", ] +[[package]] +name = "rustapi-view" +version = "0.1.4" +dependencies = [ + "bytes", + "http", + "http-body-util", + "rustapi-core", + "rustapi-openapi", + "serde", + "serde_json", + "tera", + "thiserror 1.0.69", + "tokio", + "tracing", +] + +[[package]] +name = "rustapi-ws" +version = "0.1.4" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-util", + "http", + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "rustapi-core", + "rustapi-openapi", + "serde", + "serde_json", + "sha1", + "thiserror 1.0.69", + "tokio", + "tokio-tungstenite", + "tracing", + "tungstenite", +] + [[package]] name = "rustix" version = "1.1.3" @@ -2204,6 +2666,15 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2247,6 +2718,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77" + [[package]] name = "shlex" version = "1.3.0" @@ -2291,12 +2768,28 @@ dependencies = [ "time", ] +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slab" version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" +[[package]] +name = "slug" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882a80f72ee45de3cc9a5afeb2da0331d58df69e4e7d8eeb5d3c7784ae67e724" +dependencies = [ + "deunicode", + "wasm-bindgen", +] + [[package]] name = "smallvec" version = "1.15.1" @@ -2630,6 +3123,46 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "templates-example" +version = "0.1.0" +dependencies = [ + "rustapi-rs", + "serde", + "tokio", + "tracing", + "tracing-subscriber", + "utoipa", +] + +[[package]] +name = "tera" +version = "1.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8004bca281f2d32df3bacd59bc67b312cb4c70cea46cbd79dbe8ac5ed206722" +dependencies = [ + "chrono", + "chrono-tz", + "globwalk", + "humansize", + "lazy_static", + "percent-encoding", + "pest", + "pest_derive", + "rand 0.8.5", + "regex", + "serde", + "serde_json", + "slug", + "unicode-segmentation", +] + +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "thiserror" version = "1.0.69" @@ -2784,6 +3317,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.17" @@ -2797,6 +3342,47 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap 2.12.1", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "toon-api" version = "0.1.0" @@ -2963,12 +3549,36 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror 1.0.69", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "unarray" version = "0.1.4" @@ -3008,6 +3618,18 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "untrusted" version = "0.9.0" @@ -3026,12 +3648,30 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "utoipa" version = "4.2.3" @@ -3220,6 +3860,29 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "websocket-example" +version = "0.1.0" +dependencies = [ + "futures-util", + "rustapi-rs", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "whoami" version = "1.6.1" @@ -3327,6 +3990,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" @@ -3531,6 +4203,15 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen" version = "0.46.0" diff --git a/Cargo.toml b/Cargo.toml index 2d2d97c..34d6bf4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,9 @@ members = [ "crates/rustapi-openapi", "crates/rustapi-extras", "crates/rustapi-toon", + "crates/rustapi-ws", + "crates/rustapi-view", + "crates/cargo-rustapi", "examples/hello-world", "examples/sqlx-crud", "examples/crud-api", @@ -15,6 +18,8 @@ members = [ "examples/proof-of-concept", "examples/toon-api", "examples/mcp-server", + "examples/websocket", + "examples/templates", "benches/toon_bench", ] @@ -83,6 +88,19 @@ toon-format = { version = "0.4", default-features = false } # Benchmarking criterion = { version = "0.5", features = ["html_reports"] } +# WebSocket +tokio-tungstenite = "0.24" +tungstenite = "0.24" + +# Template engine +tera = "1.19" + +# CLI +clap = { version = "4.5", features = ["derive", "color"] } +dialoguer = "0.11" +indicatif = "0.17" +console = "0.15" + # Internal crates rustapi-core = { path = "crates/rustapi-core", version = "0.1.4", default-features = false } rustapi-macros = { path = "crates/rustapi-macros", version = "0.1.4" } @@ -90,3 +108,6 @@ rustapi-validate = { path = "crates/rustapi-validate", version = "0.1.4" } rustapi-openapi = { path = "crates/rustapi-openapi", version = "0.1.4", default-features = false } rustapi-extras = { path = "crates/rustapi-extras", version = "0.1.4" } rustapi-toon = { path = "crates/rustapi-toon", version = "0.1.4" } +rustapi-ws = { path = "crates/rustapi-ws", version = "0.1.4" } +rustapi-view = { path = "crates/rustapi-view", version = "0.1.4" } + diff --git a/README.md b/README.md index 32839c5..e251159 100644 --- a/README.md +++ b/README.md @@ -112,24 +112,29 @@ async fn main() -> Result<(), Box> { | Feature | Description | |---------|-------------| -| **Type-Safe Extractors** | `Json`, `Query`, `Path` — compile-time guarantees | +| **Type-Safe Extractors** | `Json`, `Query`, `Path`, `WebSocket` — compile-time guarantees | | **Zero-Config Routing** | Macro-decorated routes auto-register at startup (`RustApi::auto()`) | | **Auto OpenAPI** | Your code = your docs. `/docs` endpoint out of the box | | **Validation** | `#[validate(email)]` → automatic 422 responses | | **JWT Auth** | One-line auth with `AuthUser` extractor | | **CORS & Rate Limit** | Production-ready middleware | | **TOON Format** | **50-58% token savings** for LLMs | +| **WebSocket** | Real-time bidirectional communication with broadcast support | +| **Template Engine** | Server-side HTML rendering with Tera templates | +| **CLI Tool** | `cargo-rustapi` for project scaffolding | ### Optional Features ```toml -rustapi-rs = { version = "0.1.4", features = ["jwt", "cors", "toon"] } +rustapi-rs = { version = "0.1.4", features = ["jwt", "cors", "toon", "ws", "view"] } ``` - `jwt` — JWT authentication - `cors` — CORS middleware - `rate-limit` — IP-based rate limiting - `toon` — LLM-optimized responses +- `ws` — WebSocket support +- `view` — Template engine (Tera) - `full` — Everything included --- @@ -145,10 +150,73 @@ cargo run -p auth-api cargo run -p sqlx-crud cargo run -p toon-api cargo run -p proof-of-concept +cargo run -p websocket # WebSocket example +cargo run -p templates # Template engine example ``` --- +## 🔌 Real-time: WebSocket Support + +RustAPI provides first-class WebSocket support for real-time applications. + +```rust +use rustapi_rs::ws::{WebSocket, Message, Broadcast}; + +#[rustapi_rs::get("/ws")] +async fn websocket(ws: WebSocket) -> WebSocketUpgrade { + ws.on_upgrade(handle_connection) +} + +async fn handle_connection(mut stream: WebSocketStream) { + while let Some(msg) = stream.recv().await { + match msg { + Message::Text(text) => { + stream.send(Message::Text(format!("Echo: {}", text))).await.ok(); + } + Message::Close(_) => break, + _ => {} + } + } +} +``` + +**Features:** +- Full WebSocket protocol support (text, binary, ping/pong) +- `Broadcast` channel for pub/sub patterns +- Seamless integration with RustAPI routing + +--- + +## 🎨 Template Engine + +Server-side HTML rendering with Tera templates. + +```rust +use rustapi_rs::view::{Templates, View, ContextBuilder}; + +#[rustapi_rs::get("/")] +async fn home(templates: Templates) -> View<()> { + View::new(&templates, "index.html", ()) +} + +#[rustapi_rs::get("/users/{id}")] +async fn user_page(templates: Templates, Path(id): Path) -> View { + let user = get_user(id); + View::with_context(&templates, "user.html", user, |ctx| { + ctx.insert("title", &format!("User: {}", user.name)); + }) +} +``` + +**Features:** +- Tera template engine (Jinja2-like syntax) +- Type-safe context with `ContextBuilder` +- Template inheritance support +- Auto-escape HTML by default + +--- + ## 🤖 LLM-Optimized: TOON Format RustAPI is built for **AI-powered APIs**. @@ -179,6 +247,34 @@ async fn users(accept: AcceptHeader) -> LlmResponse { --- +## 🛠️ CLI Tool: cargo-rustapi + +Scaffold new RustAPI projects with ease. + +```bash +# Install the CLI +cargo install cargo-rustapi + +# Create a new project +cargo rustapi new my-api + +# Interactive mode +cargo rustapi new my-api --interactive +``` + +**Available Templates:** +- `minimal` — Basic RustAPI setup +- `api` — REST API with CRUD operations +- `web` — Full web app with templates and WebSocket +- `full` — Everything included + +**Commands:** +- `cargo rustapi new ` — Create new project +- `cargo rustapi generate ` — Generate handlers, models, middleware +- `cargo rustapi docs` — Generate API documentation + +--- + ## Architecture RustAPI follows a **Facade Architecture** — a stable public API that shields you from internal changes. @@ -214,6 +310,8 @@ graph TB Validate["rustapi-validate
Request Validation"] Toon["rustapi-toon
LLM Optimization"] Extras["rustapi-extras
JWT/CORS/RateLimit"] + WsCrate["rustapi-ws
WebSocket Support"] + ViewCrate["rustapi-view
Template Engine"] end subgraph Foundation["🏗️ Foundation Layer"] @@ -292,6 +390,8 @@ graph BT Validate[rustapi-validate] Toon[rustapi-toon] Extras[rustapi-extras] + WS[rustapi-ws] + View[rustapi-view] end subgraph External["External Dependencies"] @@ -300,6 +400,8 @@ graph BT Serde[serde] Utoipa[utoipa] Validator[validator] + Tungstenite[tungstenite] + Tera[tera] end App --> RS @@ -309,6 +411,8 @@ graph BT RS --> Validate RS -.->|optional| Toon RS -.->|optional| Extras + RS -.->|optional| WS + RS -.->|optional| View Core --> Tokio Core --> Hyper @@ -316,6 +420,8 @@ graph BT OpenAPI --> Utoipa Validate --> Validator Toon --> Serde + WS --> Tungstenite + View --> Tera style RS fill:#e1f5fe style App fill:#c8e6c9 @@ -342,6 +448,8 @@ graph BT | `rustapi-validate` | Request body/query validation via `#[validate]` | | `rustapi-toon` | TOON format serializer, content negotiation, LLM headers | | `rustapi-extras` | JWT auth, CORS, rate limiting middleware | +| `rustapi-ws` | WebSocket support with broadcast channels | +| `rustapi-view` | Template engine (Tera) for server-side rendering | --- @@ -351,6 +459,9 @@ graph BT - [x] OpenAPI & Validation - [x] JWT, CORS, Rate Limiting - [x] TOON format & LLM optimization +- [x] WebSocket support +- [x] Template engine (Tera) +- [x] CLI tool (cargo-rustapi) - [ ] *Coming soon...* --- diff --git a/crates/cargo-rustapi/Cargo.toml b/crates/cargo-rustapi/Cargo.toml new file mode 100644 index 0000000..54b3b22 --- /dev/null +++ b/crates/cargo-rustapi/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "cargo-rustapi" +description = "CLI tool for RustAPI - Project scaffolding and development utilities" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +keywords = ["web", "framework", "api", "cli", "scaffold"] +categories = ["command-line-utilities", "development-tools"] +rust-version.workspace = true +readme = "README.md" + +[[bin]] +name = "cargo-rustapi" +path = "src/main.rs" + +[dependencies] +# CLI +clap = { workspace = true } +dialoguer = { workspace = true } +indicatif = { workspace = true } +console = { workspace = true } + +# File system +walkdir = "2.5" +toml_edit = "0.22" + +# Async +tokio = { workspace = true, features = ["process", "fs"] } + +# Serialization +serde = { workspace = true } +serde_json = { workspace = true } +toml = "0.8" + +# Utilities +thiserror = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +anyhow = "1.0" + +[dev-dependencies] +tempfile = "3.10" +assert_cmd = "2.0" diff --git a/crates/cargo-rustapi/README.md b/crates/cargo-rustapi/README.md new file mode 100644 index 0000000..e32b428 --- /dev/null +++ b/crates/cargo-rustapi/README.md @@ -0,0 +1,111 @@ +# cargo-rustapi + +CLI tool for the RustAPI framework - Project scaffolding and development utilities. + +## Installation + +```bash +cargo install cargo-rustapi +``` + +## Usage + +### Create a New Project + +```bash +# Interactive mode +cargo rustapi new my-project + +# With template +cargo rustapi new my-project --template api + +# With features +cargo rustapi new my-project --features jwt,cors +``` + +### Available Templates + +- `minimal` - Bare minimum RustAPI app (default) +- `api` - REST API with CRUD example +- `web` - Web app with templates +- `full` - Full-featured with JWT, CORS, database + +### Run Development Server + +```bash +# Run with auto-reload +cargo rustapi run + +# Run on specific port +cargo rustapi run --port 8080 + +# Run with specific features +cargo rustapi run --features jwt +``` + +### Generate Code + +```bash +# Generate a new handler +cargo rustapi generate handler users + +# Generate a model +cargo rustapi generate model User + +# Generate CRUD endpoints +cargo rustapi generate crud users +``` + +## Commands + +| Command | Description | +|---------|-------------| +| `new ` | Create a new RustAPI project | +| `run` | Run development server with auto-reload | +| `generate ` | Generate code from templates | +| `docs` | Open API documentation | + +## Project Templates + +### Minimal Template +``` +my-project/ +├── Cargo.toml +├── src/ +│ └── main.rs +└── .gitignore +``` + +### API Template +``` +my-project/ +├── Cargo.toml +├── src/ +│ ├── main.rs +│ ├── handlers/ +│ │ └── mod.rs +│ ├── models/ +│ │ └── mod.rs +│ └── error.rs +├── .env.example +└── .gitignore +``` + +### Web Template +``` +my-project/ +├── Cargo.toml +├── src/ +│ ├── main.rs +│ └── handlers/ +├── templates/ +│ ├── base.html +│ └── index.html +├── static/ +│ └── style.css +└── .gitignore +``` + +## License + +MIT OR Apache-2.0 diff --git a/crates/cargo-rustapi/src/cli.rs b/crates/cargo-rustapi/src/cli.rs new file mode 100644 index 0000000..9085607 --- /dev/null +++ b/crates/cargo-rustapi/src/cli.rs @@ -0,0 +1,46 @@ +//! CLI argument parsing + +use crate::commands::{self, GenerateArgs, NewArgs, RunArgs}; +use clap::{Parser, Subcommand}; + +/// RustAPI CLI - Project scaffolding and development utilities +#[derive(Parser, Debug)] +#[command(name = "cargo-rustapi")] +#[command(bin_name = "cargo rustapi")] +#[command(author, version, about, long_about = None)] +pub struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +enum Commands { + /// Create a new RustAPI project + New(NewArgs), + + /// Run the development server + Run(RunArgs), + + /// Generate code from templates + #[command(subcommand)] + Generate(GenerateArgs), + + /// Open API documentation in browser + Docs { + /// Port to check for running server + #[arg(short, long, default_value = "8080")] + port: u16, + }, +} + +impl Cli { + /// Execute the CLI command + pub async fn execute(self) -> anyhow::Result<()> { + match self.command { + Commands::New(args) => commands::new_project(args).await, + Commands::Run(args) => commands::run_dev(args).await, + Commands::Generate(args) => commands::generate(args).await, + Commands::Docs { port } => commands::open_docs(port).await, + } + } +} diff --git a/crates/cargo-rustapi/src/commands/docs.rs b/crates/cargo-rustapi/src/commands/docs.rs new file mode 100644 index 0000000..723ad98 --- /dev/null +++ b/crates/cargo-rustapi/src/commands/docs.rs @@ -0,0 +1,38 @@ +//! Docs command - open API documentation + +use anyhow::Result; +use console::style; + +/// Open API documentation in browser +pub async fn open_docs(port: u16) -> Result<()> { + let url = format!("http://localhost:{}/docs", port); + + println!("Opening {} in browser...", style(&url).cyan()); + + // Try to open in browser + #[cfg(target_os = "windows")] + { + tokio::process::Command::new("cmd") + .args(["/C", "start", &url]) + .spawn()?; + } + + #[cfg(target_os = "macos")] + { + tokio::process::Command::new("open").arg(&url).spawn()?; + } + + #[cfg(target_os = "linux")] + { + tokio::process::Command::new("xdg-open").arg(&url).spawn()?; + } + + println!(); + println!( + "{}", + style("Make sure your RustAPI server is running!").yellow() + ); + println!("Start it with: {}", style("cargo rustapi run").cyan()); + + Ok(()) +} diff --git a/crates/cargo-rustapi/src/commands/generate.rs b/crates/cargo-rustapi/src/commands/generate.rs new file mode 100644 index 0000000..5c2f6e2 --- /dev/null +++ b/crates/cargo-rustapi/src/commands/generate.rs @@ -0,0 +1,289 @@ +//! Code generation command + +use anyhow::Result; +use clap::Subcommand; +use console::style; +use std::path::Path; +use tokio::fs; + +/// Arguments for the `generate` command +#[derive(Subcommand, Debug)] +pub enum GenerateArgs { + /// Generate a handler module + Handler { + /// Handler name (e.g., "users", "products") + name: String, + }, + + /// Generate a model struct + Model { + /// Model name (e.g., "User", "Product") + name: String, + }, + + /// Generate CRUD handlers for a resource + Crud { + /// Resource name (e.g., "users", "products") + name: String, + }, +} + +/// Execute code generation +pub async fn generate(args: GenerateArgs) -> Result<()> { + match args { + GenerateArgs::Handler { name } => generate_handler(&name).await, + GenerateArgs::Model { name } => generate_model(&name).await, + GenerateArgs::Crud { name } => generate_crud(&name).await, + } +} + +async fn generate_handler(name: &str) -> Result<()> { + let handlers_dir = Path::new("src/handlers"); + + // Create handlers directory if it doesn't exist + if !handlers_dir.exists() { + fs::create_dir_all(handlers_dir).await?; + + // Create mod.rs + let mod_content = format!("pub mod {};\n", name); + fs::write(handlers_dir.join("mod.rs"), mod_content).await?; + } else { + // Append to existing mod.rs + let mod_path = handlers_dir.join("mod.rs"); + if mod_path.exists() { + let mut content = fs::read_to_string(&mod_path).await?; + if !content.contains(&format!("mod {};", name)) { + content.push_str(&format!("pub mod {};\n", name)); + fs::write(&mod_path, content).await?; + } + } + } + + // Generate handler file + let handler_content = format!( + r#"//! {} handlers + +use rustapi_rs::prelude::*; +use serde::{{Deserialize, Serialize}}; + +/// List all {} +#[rustapi::get("/{name}")] +pub async fn list() -> Json> {{ + // TODO: Implement list + Json(vec![]) +}} + +/// Get a single {singular} +#[rustapi::get("/{name}/{{id}}")] +pub async fn get(Path(id): Path) -> Result> {{ + // TODO: Implement get + Err(ApiError::not_found("{singular}")) +}} + +/// Create a new {singular} +#[rustapi::post("/{name}")] +pub async fn create(Json(body): Json) -> Result>> {{ + // TODO: Implement create + Err(ApiError::internal("Not implemented")) +}} + +/// Update a {singular} +#[rustapi::put("/{name}/{{id}}")] +pub async fn update( + Path(id): Path, + Json(body): Json, +) -> Result> {{ + // TODO: Implement update + Err(ApiError::not_found("{singular}")) +}} + +/// Delete a {singular} +#[rustapi::delete("/{name}/{{id}}")] +pub async fn delete(Path(id): Path) -> Result {{ + // TODO: Implement delete + Err(ApiError::not_found("{singular}")) +}} + +// Request/Response types +#[derive(Debug, Serialize, Schema)] +pub struct {type_name}Response {{ + pub id: i64, + // TODO: Add fields +}} + +#[derive(Debug, Deserialize, Schema)] +pub struct Create{type_name} {{ + // TODO: Add fields +}} + +#[derive(Debug, Deserialize, Schema)] +pub struct Update{type_name} {{ + // TODO: Add fields +}} +"#, + capitalize(name), + name, + name = name, + type_name = to_pascal_case(name), + singular = singularize(name), + ); + + let handler_path = handlers_dir.join(format!("{}.rs", name)); + fs::write(&handler_path, handler_content).await?; + + println!( + "{} Generated handler: {}", + style("✓").green(), + handler_path.display() + ); + println!(); + println!("Don't forget to register the routes in main.rs:"); + println!( + " {}", + style(format!(".mount(handlers::{}::list)", name)).cyan() + ); + println!( + " {}", + style(format!(".mount(handlers::{}::get)", name)).cyan() + ); + println!( + " {}", + style(format!(".mount(handlers::{}::create)", name)).cyan() + ); + println!( + " {}", + style(format!(".mount(handlers::{}::update)", name)).cyan() + ); + println!( + " {}", + style(format!(".mount(handlers::{}::delete)", name)).cyan() + ); + + Ok(()) +} + +async fn generate_model(name: &str) -> Result<()> { + let models_dir = Path::new("src/models"); + + // Create models directory if it doesn't exist + if !models_dir.exists() { + fs::create_dir_all(models_dir).await?; + + // Create mod.rs + let mod_content = format!( + "mod {};\npub use {}::*;\n", + name.to_lowercase(), + name.to_lowercase() + ); + fs::write(models_dir.join("mod.rs"), mod_content).await?; + } else { + // Append to existing mod.rs + let mod_path = models_dir.join("mod.rs"); + if mod_path.exists() { + let mut content = fs::read_to_string(&mod_path).await?; + let lower_name = name.to_lowercase(); + if !content.contains(&format!("mod {};", lower_name)) { + content.push_str(&format!( + "mod {};\npub use {}::*;\n", + lower_name, lower_name + )); + fs::write(&mod_path, content).await?; + } + } + } + + // Generate model file + let model_content = format!( + r#"//! {} model + +use serde::{{Deserialize, Serialize}}; +use rustapi_rs::Schema; + +/// {} entity +#[derive(Debug, Clone, Serialize, Deserialize, Schema)] +pub struct {} {{ + /// Unique identifier + pub id: i64, + + /// Creation timestamp + pub created_at: String, + + /// Last update timestamp + pub updated_at: String, + + // TODO: Add your fields here +}} + +impl {} {{ + /// Create a new {} instance + pub fn new(id: i64) -> Self {{ + let now = chrono::Utc::now().to_rfc3339(); + Self {{ + id, + created_at: now.clone(), + updated_at: now, + }} + }} +}} +"#, + name, + name, + name, + name, + name.to_lowercase(), + ); + + let model_path = models_dir.join(format!("{}.rs", name.to_lowercase())); + fs::write(&model_path, model_content).await?; + + println!( + "{} Generated model: {}", + style("✓").green(), + model_path.display() + ); + + Ok(()) +} + +async fn generate_crud(name: &str) -> Result<()> { + // Generate both handler and model + let type_name = to_pascal_case(name); + + println!( + "{}", + style(format!("Generating CRUD for '{}'...", name)).bold() + ); + println!(); + + generate_model(&type_name).await?; + generate_handler(name).await?; + + Ok(()) +} + +// Helper functions +fn capitalize(s: &str) -> String { + let mut chars = s.chars(); + match chars.next() { + None => String::new(), + Some(c) => c.to_uppercase().collect::() + chars.as_str(), + } +} + +fn to_pascal_case(s: &str) -> String { + s.split(&['-', '_'][..]).map(capitalize).collect() +} + +fn singularize(s: &str) -> String { + if let Some(stripped) = s.strip_suffix("ies") { + format!("{}y", stripped) + } else if let Some(stripped) = s.strip_suffix('s') { + if !s.ends_with("ss") { + stripped.to_string() + } else { + s.to_string() + } + } else { + s.to_string() + } +} diff --git a/crates/cargo-rustapi/src/commands/mod.rs b/crates/cargo-rustapi/src/commands/mod.rs new file mode 100644 index 0000000..ab33fcb --- /dev/null +++ b/crates/cargo-rustapi/src/commands/mod.rs @@ -0,0 +1,11 @@ +//! CLI commands + +mod docs; +mod generate; +mod new; +mod run; + +pub use docs::open_docs; +pub use generate::{generate, GenerateArgs}; +pub use new::{new_project, NewArgs}; +pub use run::{run_dev, RunArgs}; diff --git a/crates/cargo-rustapi/src/commands/new.rs b/crates/cargo-rustapi/src/commands/new.rs new file mode 100644 index 0000000..8a9489f --- /dev/null +++ b/crates/cargo-rustapi/src/commands/new.rs @@ -0,0 +1,228 @@ +//! New project command + +use anyhow::{Context, Result}; +use clap::Args; +use console::style; +use dialoguer::{theme::ColorfulTheme, Confirm, Input, Select}; +use indicatif::{ProgressBar, ProgressStyle}; +use std::path::Path; +use std::time::Duration; +use tokio::fs; + +use crate::templates::{self, ProjectTemplate}; + +/// Arguments for the `new` command +#[derive(Args, Debug)] +pub struct NewArgs { + /// Project name + pub name: Option, + + /// Project template + #[arg(short, long, value_enum)] + pub template: Option, + + /// Features to enable + #[arg(short, long, value_delimiter = ',')] + pub features: Option>, + + /// Skip interactive prompts + #[arg(long)] + pub yes: bool, + + /// Initialize git repository + #[arg(long, default_value = "true")] + pub git: bool, +} + +/// Create a new RustAPI project +pub async fn new_project(mut args: NewArgs) -> Result<()> { + let theme = ColorfulTheme::default(); + + // Get project name + let name = if let Some(name) = args.name.take() { + name + } else { + Input::with_theme(&theme) + .with_prompt("Project name") + .default("my-rustapi-app".to_string()) + .interact_text()? + }; + + // Validate project name + validate_project_name(&name)?; + + // Check if directory exists + let project_path = Path::new(&name); + if project_path.exists() { + anyhow::bail!("Directory '{}' already exists", name); + } + + // Get template + let template = if let Some(template) = args.template { + template + } else if args.yes { + ProjectTemplate::Minimal + } else { + let templates = [ + "minimal - Bare minimum app", + "api - REST API with CRUD", + "web - Web app with templates", + "full - Full-featured app", + ]; + let selection = Select::with_theme(&theme) + .with_prompt("Select a template") + .items(&templates) + .default(0) + .interact()?; + + match selection { + 0 => ProjectTemplate::Minimal, + 1 => ProjectTemplate::Api, + 2 => ProjectTemplate::Web, + 3 => ProjectTemplate::Full, + _ => ProjectTemplate::Minimal, + } + }; + + // Get features + let features = if let Some(features) = args.features { + features + } else if args.yes { + vec![] + } else { + let available = ["jwt", "cors", "rate-limit", "config", "toon", "ws", "view"]; + let defaults = match template { + ProjectTemplate::Full => vec![true, true, true, true, false, false, false], + ProjectTemplate::Web => vec![false, false, false, false, false, false, true], + _ => vec![false; available.len()], + }; + + let selections = dialoguer::MultiSelect::with_theme(&theme) + .with_prompt("Select features (space to toggle)") + .items(&available) + .defaults(&defaults) + .interact()?; + + selections + .iter() + .map(|&i| available[i].to_string()) + .collect() + }; + + // Confirm + if !args.yes { + println!(); + println!("{}", style("Project configuration:").bold()); + println!(" Name: {}", style(&name).cyan()); + println!(" Template: {}", style(format!("{:?}", template)).cyan()); + println!( + " Features: {}", + style(if features.is_empty() { + "none".to_string() + } else { + features.join(", ") + }) + .cyan() + ); + println!(); + + if !Confirm::with_theme(&theme) + .with_prompt("Create project?") + .default(true) + .interact()? + { + println!("{}", style("Aborted").yellow()); + return Ok(()); + } + } + + // Create project + let pb = ProgressBar::new_spinner(); + pb.set_style( + ProgressStyle::default_spinner() + .template("{spinner:.green} {msg}") + .unwrap(), + ); + pb.enable_steady_tick(Duration::from_millis(80)); + + pb.set_message("Creating project directory..."); + fs::create_dir_all(&name).await?; + + pb.set_message("Generating project files..."); + templates::generate_project(&name, template, &features).await?; + + if args.git { + pb.set_message("Initializing git repository..."); + init_git(&name).await.ok(); // Don't fail if git isn't available + } + + pb.finish_and_clear(); + + // Success message + println!(); + println!( + "{}", + style("✨ Project created successfully!").green().bold() + ); + println!(); + println!("Next steps:"); + println!(" {} {}", style("cd").cyan(), name); + println!(" {} run", style("cargo").cyan()); + println!(); + println!( + "Then open {} in your browser.", + style("http://localhost:8080").cyan() + ); + + if features.iter().any(|f| f == "swagger-ui") || template == ProjectTemplate::Full { + println!( + "API docs available at {}", + style("http://localhost:8080/docs").cyan() + ); + } + + Ok(()) +} + +/// Validate project name +fn validate_project_name(name: &str) -> Result<()> { + if name.is_empty() { + anyhow::bail!("Project name cannot be empty"); + } + + if name.contains('/') || name.contains('\\') { + anyhow::bail!("Project name cannot contain path separators"); + } + + // Check for valid Rust crate name characters + if !name + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_') + { + anyhow::bail!( + "Project name can only contain alphanumeric characters, hyphens, and underscores" + ); + } + + if name.starts_with('-') || name.starts_with('_') { + anyhow::bail!("Project name cannot start with a hyphen or underscore"); + } + + Ok(()) +} + +/// Initialize a git repository +async fn init_git(path: &str) -> Result<()> { + let output = tokio::process::Command::new("git") + .args(["init"]) + .current_dir(path) + .output() + .await + .context("Failed to run git init")?; + + if !output.status.success() { + anyhow::bail!("git init failed"); + } + + Ok(()) +} diff --git a/crates/cargo-rustapi/src/commands/run.rs b/crates/cargo-rustapi/src/commands/run.rs new file mode 100644 index 0000000..a55508a --- /dev/null +++ b/crates/cargo-rustapi/src/commands/run.rs @@ -0,0 +1,118 @@ +//! Run command for development server + +use anyhow::Result; +use clap::Args; +use console::style; +use std::process::Stdio; +use tokio::process::Command; + +/// Arguments for the `run` command +#[derive(Args, Debug)] +pub struct RunArgs { + /// Port to run on + #[arg(short, long, default_value = "8080")] + pub port: u16, + + /// Additional features to enable + #[arg(short, long, value_delimiter = ',')] + pub features: Option>, + + /// Release mode + #[arg(long)] + pub release: bool, + + /// Watch for changes and auto-reload + #[arg(short, long)] + pub watch: bool, +} + +/// Run the development server +pub async fn run_dev(args: RunArgs) -> Result<()> { + // Set environment variables + std::env::set_var("PORT", args.port.to_string()); + std::env::set_var("RUSTAPI_ENV", "development"); + + println!("{}", style("Starting RustAPI development server...").bold()); + println!(); + + if args.watch { + // Use cargo-watch if available + run_with_watch(&args).await + } else { + run_cargo(&args).await + } +} + +async fn run_cargo(args: &RunArgs) -> Result<()> { + let mut cmd = Command::new("cargo"); + cmd.arg("run"); + + if args.release { + cmd.arg("--release"); + } + + if let Some(features) = &args.features { + cmd.arg("--features").arg(features.join(",")); + } + + cmd.stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .stdin(Stdio::inherit()); + + let status = cmd.status().await?; + + if !status.success() { + anyhow::bail!("cargo run failed"); + } + + Ok(()) +} + +async fn run_with_watch(args: &RunArgs) -> Result<()> { + // Check if cargo-watch is installed + let check = Command::new("cargo") + .args(["watch", "--version"]) + .output() + .await; + + if check.is_err() || !check.unwrap().status.success() { + println!("{}", style("cargo-watch not found. Installing...").yellow()); + + let install = Command::new("cargo") + .args(["install", "cargo-watch"]) + .status() + .await?; + + if !install.success() { + println!( + "{}", + style("Failed to install cargo-watch. Running without watch mode.").yellow() + ); + return run_cargo(args).await; + } + } + + let mut cmd = Command::new("cargo"); + cmd.args(["watch", "-x"]); + + let mut run_cmd = String::from("run"); + if args.release { + run_cmd.push_str(" --release"); + } + if let Some(features) = &args.features { + run_cmd.push_str(&format!(" --features {}", features.join(","))); + } + + cmd.arg(run_cmd); + cmd.stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .stdin(Stdio::inherit()); + + let status = cmd.status().await?; + + if !status.success() { + anyhow::bail!("cargo watch failed"); + } + + Ok(()) +} diff --git a/crates/cargo-rustapi/src/main.rs b/crates/cargo-rustapi/src/main.rs new file mode 100644 index 0000000..9bc13bb --- /dev/null +++ b/crates/cargo-rustapi/src/main.rs @@ -0,0 +1,28 @@ +//! cargo-rustapi CLI tool +//! +//! Provides project scaffolding and development utilities for RustAPI. + +mod cli; +mod commands; +mod templates; + +use clap::Parser; +use cli::Cli; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("cargo_rustapi=info".parse().unwrap()), + ) + .without_time() + .init(); + + // Parse CLI arguments + let cli = Cli::parse(); + + // Execute command + cli.execute().await +} diff --git a/crates/cargo-rustapi/src/templates/api.rs b/crates/cargo-rustapi/src/templates/api.rs new file mode 100644 index 0000000..02c1664 --- /dev/null +++ b/crates/cargo-rustapi/src/templates/api.rs @@ -0,0 +1,305 @@ +//! API project template + +use super::common; +use anyhow::Result; +use tokio::fs; + +pub async fn generate(name: &str, features: &[String]) -> Result<()> { + // Cargo.toml + let cargo_toml = format!( + r#"[package] +name = "{name}" +version = "0.1.0" +edition = "2021" + +[dependencies] +rustapi-rs = {{ version = "0.1"{features} }} +tokio = {{ version = "1", features = ["full"] }} +serde = {{ version = "1", features = ["derive"] }} +tracing = "0.1" +tracing-subscriber = {{ version = "0.3", features = ["env-filter"] }} +uuid = {{ version = "1", features = ["v4"] }} +"#, + name = name, + features = common::features_to_cargo(features), + ); + fs::write(format!("{name}/Cargo.toml"), cargo_toml).await?; + + // Create directories + fs::create_dir_all(format!("{name}/src/handlers")).await?; + fs::create_dir_all(format!("{name}/src/models")).await?; + + // main.rs + let main_rs = r#"mod handlers; +mod models; +mod error; + +use rustapi_rs::prelude::*; +use std::sync::Arc; +use tokio::sync::RwLock; + +pub type AppState = Arc>; + +#[rustapi::main] +async fn main() -> Result<(), Box> { + // Initialize tracing + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("info".parse().unwrap()), + ) + .init(); + + // Create shared state + let state: AppState = Arc::new(RwLock::new(models::Store::new())); + + let port = std::env::var("PORT").unwrap_or_else(|_| "8080".to_string()); + let addr = format!("127.0.0.1:{}", port); + + tracing::info!("🚀 Server running at http://{}", addr); + tracing::info!("📚 API docs at http://{}/docs", addr); + + RustApi::new() + .state(state) + // Health check + .route("/health", get(handlers::health)) + // Items CRUD + .mount(handlers::items::list) + .mount(handlers::items::get) + .mount(handlers::items::create) + .mount(handlers::items::update) + .mount(handlers::items::delete) + // Documentation + .docs("/docs") + .run(&addr) + .await +} +"#; + fs::write(format!("{name}/src/main.rs"), main_rs).await?; + + // error.rs + let error_rs = r#"//! Application error types + +use rustapi_rs::prelude::*; + +/// Application-specific errors +#[derive(Debug)] +pub enum AppError { + NotFound(String), + InvalidInput(String), +} + +impl From for ApiError { + fn from(err: AppError) -> Self { + match err { + AppError::NotFound(msg) => ApiError::not_found(msg), + AppError::InvalidInput(msg) => ApiError::bad_request(msg), + } + } +} +"#; + fs::write(format!("{name}/src/error.rs"), error_rs).await?; + + // handlers/mod.rs + let handlers_mod = r#"//! Request handlers + +pub mod items; + +use rustapi_rs::prelude::*; +use serde::Serialize; + +/// Health check response +#[derive(Serialize, Schema)] +pub struct HealthResponse { + pub status: String, + pub version: String, +} + +/// Health check endpoint +pub async fn health() -> Json { + Json(HealthResponse { + status: "ok".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }) +} +"#; + fs::write(format!("{name}/src/handlers/mod.rs"), handlers_mod).await?; + + // handlers/items.rs + let handlers_items = r#"//! Item handlers + +use crate::models::{Item, CreateItem, UpdateItem}; +use crate::AppState; +use rustapi_rs::prelude::*; + +/// List all items +#[rustapi::get("/items")] +#[rustapi::tag("Items")] +#[rustapi::summary("List all items")] +pub async fn list(State(state): State) -> Json> { + let store = state.read().await; + Json(store.items.values().cloned().collect()) +} + +/// Get an item by ID +#[rustapi::get("/items/{id}")] +#[rustapi::tag("Items")] +#[rustapi::summary("Get item by ID")] +pub async fn get( + Path(id): Path, + State(state): State, +) -> Result> { + let store = state.read().await; + store.items + .get(&id) + .cloned() + .map(Json) + .ok_or_else(|| ApiError::not_found(format!("Item {} not found", id))) +} + +/// Create a new item +#[rustapi::post("/items")] +#[rustapi::tag("Items")] +#[rustapi::summary("Create a new item")] +pub async fn create( + State(state): State, + Json(body): Json, +) -> Result>> { + let item = Item::new(body.name, body.description); + + let mut store = state.write().await; + store.items.insert(item.id.clone(), item.clone()); + + Ok(Created(Json(item))) +} + +/// Update an item +#[rustapi::put("/items/{id}")] +#[rustapi::tag("Items")] +#[rustapi::summary("Update an item")] +pub async fn update( + Path(id): Path, + State(state): State, + Json(body): Json, +) -> Result> { + let mut store = state.write().await; + + let item = store.items + .get_mut(&id) + .ok_or_else(|| ApiError::not_found(format!("Item {} not found", id)))?; + + if let Some(name) = body.name { + item.name = name; + } + if let Some(description) = body.description { + item.description = description; + } + item.updated_at = chrono_now(); + + Ok(Json(item.clone())) +} + +/// Delete an item +#[rustapi::delete("/items/{id}")] +#[rustapi::tag("Items")] +#[rustapi::summary("Delete an item")] +pub async fn delete( + Path(id): Path, + State(state): State, +) -> Result { + let mut store = state.write().await; + + store.items + .remove(&id) + .ok_or_else(|| ApiError::not_found(format!("Item {} not found", id)))?; + + Ok(NoContent) +} + +fn chrono_now() -> String { + // Simple timestamp without chrono dependency + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs().to_string()) + .unwrap_or_default() +} +"#; + fs::write(format!("{name}/src/handlers/items.rs"), handlers_items).await?; + + // models/mod.rs + let models_mod = r#"//! Data models + +use serde::{Deserialize, Serialize}; +use rustapi_rs::Schema; +use std::collections::HashMap; + +/// In-memory data store +pub struct Store { + pub items: HashMap, +} + +impl Store { + pub fn new() -> Self { + Self { + items: HashMap::new(), + } + } +} + +impl Default for Store { + fn default() -> Self { + Self::new() + } +} + +/// An item in the store +#[derive(Debug, Clone, Serialize, Deserialize, Schema)] +pub struct Item { + pub id: String, + pub name: String, + #[serde(default)] + pub description: Option, + pub created_at: String, + pub updated_at: String, +} + +impl Item { + pub fn new(name: String, description: Option) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs().to_string()) + .unwrap_or_default(); + + Self { + id: uuid::Uuid::new_v4().to_string(), + name, + description, + created_at: now.clone(), + updated_at: now, + } + } +} + +/// Request to create an item +#[derive(Debug, Deserialize, Schema)] +pub struct CreateItem { + pub name: String, + #[serde(default)] + pub description: Option, +} + +/// Request to update an item +#[derive(Debug, Deserialize, Schema)] +pub struct UpdateItem { + pub name: Option, + pub description: Option, +} +"#; + fs::write(format!("{name}/src/models/mod.rs"), models_mod).await?; + + // .gitignore and .env.example + common::generate_gitignore(name).await?; + common::generate_env_example(name).await?; + + Ok(()) +} diff --git a/crates/cargo-rustapi/src/templates/full.rs b/crates/cargo-rustapi/src/templates/full.rs new file mode 100644 index 0000000..6245d2b --- /dev/null +++ b/crates/cargo-rustapi/src/templates/full.rs @@ -0,0 +1,426 @@ +//! Full-featured project template + +use super::common; +use anyhow::Result; +use tokio::fs; + +pub async fn generate(name: &str, features: &[String]) -> Result<()> { + // Add recommended features for full template + let mut all_features: Vec = vec![ + "jwt".to_string(), + "cors".to_string(), + "rate-limit".to_string(), + "config".to_string(), + ]; + + // Add user-specified features + for f in features { + if !all_features.contains(f) { + all_features.push(f.clone()); + } + } + + // Cargo.toml + let cargo_toml = format!( + r#"[package] +name = "{name}" +version = "0.1.0" +edition = "2021" + +[dependencies] +rustapi-rs = {{ version = "0.1"{features} }} +tokio = {{ version = "1", features = ["full"] }} +serde = {{ version = "1", features = ["derive"] }} +tracing = "0.1" +tracing-subscriber = {{ version = "0.3", features = ["env-filter"] }} +uuid = {{ version = "1", features = ["v4"] }} +"#, + name = name, + features = common::features_to_cargo(&all_features), + ); + fs::write(format!("{name}/Cargo.toml"), cargo_toml).await?; + + // Create directories + fs::create_dir_all(format!("{name}/src/handlers")).await?; + fs::create_dir_all(format!("{name}/src/models")).await?; + fs::create_dir_all(format!("{name}/src/middleware")).await?; + + // main.rs + let main_rs = r#"mod handlers; +mod models; +mod middleware; + +use rustapi_rs::prelude::*; +use std::sync::Arc; +use tokio::sync::RwLock; + +pub type AppState = Arc>; + +#[rustapi::main] +async fn main() -> Result<(), Box> { + // Load environment variables + load_dotenv(); + + // Initialize tracing + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("info".parse().unwrap()), + ) + .init(); + + // Get configuration + let env = Environment::from_env(); + let host = env_or("HOST", "127.0.0.1"); + let port = env_or("PORT", "8080"); + let addr = format!("{}:{}", host, port); + + // Create shared state + let state: AppState = Arc::new(RwLock::new(models::Store::new())); + + tracing::info!("🚀 Starting server in {:?} mode", env); + tracing::info!("📡 Listening on http://{}", addr); + tracing::info!("📚 API docs at http://{}/docs", addr); + + RustApi::new() + .state(state) + // Middleware + .layer(CorsLayer::permissive()) + .layer(RateLimitLayer::new(100, std::time::Duration::from_secs(60))) + // Health check + .route("/health", get(handlers::health)) + // Auth endpoints + .route("/auth/login", post(handlers::auth::login)) + .route("/auth/me", get(handlers::auth::me)) + // Protected items endpoints (require JWT) + .mount(handlers::items::list) + .mount(handlers::items::get) + .mount(handlers::items::create) + .mount(handlers::items::update) + .mount(handlers::items::delete) + // Documentation + .docs_with_info("/docs", ApiInfo { + title: env!("CARGO_PKG_NAME").to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + description: Some("Full-featured RustAPI application".to_string()), + }) + .run(&addr) + .await +} +"#; + fs::write(format!("{name}/src/main.rs"), main_rs).await?; + + // handlers/mod.rs + let handlers_mod = r#"//! Request handlers + +pub mod auth; +pub mod items; + +use rustapi_rs::prelude::*; +use serde::Serialize; + +#[derive(Serialize, Schema)] +pub struct HealthResponse { + pub status: String, + pub version: String, + pub environment: String, +} + +pub async fn health() -> Json { + Json(HealthResponse { + status: "ok".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + environment: std::env::var("RUSTAPI_ENV").unwrap_or_else(|_| "development".to_string()), + }) +} +"#; + fs::write(format!("{name}/src/handlers/mod.rs"), handlers_mod).await?; + + // handlers/auth.rs + let handlers_auth = r#"//! Authentication handlers + +use rustapi_rs::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Schema)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Schema)] +pub struct LoginResponse { + pub token: String, + pub token_type: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserClaims { + pub sub: String, + pub username: String, + pub exp: usize, +} + +/// Login and get a JWT token +#[rustapi::post("/auth/login")] +#[rustapi::tag("Authentication")] +#[rustapi::summary("Login with username and password")] +pub async fn login(Json(body): Json) -> Result> { + // TODO: Validate credentials against your database + if body.username == "admin" && body.password == "password" { + let jwt_secret = std::env::var("JWT_SECRET") + .unwrap_or_else(|_| "dev-secret-change-in-production".to_string()); + + let claims = UserClaims { + sub: "1".to_string(), + username: body.username, + exp: (chrono_now() + 86400) as usize, // 24 hours + }; + + let token = create_token(&claims, &jwt_secret)?; + + Ok(Json(LoginResponse { + token, + token_type: "Bearer".to_string(), + })) + } else { + Err(ApiError::unauthorized("Invalid credentials")) + } +} + +/// Get current user info +#[rustapi::get("/auth/me")] +#[rustapi::tag("Authentication")] +#[rustapi::summary("Get current authenticated user")] +pub async fn me(auth: AuthUser) -> Json { + Json(auth.claims) +} + +fn chrono_now() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} +"#; + fs::write(format!("{name}/src/handlers/auth.rs"), handlers_auth).await?; + + // handlers/items.rs + let handlers_items = r#"//! Item handlers + +use crate::handlers::auth::UserClaims; +use crate::models::{Item, CreateItem, UpdateItem}; +use crate::AppState; +use rustapi_rs::prelude::*; + +/// List all items +#[rustapi::get("/items")] +#[rustapi::tag("Items")] +#[rustapi::summary("List all items")] +pub async fn list( + _auth: AuthUser, + State(state): State, +) -> Json> { + let store = state.read().await; + Json(store.items.values().cloned().collect()) +} + +/// Get an item by ID +#[rustapi::get("/items/{id}")] +#[rustapi::tag("Items")] +#[rustapi::summary("Get item by ID")] +pub async fn get( + _auth: AuthUser, + Path(id): Path, + State(state): State, +) -> Result> { + let store = state.read().await; + store.items + .get(&id) + .cloned() + .map(Json) + .ok_or_else(|| ApiError::not_found(format!("Item {} not found", id))) +} + +/// Create a new item +#[rustapi::post("/items")] +#[rustapi::tag("Items")] +#[rustapi::summary("Create a new item")] +pub async fn create( + auth: AuthUser, + State(state): State, + Json(body): Json, +) -> Result>> { + let item = Item::new(body.name, body.description, auth.claims.sub.clone()); + + let mut store = state.write().await; + store.items.insert(item.id.clone(), item.clone()); + + tracing::info!("User {} created item {}", auth.claims.username, item.id); + + Ok(Created(Json(item))) +} + +/// Update an item +#[rustapi::put("/items/{id}")] +#[rustapi::tag("Items")] +#[rustapi::summary("Update an item")] +pub async fn update( + _auth: AuthUser, + Path(id): Path, + State(state): State, + Json(body): Json, +) -> Result> { + let mut store = state.write().await; + + let item = store.items + .get_mut(&id) + .ok_or_else(|| ApiError::not_found(format!("Item {} not found", id)))?; + + if let Some(name) = body.name { + item.name = name; + } + if let Some(description) = body.description { + item.description = description; + } + item.updated_at = chrono_now(); + + Ok(Json(item.clone())) +} + +/// Delete an item +#[rustapi::delete("/items/{id}")] +#[rustapi::tag("Items")] +#[rustapi::summary("Delete an item")] +pub async fn delete( + auth: AuthUser, + Path(id): Path, + State(state): State, +) -> Result { + let mut store = state.write().await; + + store.items + .remove(&id) + .ok_or_else(|| ApiError::not_found(format!("Item {} not found", id)))?; + + tracing::info!("User {} deleted item {}", auth.claims.username, id); + + Ok(NoContent) +} + +fn chrono_now() -> String { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs().to_string()) + .unwrap_or_default() +} +"#; + fs::write(format!("{name}/src/handlers/items.rs"), handlers_items).await?; + + // models/mod.rs + let models_mod = r#"//! Data models + +use serde::{Deserialize, Serialize}; +use rustapi_rs::Schema; +use std::collections::HashMap; + +pub struct Store { + pub items: HashMap, +} + +impl Store { + pub fn new() -> Self { + Self { + items: HashMap::new(), + } + } +} + +impl Default for Store { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Schema)] +pub struct Item { + pub id: String, + pub name: String, + #[serde(default)] + pub description: Option, + pub created_by: String, + pub created_at: String, + pub updated_at: String, +} + +impl Item { + pub fn new(name: String, description: Option, created_by: String) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs().to_string()) + .unwrap_or_default(); + + Self { + id: uuid::Uuid::new_v4().to_string(), + name, + description, + created_by, + created_at: now.clone(), + updated_at: now, + } + } +} + +#[derive(Debug, Deserialize, Schema)] +pub struct CreateItem { + pub name: String, + #[serde(default)] + pub description: Option, +} + +#[derive(Debug, Deserialize, Schema)] +pub struct UpdateItem { + pub name: Option, + pub description: Option, +} +"#; + fs::write(format!("{name}/src/models/mod.rs"), models_mod).await?; + + // middleware/mod.rs + let middleware_mod = r#"//! Custom middleware + +// Add your custom middleware here +// Example: +// pub mod logging; +// pub mod auth_check; +"#; + fs::write(format!("{name}/src/middleware/mod.rs"), middleware_mod).await?; + + // .env.example with JWT secret + let env_example = r#"# Server configuration +HOST=127.0.0.1 +PORT=8080 + +# Environment (development, production) +RUSTAPI_ENV=development + +# JWT Secret (CHANGE THIS IN PRODUCTION!) +JWT_SECRET=your-super-secret-key-change-in-production + +# Rate limiting +RATE_LIMIT_REQUESTS=100 +RATE_LIMIT_WINDOW_SECS=60 + +# Logging +RUST_LOG=info +"#; + fs::write(format!("{name}/.env.example"), env_example).await?; + + // Copy .env.example to .env for development + fs::copy(format!("{name}/.env.example"), format!("{name}/.env")).await?; + + // .gitignore + common::generate_gitignore(name).await?; + + Ok(()) +} diff --git a/crates/cargo-rustapi/src/templates/minimal.rs b/crates/cargo-rustapi/src/templates/minimal.rs new file mode 100644 index 0000000..3fc58cd --- /dev/null +++ b/crates/cargo-rustapi/src/templates/minimal.rs @@ -0,0 +1,65 @@ +//! Minimal project template + +use super::common; +use anyhow::Result; +use tokio::fs; + +pub async fn generate(name: &str, features: &[String]) -> Result<()> { + // Cargo.toml + let cargo_toml = format!( + r#"[package] +name = "{name}" +version = "0.1.0" +edition = "2021" + +[dependencies] +rustapi-rs = {{ version = "0.1"{features} }} +tokio = {{ version = "1", features = ["full"] }} +serde = {{ version = "1", features = ["derive"] }} +"#, + name = name, + features = common::features_to_cargo(features), + ); + fs::write(format!("{name}/Cargo.toml"), cargo_toml).await?; + + // src directory + fs::create_dir_all(format!("{name}/src")).await?; + + // main.rs + let main_rs = r#"use rustapi_rs::prelude::*; +use serde::Serialize; + +#[derive(Serialize, Schema)] +struct Hello { + message: String, +} + +async fn hello() -> Json { + Json(Hello { + message: "Hello, World!".to_string(), + }) +} + +#[rustapi::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let port = std::env::var("PORT").unwrap_or_else(|_| "8080".to_string()); + let addr = format!("127.0.0.1:{}", port); + + println!("🚀 Server running at http://{}", addr); + + RustApi::new() + .route("/", get(hello)) + .docs("/docs") + .run(&addr) + .await +} +"#; + fs::write(format!("{name}/src/main.rs"), main_rs).await?; + + // .gitignore + common::generate_gitignore(name).await?; + + Ok(()) +} diff --git a/crates/cargo-rustapi/src/templates/mod.rs b/crates/cargo-rustapi/src/templates/mod.rs new file mode 100644 index 0000000..4f1138b --- /dev/null +++ b/crates/cargo-rustapi/src/templates/mod.rs @@ -0,0 +1,104 @@ +//! Project templates + +mod api; +mod full; +mod minimal; +mod web; + +use anyhow::Result; +use clap::ValueEnum; + +/// Available project templates +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub enum ProjectTemplate { + /// Minimal starter template + Minimal, + /// REST API template with CRUD + Api, + /// Web app template with Tera templates + Web, + /// Full-featured template with all batteries + Full, +} + +/// Generate a project from a template +pub async fn generate_project( + name: &str, + template: ProjectTemplate, + features: &[String], +) -> Result<()> { + match template { + ProjectTemplate::Minimal => minimal::generate(name, features).await, + ProjectTemplate::Api => api::generate(name, features).await, + ProjectTemplate::Web => web::generate(name, features).await, + ProjectTemplate::Full => full::generate(name, features).await, + } +} + +/// Common files for all templates +pub mod common { + use anyhow::Result; + use tokio::fs; + + pub async fn generate_gitignore(path: &str) -> Result<()> { + let content = r#"# Generated by Cargo +/target/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Environment +.env +.env.local +.env.*.local + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log +"#; + fs::write(format!("{path}/.gitignore"), content).await?; + Ok(()) + } + + pub async fn generate_env_example(path: &str) -> Result<()> { + let content = r#"# Server configuration +HOST=127.0.0.1 +PORT=8080 + +# Environment (development, production) +RUSTAPI_ENV=development + +# Database (if using sqlx) +# DATABASE_URL=postgres://user:pass@localhost/db + +# JWT Secret (if using jwt feature) +# JWT_SECRET=your-secret-key-here + +# Logging +RUST_LOG=info +"#; + fs::write(format!("{path}/.env.example"), content).await?; + Ok(()) + } + + pub fn features_to_cargo(features: &[String]) -> String { + if features.is_empty() { + String::new() + } else { + format!( + ", features = [{}]", + features + .iter() + .map(|f| format!("\"{}\"", f)) + .collect::>() + .join(", ") + ) + } + } +} diff --git a/crates/cargo-rustapi/src/templates/web.rs b/crates/cargo-rustapi/src/templates/web.rs new file mode 100644 index 0000000..2dfb78c --- /dev/null +++ b/crates/cargo-rustapi/src/templates/web.rs @@ -0,0 +1,241 @@ +//! Web project template with Tera templates + +use super::common; +use anyhow::Result; +use tokio::fs; + +pub async fn generate(name: &str, features: &[String]) -> Result<()> { + // Add view feature + let mut all_features = features.to_vec(); + if !all_features.contains(&"view".to_string()) { + all_features.push("view".to_string()); + } + + // Cargo.toml + let cargo_toml = format!( + r#"[package] +name = "{name}" +version = "0.1.0" +edition = "2021" + +[dependencies] +rustapi-rs = {{ version = "0.1"{features} }} +rustapi-view = "0.1" +tokio = {{ version = "1", features = ["full"] }} +serde = {{ version = "1", features = ["derive"] }} +tracing = "0.1" +tracing-subscriber = {{ version = "0.3", features = ["env-filter"] }} +"#, + name = name, + features = common::features_to_cargo(&all_features), + ); + fs::write(format!("{name}/Cargo.toml"), cargo_toml).await?; + + // Create directories + fs::create_dir_all(format!("{name}/src/handlers")).await?; + fs::create_dir_all(format!("{name}/templates")).await?; + fs::create_dir_all(format!("{name}/static")).await?; + + // main.rs + let main_rs = r#"mod handlers; + +use rustapi_rs::prelude::*; +use rustapi_view::Templates; + +#[rustapi::main] +async fn main() -> Result<(), Box> { + // Initialize tracing + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("info".parse().unwrap()), + ) + .init(); + + // Initialize templates + let templates = Templates::new("templates/**/*.html")?; + + let port = std::env::var("PORT").unwrap_or_else(|_| "8080".to_string()); + let addr = format!("127.0.0.1:{}", port); + + tracing::info!("🚀 Server running at http://{}", addr); + + RustApi::new() + .state(templates) + // Pages + .route("/", get(handlers::home)) + .route("/about", get(handlers::about)) + // Static files + .serve_static("/static", "./static") + .run(&addr) + .await +} +"#; + fs::write(format!("{name}/src/main.rs"), main_rs).await?; + + // handlers/mod.rs + let handlers_mod = r#"//! Page handlers + +use rustapi_rs::prelude::*; +use rustapi_view::{Templates, View}; +use serde::Serialize; + +#[derive(Serialize)] +pub struct HomeContext { + pub title: String, + pub message: String, +} + +#[derive(Serialize)] +pub struct AboutContext { + pub title: String, + pub version: String, +} + +/// Home page +pub async fn home(State(templates): State) -> View { + View::render(&templates, "index.html", HomeContext { + title: "Home".to_string(), + message: "Welcome to RustAPI!".to_string(), + }).await +} + +/// About page +pub async fn about(State(templates): State) -> View { + View::render(&templates, "about.html", AboutContext { + title: "About".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }).await +} +"#; + fs::write(format!("{name}/src/handlers/mod.rs"), handlers_mod).await?; + + // templates/base.html + let base_html = r#" + + + + + {% block title %}{{ title }}{% endblock %} - RustAPI + + {% block head %}{% endblock %} + + + + +
+ {% block content %}{% endblock %} +
+ +
+

Built with RustAPI

+
+ + {% block scripts %}{% endblock %} + + +"#; + fs::write(format!("{name}/templates/base.html"), base_html).await?; + + // templates/index.html + let index_html = r#"{% extends "base.html" %} + +{% block content %} +

{{ message }}

+

This is a RustAPI web application with Tera templates.

+ +

Features

+
    +
  • Server-side rendering with Tera
  • +
  • Static file serving
  • +
  • Layout inheritance
  • +
+{% endblock %} +"#; + fs::write(format!("{name}/templates/index.html"), index_html).await?; + + // templates/about.html + let about_html = r#"{% extends "base.html" %} + +{% block content %} +

About

+

Version: {{ version }}

+

RustAPI is a FastAPI-like web framework for Rust.

+{% endblock %} +"#; + fs::write(format!("{name}/templates/about.html"), about_html).await?; + + // static/style.css + let style_css = r#"* { + box-sizing: border-box; + margin: 0; + padding: 0; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; + line-height: 1.6; + color: #333; + max-width: 800px; + margin: 0 auto; + padding: 20px; +} + +nav { + margin-bottom: 2rem; + padding-bottom: 1rem; + border-bottom: 1px solid #eee; +} + +nav a { + margin-right: 1rem; + color: #0066cc; + text-decoration: none; +} + +nav a:hover { + text-decoration: underline; +} + +main { + min-height: calc(100vh - 200px); +} + +h1 { + margin-bottom: 1rem; + color: #222; +} + +h2 { + margin-top: 2rem; + margin-bottom: 0.5rem; +} + +p { + margin-bottom: 1rem; +} + +ul { + margin-left: 2rem; + margin-bottom: 1rem; +} + +footer { + margin-top: 3rem; + padding-top: 1rem; + border-top: 1px solid #eee; + color: #666; + font-size: 0.9rem; +} +"#; + fs::write(format!("{name}/static/style.css"), style_css).await?; + + // .gitignore and .env.example + common::generate_gitignore(name).await?; + common::generate_env_example(name).await?; + + Ok(()) +} diff --git a/crates/rustapi-core/Cargo.toml b/crates/rustapi-core/Cargo.toml index 962d2b4..be55eb9 100644 --- a/crates/rustapi-core/Cargo.toml +++ b/crates/rustapi-core/Cargo.toml @@ -44,6 +44,10 @@ linkme = { workspace = true } uuid = { workspace = true } base64 = "0.22" +# Compression (optional) +flate2 = { version = "1.0", optional = true } +brotli = { version = "6.0", optional = true } + # Cookies (optional) cookie = { version = "0.18", optional = true } @@ -69,3 +73,5 @@ test-utils = [] cookies = ["dep:cookie"] sqlx = ["dep:sqlx"] metrics = ["dep:prometheus"] +compression = ["dep:flate2"] +compression-brotli = ["compression", "dep:brotli"] diff --git a/crates/rustapi-core/src/app.rs b/crates/rustapi-core/src/app.rs index d1c40b6..50adbf6 100644 --- a/crates/rustapi-core/src/app.rs +++ b/crates/rustapi-core/src/app.rs @@ -2,6 +2,7 @@ use crate::error::Result; use crate::middleware::{BodyLimitLayer, LayerStack, MiddlewareLayer, DEFAULT_BODY_LIMIT}; +use crate::response::IntoResponse; use crate::router::{MethodRouter, Router}; use crate::server::Server; use std::collections::HashMap; @@ -405,6 +406,126 @@ impl RustApi { self } + /// Serve static files from a directory + /// + /// Maps a URL path prefix to a filesystem directory. Requests to paths under + /// the prefix will serve files from the corresponding location in the directory. + /// + /// # Arguments + /// + /// * `prefix` - URL path prefix (e.g., "/static", "/assets") + /// * `root` - Filesystem directory path + /// + /// # Features + /// + /// - Automatic MIME type detection + /// - ETag and Last-Modified headers for caching + /// - Index file serving for directories + /// - Path traversal prevention + /// + /// # Example + /// + /// ```rust,ignore + /// use rustapi_rs::prelude::*; + /// + /// RustApi::new() + /// .serve_static("/assets", "./public") + /// .serve_static("/uploads", "./uploads") + /// .run("127.0.0.1:8080") + /// .await + /// ``` + pub fn serve_static(self, prefix: &str, root: impl Into) -> Self { + self.serve_static_with_config(crate::static_files::StaticFileConfig::new(root, prefix)) + } + + /// Serve static files with custom configuration + /// + /// # Example + /// + /// ```rust,ignore + /// use rustapi_core::static_files::StaticFileConfig; + /// + /// let config = StaticFileConfig::new("./public", "/assets") + /// .max_age(86400) // Cache for 1 day + /// .fallback("index.html"); // SPA fallback + /// + /// RustApi::new() + /// .serve_static_with_config(config) + /// .run("127.0.0.1:8080") + /// .await + /// ``` + pub fn serve_static_with_config(self, config: crate::static_files::StaticFileConfig) -> Self { + use crate::router::MethodRouter; + use std::collections::HashMap; + + let prefix = config.prefix.clone(); + let catch_all_path = format!("{}/*path", prefix.trim_end_matches('/')); + + // Create the static file handler + let handler: crate::handler::BoxedHandler = + std::sync::Arc::new(move |req: crate::Request| { + let config = config.clone(); + let path = req.uri().path().to_string(); + + Box::pin(async move { + let relative_path = path.strip_prefix(&config.prefix).unwrap_or(&path); + + match crate::static_files::StaticFile::serve(relative_path, &config).await { + Ok(response) => response, + Err(err) => err.into_response(), + } + }) + as std::pin::Pin + Send>> + }); + + let mut handlers = HashMap::new(); + handlers.insert(http::Method::GET, handler); + let method_router = MethodRouter::from_boxed(handlers); + + self.route(&catch_all_path, method_router) + } + + /// Enable response compression + /// + /// Adds gzip/deflate compression for response bodies. The compression + /// is based on the client's Accept-Encoding header. + /// + /// # Example + /// + /// ```rust,ignore + /// use rustapi_rs::prelude::*; + /// + /// RustApi::new() + /// .compression() + /// .route("/", get(handler)) + /// .run("127.0.0.1:8080") + /// .await + /// ``` + #[cfg(feature = "compression")] + pub fn compression(self) -> Self { + self.layer(crate::middleware::CompressionLayer::new()) + } + + /// Enable response compression with custom configuration + /// + /// # Example + /// + /// ```rust,ignore + /// use rustapi_core::middleware::CompressionConfig; + /// + /// RustApi::new() + /// .compression_with_config( + /// CompressionConfig::new() + /// .min_size(512) + /// .level(9) + /// ) + /// .route("/", get(handler)) + /// ``` + #[cfg(feature = "compression")] + pub fn compression_with_config(self, config: crate::middleware::CompressionConfig) -> Self { + self.layer(crate::middleware::CompressionLayer::with_config(config)) + } + /// Enable Swagger UI documentation /// /// This adds two endpoints: diff --git a/crates/rustapi-core/src/lib.rs b/crates/rustapi-core/src/lib.rs index 0ec3bb7..bff83bc 100644 --- a/crates/rustapi-core/src/lib.rs +++ b/crates/rustapi-core/src/lib.rs @@ -57,12 +57,14 @@ mod error; mod extract; mod handler; pub mod middleware; +pub mod multipart; pub mod path_validation; mod request; mod response; mod router; mod server; pub mod sse; +pub mod static_files; pub mod stream; #[cfg(any(test, feature = "test-utils"))] mod test_client; @@ -92,13 +94,17 @@ pub use handler::{ delete_route, get_route, patch_route, post_route, put_route, Handler, HandlerService, Route, RouteHandler, }; +#[cfg(feature = "compression")] +pub use middleware::CompressionLayer; pub use middleware::{BodyLimitLayer, RequestId, RequestIdLayer, TracingLayer, DEFAULT_BODY_LIMIT}; #[cfg(feature = "metrics")] pub use middleware::{MetricsLayer, MetricsResponse}; +pub use multipart::{Multipart, MultipartConfig, MultipartField, UploadedFile}; pub use request::Request; pub use response::{Created, Html, IntoResponse, NoContent, Redirect, Response, WithStatus}; pub use router::{delete, get, patch, post, put, MethodRouter, Router}; -pub use sse::{Sse, SseEvent}; +pub use sse::{sse_response, KeepAlive, Sse, SseEvent}; +pub use static_files::{serve_dir, StaticFile, StaticFileConfig}; pub use stream::StreamBody; #[cfg(any(test, feature = "test-utils"))] pub use test_client::{TestClient, TestRequest, TestResponse}; diff --git a/crates/rustapi-core/src/middleware/compression.rs b/crates/rustapi-core/src/middleware/compression.rs new file mode 100644 index 0000000..8193db3 --- /dev/null +++ b/crates/rustapi-core/src/middleware/compression.rs @@ -0,0 +1,434 @@ +//! Response compression middleware +//! +//! This module provides Gzip and Brotli compression for response bodies. +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_rs::prelude::*; +//! use rustapi_core::middleware::CompressionLayer; +//! +//! RustApi::new() +//! .layer(CompressionLayer::new()) +//! .route("/", get(handler)) +//! .run("127.0.0.1:8080") +//! .await +//! ``` + +use crate::middleware::{BoxedNext, MiddlewareLayer}; +use crate::request::Request; +use crate::response::Response; +use bytes::Bytes; +use flate2::write::{DeflateEncoder, GzEncoder}; +use flate2::Compression; +use http::header; +use http_body_util::{BodyExt, Full}; +use std::future::Future; +use std::io::Write; +use std::pin::Pin; + +/// Supported compression algorithms +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CompressionAlgorithm { + /// Gzip compression + Gzip, + /// Deflate compression + Deflate, + /// Brotli compression (if enabled) + #[cfg(feature = "compression-brotli")] + Brotli, + /// No compression + Identity, +} + +impl CompressionAlgorithm { + /// Get the Content-Encoding header value + pub fn content_encoding(&self) -> &'static str { + match self { + Self::Gzip => "gzip", + Self::Deflate => "deflate", + #[cfg(feature = "compression-brotli")] + Self::Brotli => "br", + Self::Identity => "identity", + } + } + + /// Parse from Accept-Encoding header + pub fn from_accept_encoding(header: &str) -> Self { + let encodings: Vec<(f32, &str)> = header + .split(',') + .map(|part| { + let part = part.trim(); + let (encoding, quality) = if let Some((enc, q)) = part.split_once(";q=") { + (enc.trim(), q.trim().parse().unwrap_or(1.0)) + } else { + (part, 1.0) + }; + (quality, encoding) + }) + .collect(); + + // Sort by quality (highest first) + let mut sorted = encodings; + sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + + for (_, encoding) in sorted { + match encoding.to_lowercase().as_str() { + #[cfg(feature = "compression-brotli")] + "br" => return Self::Brotli, + "gzip" => return Self::Gzip, + "deflate" => return Self::Deflate, + "*" => return Self::Gzip, // Default to gzip for wildcard + _ => continue, + } + } + + Self::Identity + } +} + +/// Configuration for compression middleware +#[derive(Clone)] +pub struct CompressionConfig { + /// Minimum response size to compress (default: 1024 bytes) + pub min_size: usize, + /// Compression level (0-9 for gzip/deflate, 0-11 for brotli) + pub level: u32, + /// Content types to compress (empty = all compressible types) + pub content_types: Vec, + /// Enable gzip compression + pub gzip: bool, + /// Enable deflate compression + pub deflate: bool, + /// Enable brotli compression + #[cfg(feature = "compression-brotli")] + pub brotli: bool, +} + +impl Default for CompressionConfig { + fn default() -> Self { + Self { + min_size: 1024, + level: 6, // Good balance between speed and compression + content_types: vec![ + "text/".to_string(), + "application/json".to_string(), + "application/javascript".to_string(), + "application/xml".to_string(), + "image/svg+xml".to_string(), + ], + gzip: true, + deflate: true, + #[cfg(feature = "compression-brotli")] + brotli: true, + } + } +} + +impl CompressionConfig { + /// Create a new compression config with default values + pub fn new() -> Self { + Self::default() + } + + /// Set minimum size for compression + pub fn min_size(mut self, size: usize) -> Self { + self.min_size = size; + self + } + + /// Set compression level (0-9) + pub fn level(mut self, level: u32) -> Self { + self.level = level.min(9); + self + } + + /// Enable or disable gzip + pub fn gzip(mut self, enabled: bool) -> Self { + self.gzip = enabled; + self + } + + /// Enable or disable deflate + pub fn deflate(mut self, enabled: bool) -> Self { + self.deflate = enabled; + self + } + + /// Enable or disable brotli + #[cfg(feature = "compression-brotli")] + pub fn brotli(mut self, enabled: bool) -> Self { + self.brotli = enabled; + self + } + + /// Add a content type to compress + pub fn add_content_type(mut self, content_type: impl Into) -> Self { + self.content_types.push(content_type.into()); + self + } + + /// Set content types to compress + pub fn content_types(mut self, types: Vec) -> Self { + self.content_types = types; + self + } + + /// Check if a content type should be compressed + fn should_compress_content_type(&self, content_type: &str) -> bool { + if self.content_types.is_empty() { + return true; + } + self.content_types + .iter() + .any(|ct| content_type.starts_with(ct.as_str())) + } +} + +/// Compression middleware layer +#[derive(Clone)] +pub struct CompressionLayer { + config: CompressionConfig, +} + +impl CompressionLayer { + /// Create a new compression layer with default config + pub fn new() -> Self { + Self { + config: CompressionConfig::default(), + } + } + + /// Create a compression layer with custom config + pub fn with_config(config: CompressionConfig) -> Self { + Self { config } + } + + /// Set minimum size for compression + pub fn min_size(mut self, size: usize) -> Self { + self.config.min_size = size; + self + } + + /// Set compression level + pub fn level(mut self, level: u32) -> Self { + self.config.level = level.min(9); + self + } + + /// Compress bytes using the specified algorithm + fn compress( + &self, + data: &[u8], + algorithm: CompressionAlgorithm, + ) -> Result, std::io::Error> { + let level = Compression::new(self.config.level); + + match algorithm { + CompressionAlgorithm::Gzip => { + let mut encoder = GzEncoder::new(Vec::new(), level); + encoder.write_all(data)?; + encoder.finish() + } + CompressionAlgorithm::Deflate => { + let mut encoder = DeflateEncoder::new(Vec::new(), level); + encoder.write_all(data)?; + encoder.finish() + } + #[cfg(feature = "compression-brotli")] + CompressionAlgorithm::Brotli => { + use brotli::enc::BrotliEncoderParams; + let mut output = Vec::new(); + let params = BrotliEncoderParams::default(); + brotli::BrotliCompress(&mut &data[..], &mut output, ¶ms)?; + Ok(output) + } + CompressionAlgorithm::Identity => Ok(data.to_vec()), + } + } +} + +impl Default for CompressionLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for CompressionLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + + // Get accepted encoding from request + let accept_encoding = req + .headers() + .get(header::ACCEPT_ENCODING) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + Box::pin(async move { + // Call next handler + let response = next(req).await; + + // Determine compression algorithm + let algorithm = accept_encoding + .as_ref() + .map(|ae| CompressionAlgorithm::from_accept_encoding(ae)) + .unwrap_or(CompressionAlgorithm::Identity); + + // Check if we should compress + if algorithm == CompressionAlgorithm::Identity { + return response; + } + + // Check if response is already encoded + if response.headers().contains_key(header::CONTENT_ENCODING) { + return response; + } + + // Check content type + let content_type = response + .headers() + .get(header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + if !config.should_compress_content_type(content_type) { + return response; + } + + // Get body + let (parts, body) = response.into_parts(); + let body_bytes = match body.collect().await { + Ok(collected) => collected.to_bytes(), + Err(_) => return http::Response::from_parts(parts, Full::new(Bytes::new())), + }; + + // Check minimum size + if body_bytes.len() < config.min_size { + let response = http::Response::from_parts(parts, Full::new(body_bytes)); + return response; + } + + // Compress + let layer = CompressionLayer { config }; + match layer.compress(&body_bytes, algorithm) { + Ok(compressed) => { + // Only use compressed if it's smaller + if compressed.len() < body_bytes.len() { + let mut response = + http::Response::from_parts(parts, Full::new(Bytes::from(compressed))); + response.headers_mut().insert( + header::CONTENT_ENCODING, + algorithm.content_encoding().parse().unwrap(), + ); + response.headers_mut().remove(header::CONTENT_LENGTH); + response + } else { + http::Response::from_parts(parts, Full::new(body_bytes)) + } + } + Err(_) => http::Response::from_parts(parts, Full::new(body_bytes)), + } + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_accept_encoding() { + assert_eq!( + CompressionAlgorithm::from_accept_encoding("gzip"), + CompressionAlgorithm::Gzip + ); + assert_eq!( + CompressionAlgorithm::from_accept_encoding("deflate"), + CompressionAlgorithm::Deflate + ); + assert_eq!( + CompressionAlgorithm::from_accept_encoding("gzip, deflate"), + CompressionAlgorithm::Gzip + ); + assert_eq!( + CompressionAlgorithm::from_accept_encoding("deflate;q=1.0, gzip;q=0.5"), + CompressionAlgorithm::Deflate + ); + assert_eq!( + CompressionAlgorithm::from_accept_encoding("identity"), + CompressionAlgorithm::Identity + ); + } + + #[test] + fn test_compression_config() { + let config = CompressionConfig::new() + .min_size(512) + .level(9) + .gzip(true) + .deflate(false) + .add_content_type("application/custom"); + + assert_eq!(config.min_size, 512); + assert_eq!(config.level, 9); + assert!(config.gzip); + assert!(!config.deflate); + assert!(config + .content_types + .contains(&"application/custom".to_string())); + } + + #[test] + fn test_content_type_filtering() { + let config = CompressionConfig::new(); + + assert!(config.should_compress_content_type("text/html")); + assert!(config.should_compress_content_type("application/json")); + assert!(config.should_compress_content_type("text/plain")); + assert!(!config.should_compress_content_type("image/png")); + } + + #[test] + fn test_gzip_compression() { + let layer = CompressionLayer::new(); + let data = b"Hello, World! This is test data that should be compressed."; + + let compressed = layer.compress(data, CompressionAlgorithm::Gzip).unwrap(); + + // Compressed data should be valid gzip (starts with magic bytes) + assert!(compressed.len() >= 2); + assert_eq!(compressed[0], 0x1f); + assert_eq!(compressed[1], 0x8b); + } + + #[test] + fn test_deflate_compression() { + let layer = CompressionLayer::new(); + let data = b"Hello, World! This is test data that should be compressed."; + + let compressed = layer.compress(data, CompressionAlgorithm::Deflate).unwrap(); + + // Deflate produces output + assert!(!compressed.is_empty()); + } + + #[test] + fn test_identity_no_compression() { + let layer = CompressionLayer::new(); + let data = b"Hello, World!"; + + let result = layer + .compress(data, CompressionAlgorithm::Identity) + .unwrap(); + assert_eq!(result, data); + } +} diff --git a/crates/rustapi-core/src/middleware/mod.rs b/crates/rustapi-core/src/middleware/mod.rs index a3c1e95..9148c22 100644 --- a/crates/rustapi-core/src/middleware/mod.rs +++ b/crates/rustapi-core/src/middleware/mod.rs @@ -17,6 +17,8 @@ //! ``` mod body_limit; +#[cfg(feature = "compression")] +mod compression; mod layer; #[cfg(feature = "metrics")] mod metrics; @@ -24,6 +26,8 @@ mod request_id; mod tracing_layer; pub use body_limit::{BodyLimitLayer, DEFAULT_BODY_LIMIT}; +#[cfg(feature = "compression")] +pub use compression::{CompressionAlgorithm, CompressionConfig, CompressionLayer}; pub use layer::{BoxedNext, LayerStack, MiddlewareLayer}; #[cfg(feature = "metrics")] pub use metrics::{MetricsLayer, MetricsResponse}; diff --git a/crates/rustapi-core/src/multipart.rs b/crates/rustapi-core/src/multipart.rs new file mode 100644 index 0000000..812e0df --- /dev/null +++ b/crates/rustapi-core/src/multipart.rs @@ -0,0 +1,543 @@ +//! Multipart form data extractor for file uploads +//! +//! This module provides types for handling `multipart/form-data` requests, +//! commonly used for file uploads. +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::multipart::{Multipart, FieldData}; +//! +//! async fn upload(mut multipart: Multipart) -> Result { +//! while let Some(field) = multipart.next_field().await? { +//! let name = field.name().unwrap_or("unknown"); +//! let filename = field.file_name().map(|s| s.to_string()); +//! let data = field.bytes().await?; +//! +//! println!("Field: {}, File: {:?}, Size: {} bytes", name, filename, data.len()); +//! } +//! Ok("Upload successful".to_string()) +//! } +//! ``` + +use crate::error::{ApiError, Result}; +use crate::extract::FromRequest; +use crate::request::Request; +use bytes::Bytes; +use std::path::Path; + +/// Maximum file size (default: 10MB) +pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; + +/// Maximum number of fields in multipart form (default: 100) +pub const DEFAULT_MAX_FIELDS: usize = 100; + +/// Multipart form data extractor +/// +/// Parses `multipart/form-data` requests, commonly used for file uploads. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::multipart::Multipart; +/// +/// async fn upload(mut multipart: Multipart) -> Result { +/// while let Some(field) = multipart.next_field().await? { +/// let name = field.name().unwrap_or("unknown").to_string(); +/// let data = field.bytes().await?; +/// println!("Received field '{}' with {} bytes", name, data.len()); +/// } +/// Ok("Upload complete".to_string()) +/// } +/// ``` +pub struct Multipart { + fields: Vec, + current_index: usize, +} + +impl Multipart { + /// Create a new Multipart from raw data + fn new(fields: Vec) -> Self { + Self { + fields, + current_index: 0, + } + } + + /// Get the next field from the multipart form + pub async fn next_field(&mut self) -> Result> { + if self.current_index >= self.fields.len() { + return Ok(None); + } + let field = self.fields.get(self.current_index).cloned(); + self.current_index += 1; + Ok(field) + } + + /// Collect all fields into a vector + pub fn into_fields(self) -> Vec { + self.fields + } + + /// Get the number of fields + pub fn field_count(&self) -> usize { + self.fields.len() + } +} + +/// A single field from a multipart form +#[derive(Clone)] +pub struct MultipartField { + name: Option, + file_name: Option, + content_type: Option, + data: Bytes, +} + +impl MultipartField { + /// Create a new multipart field + pub fn new( + name: Option, + file_name: Option, + content_type: Option, + data: Bytes, + ) -> Self { + Self { + name, + file_name, + content_type, + data, + } + } + + /// Get the field name + pub fn name(&self) -> Option<&str> { + self.name.as_deref() + } + + /// Get the original filename (if this is a file upload) + pub fn file_name(&self) -> Option<&str> { + self.file_name.as_deref() + } + + /// Get the content type of the field + pub fn content_type(&self) -> Option<&str> { + self.content_type.as_deref() + } + + /// Check if this field is a file upload + pub fn is_file(&self) -> bool { + self.file_name.is_some() + } + + /// Get the field data as bytes + pub async fn bytes(&self) -> Result { + Ok(self.data.clone()) + } + + /// Get the field data as a string (UTF-8) + pub async fn text(&self) -> Result { + String::from_utf8(self.data.to_vec()) + .map_err(|e| ApiError::bad_request(format!("Invalid UTF-8 in field: {}", e))) + } + + /// Get the size of the field data in bytes + pub fn size(&self) -> usize { + self.data.len() + } + + /// Save the file to disk + /// + /// # Arguments + /// + /// * `path` - The directory to save the file to + /// * `filename` - Optional custom filename, uses original filename if None + /// + /// # Example + /// + /// ```rust,ignore + /// field.save_to("./uploads", None).await?; + /// // or with custom filename + /// field.save_to("./uploads", Some("custom_name.txt")).await?; + /// ``` + pub async fn save_to(&self, dir: impl AsRef, filename: Option<&str>) -> Result { + let dir = dir.as_ref(); + + // Ensure directory exists + tokio::fs::create_dir_all(dir) + .await + .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?; + + // Determine filename + let final_filename = filename + .map(|s| s.to_string()) + .or_else(|| self.file_name.clone()) + .ok_or_else(|| { + ApiError::bad_request("No filename provided and field has no filename") + })?; + + // Sanitize filename to prevent path traversal + let safe_filename = sanitize_filename(&final_filename); + let file_path = dir.join(&safe_filename); + + // Write file + tokio::fs::write(&file_path, &self.data) + .await + .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?; + + Ok(file_path.to_string_lossy().to_string()) + } +} + +/// Sanitize a filename to prevent path traversal attacks +fn sanitize_filename(filename: &str) -> String { + // Remove path separators and parent directory references + filename + .replace(['/', '\\'], "_") + .replace("..", "_") + .trim_start_matches('.') + .to_string() +} + +impl FromRequest for Multipart { + async fn from_request(req: &mut Request) -> Result { + // Check content type + let content_type = req + .headers() + .get(http::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| ApiError::bad_request("Missing Content-Type header"))?; + + if !content_type.starts_with("multipart/form-data") { + return Err(ApiError::bad_request(format!( + "Expected multipart/form-data, got: {}", + content_type + ))); + } + + // Extract boundary + let boundary = extract_boundary(content_type) + .ok_or_else(|| ApiError::bad_request("Missing boundary in Content-Type"))?; + + // Get body + let body = req + .take_body() + .ok_or_else(|| ApiError::internal("Body already consumed"))?; + + // Parse multipart + let fields = parse_multipart(&body, &boundary)?; + + Ok(Multipart::new(fields)) + } +} + +/// Extract boundary from Content-Type header +fn extract_boundary(content_type: &str) -> Option { + content_type.split(';').find_map(|part| { + let part = part.trim(); + if part.starts_with("boundary=") { + let boundary = part.trim_start_matches("boundary=").trim_matches('"'); + Some(boundary.to_string()) + } else { + None + } + }) +} + +/// Parse multipart form data +fn parse_multipart(body: &Bytes, boundary: &str) -> Result> { + let mut fields = Vec::new(); + let delimiter = format!("--{}", boundary); + let end_delimiter = format!("--{}--", boundary); + + // Convert body to string for easier parsing + // Note: This is a simplified parser. For production, consider using multer crate. + let body_str = String::from_utf8_lossy(body); + + // Split by delimiter + let parts: Vec<&str> = body_str.split(&delimiter).collect(); + + for part in parts.iter().skip(1) { + // Skip empty parts and end delimiter + let part = part.trim_start_matches("\r\n").trim_start_matches('\n'); + if part.is_empty() || part.starts_with("--") { + continue; + } + + // Find header/body separator (blank line) + let header_body_split = if let Some(pos) = part.find("\r\n\r\n") { + pos + } else if let Some(pos) = part.find("\n\n") { + pos + } else { + continue; + }; + + let headers_section = &part[..header_body_split]; + let body_section = &part[header_body_split..] + .trim_start_matches("\r\n\r\n") + .trim_start_matches("\n\n"); + + // Remove trailing boundary markers from body + let body_section = body_section + .trim_end_matches(&end_delimiter) + .trim_end_matches(&delimiter) + .trim_end_matches("\r\n") + .trim_end_matches('\n'); + + // Parse headers + let mut name = None; + let mut filename = None; + let mut content_type = None; + + for header_line in headers_section.lines() { + let header_line = header_line.trim(); + if header_line.is_empty() { + continue; + } + + if let Some((key, value)) = header_line.split_once(':') { + let key = key.trim().to_lowercase(); + let value = value.trim(); + + match key.as_str() { + "content-disposition" => { + // Parse name and filename from Content-Disposition + for part in value.split(';') { + let part = part.trim(); + if part.starts_with("name=") { + name = Some( + part.trim_start_matches("name=") + .trim_matches('"') + .to_string(), + ); + } else if part.starts_with("filename=") { + filename = Some( + part.trim_start_matches("filename=") + .trim_matches('"') + .to_string(), + ); + } + } + } + "content-type" => { + content_type = Some(value.to_string()); + } + _ => {} + } + } + } + + fields.push(MultipartField::new( + name, + filename, + content_type, + Bytes::copy_from_slice(body_section.as_bytes()), + )); + } + + Ok(fields) +} + +/// Configuration for multipart form handling +#[derive(Clone)] +pub struct MultipartConfig { + /// Maximum total size of the multipart form (default: 10MB) + pub max_size: usize, + /// Maximum number of fields (default: 100) + pub max_fields: usize, + /// Maximum size per file (default: 10MB) + pub max_file_size: usize, + /// Allowed content types for files (empty = all allowed) + pub allowed_content_types: Vec, +} + +impl Default for MultipartConfig { + fn default() -> Self { + Self { + max_size: DEFAULT_MAX_FILE_SIZE, + max_fields: DEFAULT_MAX_FIELDS, + max_file_size: DEFAULT_MAX_FILE_SIZE, + allowed_content_types: Vec::new(), + } + } +} + +impl MultipartConfig { + /// Create a new multipart config with default values + pub fn new() -> Self { + Self::default() + } + + /// Set the maximum total size + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Set the maximum number of fields + pub fn max_fields(mut self, count: usize) -> Self { + self.max_fields = count; + self + } + + /// Set the maximum file size + pub fn max_file_size(mut self, size: usize) -> Self { + self.max_file_size = size; + self + } + + /// Set allowed content types for file uploads + pub fn allowed_content_types(mut self, types: Vec) -> Self { + self.allowed_content_types = types; + self + } + + /// Add an allowed content type + pub fn allow_content_type(mut self, content_type: impl Into) -> Self { + self.allowed_content_types.push(content_type.into()); + self + } +} + +/// File data wrapper for convenient access to uploaded files +#[derive(Clone)] +pub struct UploadedFile { + /// Original filename + pub filename: String, + /// Content type (MIME type) + pub content_type: Option, + /// File data + pub data: Bytes, +} + +impl UploadedFile { + /// Create from a multipart field + pub fn from_field(field: &MultipartField) -> Option { + field.file_name().map(|filename| Self { + filename: filename.to_string(), + content_type: field.content_type().map(|s| s.to_string()), + data: field.data.clone(), + }) + } + + /// Get file size in bytes + pub fn size(&self) -> usize { + self.data.len() + } + + /// Get file extension + pub fn extension(&self) -> Option<&str> { + self.filename.rsplit('.').next() + } + + /// Save to disk with original filename + pub async fn save_to(&self, dir: impl AsRef) -> Result { + let dir = dir.as_ref(); + + tokio::fs::create_dir_all(dir) + .await + .map_err(|e| ApiError::internal(format!("Failed to create upload directory: {}", e)))?; + + let safe_filename = sanitize_filename(&self.filename); + let file_path = dir.join(&safe_filename); + + tokio::fs::write(&file_path, &self.data) + .await + .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?; + + Ok(file_path.to_string_lossy().to_string()) + } + + /// Save with a custom filename + pub async fn save_as(&self, path: impl AsRef) -> Result<()> { + let path = path.as_ref(); + + if let Some(parent) = path.parent() { + tokio::fs::create_dir_all(parent) + .await + .map_err(|e| ApiError::internal(format!("Failed to create directory: {}", e)))?; + } + + tokio::fs::write(path, &self.data) + .await + .map_err(|e| ApiError::internal(format!("Failed to save file: {}", e)))?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_boundary() { + let ct = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_boundary(ct), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + let ct_quoted = "multipart/form-data; boundary=\"----WebKitFormBoundary\""; + assert_eq!( + extract_boundary(ct_quoted), + Some("----WebKitFormBoundary".to_string()) + ); + } + + #[test] + fn test_sanitize_filename() { + assert_eq!(sanitize_filename("test.txt"), "test.txt"); + assert_eq!(sanitize_filename("../../../etc/passwd"), "______etc_passwd"); + // ..\..\windows\system32 -> .._.._windows_system32 -> ____windows_system32 + assert_eq!( + sanitize_filename("..\\..\\windows\\system32"), + "____windows_system32" + ); + assert_eq!(sanitize_filename(".hidden"), "hidden"); + } + + #[test] + fn test_parse_simple_multipart() { + let boundary = "----WebKitFormBoundary"; + let body = format!( + "------WebKitFormBoundary\r\n\ + Content-Disposition: form-data; name=\"field1\"\r\n\ + \r\n\ + value1\r\n\ + ------WebKitFormBoundary\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\ + Content-Type: text/plain\r\n\ + \r\n\ + file content\r\n\ + ------WebKitFormBoundary--\r\n" + ); + + let fields = parse_multipart(&Bytes::from(body), boundary).unwrap(); + assert_eq!(fields.len(), 2); + + assert_eq!(fields[0].name(), Some("field1")); + assert!(!fields[0].is_file()); + + assert_eq!(fields[1].name(), Some("file")); + assert_eq!(fields[1].file_name(), Some("test.txt")); + assert_eq!(fields[1].content_type(), Some("text/plain")); + assert!(fields[1].is_file()); + } + + #[test] + fn test_multipart_config() { + let config = MultipartConfig::new() + .max_size(20 * 1024 * 1024) + .max_fields(50) + .max_file_size(5 * 1024 * 1024) + .allow_content_type("image/png") + .allow_content_type("image/jpeg"); + + assert_eq!(config.max_size, 20 * 1024 * 1024); + assert_eq!(config.max_fields, 50); + assert_eq!(config.max_file_size, 5 * 1024 * 1024); + assert_eq!(config.allowed_content_types.len(), 2); + } +} diff --git a/crates/rustapi-core/src/sse.rs b/crates/rustapi-core/src/sse.rs index 597570d..0068298 100644 --- a/crates/rustapi-core/src/sse.rs +++ b/crates/rustapi-core/src/sse.rs @@ -1,12 +1,14 @@ //! Server-Sent Events (SSE) response types for RustAPI //! //! This module provides types for streaming Server-Sent Events to clients. +//! SSE is ideal for real-time updates like notifications, live feeds, and progress updates. //! //! # Example //! //! ```rust,ignore -//! use rustapi_core::sse::{Sse, SseEvent}; +//! use rustapi_core::sse::{Sse, SseEvent, KeepAlive}; //! use futures_util::stream; +//! use std::time::Duration; //! //! async fn events() -> Sse>> { //! let stream = stream::iter(vec![ @@ -14,6 +16,32 @@ //! Ok(SseEvent::new("World").event("greeting")), //! ]); //! Sse::new(stream) +//! .keep_alive(KeepAlive::new().interval(Duration::from_secs(15))) +//! } +//! ``` +//! +//! # Keep-Alive Support +//! +//! SSE connections can be kept alive by sending periodic comments: +//! +//! ```rust,ignore +//! use rustapi_core::sse::{Sse, SseEvent, KeepAlive}; +//! use std::time::Duration; +//! +//! async fn events() -> impl IntoResponse { +//! let stream = async_stream::stream! { +//! for i in 0..10 { +//! yield Ok::<_, std::convert::Infallible>( +//! SseEvent::new(format!("Event {}", i)) +//! ); +//! tokio::time::sleep(Duration::from_secs(1)).await; +//! } +//! }; +//! +//! Sse::new(stream) +//! .keep_alive(KeepAlive::new() +//! .interval(Duration::from_secs(30)) +//! .text("ping")) //! } //! ``` @@ -21,7 +49,11 @@ use bytes::Bytes; use futures_util::Stream; use http::{header, StatusCode}; use http_body_util::Full; +use pin_project_lite::pin_project; use std::fmt::Write; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; use crate::response::{IntoResponse, Response}; @@ -33,6 +65,7 @@ use crate::response::{IntoResponse, Response}; /// - `event`: The event type/name (optional) /// - `id`: The event ID for reconnection (optional) /// - `retry`: Reconnection time in milliseconds (optional) +/// - `comment`: A comment line (optional, not visible to most clients) #[derive(Debug, Clone, Default)] pub struct SseEvent { /// The event data @@ -43,6 +76,8 @@ pub struct SseEvent { pub id: Option, /// Reconnection time in milliseconds pub retry: Option, + /// Comment line + comment: Option, } impl SseEvent { @@ -53,6 +88,20 @@ impl SseEvent { event: None, id: None, retry: None, + comment: None, + } + } + + /// Create an SSE comment (keep-alive) + /// + /// Comments are lines starting with `:` and are typically used for keep-alive. + pub fn comment(text: impl Into) -> Self { + Self { + data: String::new(), + event: None, + id: None, + retry: None, + comment: Some(text.into()), } } @@ -74,6 +123,11 @@ impl SseEvent { self } + /// Set JSON data (serializes the value) + pub fn json_data(data: &T) -> Result { + Ok(Self::new(serde_json::to_string(data)?)) + } + /// Format the event as an SSE message /// /// The format follows the SSE specification: @@ -81,10 +135,18 @@ impl SseEvent { /// - Lines starting with "id:" specify the event ID /// - Lines starting with "retry:" specify the reconnection time /// - Lines starting with "data:" contain the event data + /// - Lines starting with ":" are comments /// - Events are terminated with a blank line pub fn to_sse_string(&self) -> String { let mut output = String::new(); + // Comment (for keep-alive) + if let Some(ref comment) = self.comment { + writeln!(output, ": {}", comment).unwrap(); + output.push('\n'); + return output; + } + // Event type if let Some(ref event) = self.event { writeln!(output, "event: {}", event).unwrap(); @@ -105,11 +167,81 @@ impl SseEvent { writeln!(output, "data: {}", line).unwrap(); } + // If data is empty, still send an empty data line + if self.data.is_empty() && self.comment.is_none() { + writeln!(output, "data:").unwrap(); + } + // Empty line to terminate the event output.push('\n'); output } + + /// Convert the event to bytes + pub fn to_bytes(&self) -> Bytes { + Bytes::from(self.to_sse_string()) + } +} + +/// Keep-alive configuration for SSE connections +/// +/// Keep-alive sends periodic comments to prevent connection timeouts. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::sse::KeepAlive; +/// use std::time::Duration; +/// +/// let keep_alive = KeepAlive::new() +/// .interval(Duration::from_secs(30)) +/// .text("ping"); +/// ``` +#[derive(Debug, Clone)] +pub struct KeepAlive { + /// Interval between keep-alive messages + interval: Duration, + /// Text to send as keep-alive comment + text: String, +} + +impl Default for KeepAlive { + fn default() -> Self { + Self { + interval: Duration::from_secs(15), + text: "keep-alive".to_string(), + } + } +} + +impl KeepAlive { + /// Create a new keep-alive configuration with default settings + pub fn new() -> Self { + Self::default() + } + + /// Set the keep-alive interval + pub fn interval(mut self, interval: Duration) -> Self { + self.interval = interval; + self + } + + /// Set the keep-alive text + pub fn text(mut self, text: impl Into) -> Self { + self.text = text.into(); + self + } + + /// Get the interval + pub fn get_interval(&self) -> Duration { + self.interval + } + + /// Create the keep-alive event + pub fn event(&self) -> SseEvent { + SseEvent::comment(&self.text) + } } /// Server-Sent Events response wrapper @@ -119,8 +251,9 @@ impl SseEvent { /// # Example /// /// ```rust,ignore -/// use rustapi_core::sse::{Sse, SseEvent}; +/// use rustapi_core::sse::{Sse, SseEvent, KeepAlive}; /// use futures_util::stream; +/// use std::time::Duration; /// /// async fn events() -> Sse>> { /// let stream = stream::iter(vec![ @@ -128,12 +261,12 @@ impl SseEvent { /// Ok(SseEvent::new("World").event("greeting")), /// ]); /// Sse::new(stream) +/// .keep_alive(KeepAlive::new().interval(Duration::from_secs(30))) /// } /// ``` pub struct Sse { - #[allow(dead_code)] stream: S, - keep_alive: Option, + keep_alive: Option, } impl Sse { @@ -145,14 +278,77 @@ impl Sse { } } - /// Set the keep-alive interval + /// Set the keep-alive configuration + /// + /// When set, the server will send periodic comments to keep the connection alive. /// - /// When set, the server will send a comment (`:keep-alive`) at the specified interval - /// to keep the connection alive. - pub fn keep_alive(mut self, interval: std::time::Duration) -> Self { - self.keep_alive = Some(interval); + /// # Example + /// + /// ```rust,ignore + /// use rustapi_core::sse::{Sse, KeepAlive}; + /// use std::time::Duration; + /// + /// Sse::new(stream) + /// .keep_alive(KeepAlive::new().interval(Duration::from_secs(30))) + /// ``` + pub fn keep_alive(mut self, config: KeepAlive) -> Self { + self.keep_alive = Some(config); self } + + /// Get the keep-alive configuration + pub fn get_keep_alive(&self) -> Option<&KeepAlive> { + self.keep_alive.as_ref() + } +} + +// Stream that merges SSE events with keep-alive events +pin_project! { + /// A stream that combines SSE events with keep-alive messages + pub struct SseStream { + #[pin] + inner: S, + keep_alive: Option, + #[pin] + keep_alive_timer: Option, + } +} + +impl Stream for SseStream +where + S: Stream>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + // First, check if there's an event ready from the inner stream + match this.inner.poll_next(cx) { + Poll::Ready(Some(Ok(event))) => { + return Poll::Ready(Some(Ok(event.to_bytes()))); + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Pending => {} + } + + // Check keep-alive timer + if let Some(mut timer) = this.keep_alive_timer.as_pin_mut() { + if timer.poll_tick(cx).is_ready() { + if let Some(keep_alive) = this.keep_alive { + let event = keep_alive.event(); + return Poll::Ready(Some(Ok(event.to_bytes()))); + } + } + } + + Poll::Pending + } } // For now, we'll implement IntoResponse by collecting the stream into a single response @@ -164,23 +360,83 @@ where E: std::error::Error + Send + Sync + 'static, { fn into_response(self) -> Response { - // For the initial implementation, we return a response with SSE headers - // and an empty body. The actual streaming would require a different body type. - // This is a placeholder that sets up the correct headers. + // For the synchronous IntoResponse, we need to return immediately + // The actual streaming would be handled by an async body type + // For now, return headers with empty body as placeholder + // Real streaming requires server-side async body support + // + // Note: The SseStream wrapper can be used for true streaming + // when integrated with a streaming body type + + let _ = self.stream; // Consume stream (in production, would be streamed) + let _ = self.keep_alive; // Keep-alive would be used in streaming - // Note: A full implementation would use a streaming body type. - // For now, we create a response with the correct headers that can be - // used as a starting point for SSE responses. http::Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, "text/event-stream") .header(header::CACHE_CONTROL, "no-cache") .header(header::CONNECTION, "keep-alive") + .header("X-Accel-Buffering", "no") // Disable nginx buffering .body(Full::new(Bytes::new())) .unwrap() } } +/// Collect all SSE events from a stream into a single response body +/// +/// This is useful for testing or when you know the stream is finite. +pub async fn collect_sse_events(stream: S) -> Result +where + S: Stream> + Send, +{ + use futures_util::StreamExt; + + let mut buffer = Vec::new(); + futures_util::pin_mut!(stream); + + while let Some(result) = stream.next().await { + let event = result?; + buffer.extend_from_slice(&event.to_bytes()); + } + + Ok(Bytes::from(buffer)) +} + +/// Create an SSE response from a synchronous iterator of events +/// +/// This is a convenience function for simple cases with pre-computed events. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::sse::{sse_response, SseEvent}; +/// +/// async fn handler() -> Response { +/// sse_response(vec![ +/// SseEvent::new("Hello"), +/// SseEvent::new("World").event("greeting"), +/// ]) +/// } +/// ``` +pub fn sse_response(events: I) -> Response +where + I: IntoIterator, +{ + let mut buffer = String::new(); + for event in events { + buffer.push_str(&event.to_sse_string()); + } + + http::Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .header("X-Accel-Buffering", "no") + .body(Full::new(Bytes::from(buffer))) + .unwrap() +} + /// Helper function to create an SSE response from an iterator of events /// /// This is useful for simple cases where you have a fixed set of events. diff --git a/crates/rustapi-core/src/static_files.rs b/crates/rustapi-core/src/static_files.rs new file mode 100644 index 0000000..8df502b --- /dev/null +++ b/crates/rustapi-core/src/static_files.rs @@ -0,0 +1,474 @@ +//! Static file serving for RustAPI +//! +//! This module provides types for serving static files from a directory. +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_rs::prelude::*; +//! +//! RustApi::new() +//! .serve_static("/assets", "./static") +//! .serve_static("/uploads", "./uploads") +//! .run("127.0.0.1:8080") +//! .await +//! ``` + +use crate::error::ApiError; +use crate::response::{IntoResponse, Response}; +use bytes::Bytes; +use http::{header, StatusCode}; +use http_body_util::Full; +use std::path::{Path, PathBuf}; +use std::time::SystemTime; +use tokio::fs; + +/// MIME type detection based on file extension +fn mime_type_for_extension(extension: &str) -> &'static str { + match extension.to_lowercase().as_str() { + // Text + "html" | "htm" => "text/html; charset=utf-8", + "css" => "text/css; charset=utf-8", + "js" | "mjs" => "text/javascript; charset=utf-8", + "json" => "application/json", + "xml" => "application/xml", + "txt" => "text/plain; charset=utf-8", + "md" => "text/markdown; charset=utf-8", + "csv" => "text/csv", + + // Images + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + "gif" => "image/gif", + "webp" => "image/webp", + "svg" => "image/svg+xml", + "ico" => "image/x-icon", + "bmp" => "image/bmp", + "avif" => "image/avif", + + // Fonts + "woff" => "font/woff", + "woff2" => "font/woff2", + "ttf" => "font/ttf", + "otf" => "font/otf", + "eot" => "application/vnd.ms-fontobject", + + // Audio/Video + "mp3" => "audio/mpeg", + "wav" => "audio/wav", + "ogg" => "audio/ogg", + "mp4" => "video/mp4", + "webm" => "video/webm", + + // Documents + "pdf" => "application/pdf", + "zip" => "application/zip", + "tar" => "application/x-tar", + "gz" => "application/gzip", + + // WebAssembly + "wasm" => "application/wasm", + + // Default + _ => "application/octet-stream", + } +} + +/// Calculate ETag from file metadata +fn calculate_etag(modified: SystemTime, size: u64) -> String { + let timestamp = modified + .duration_since(SystemTime::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + format!("\"{:x}-{:x}\"", timestamp, size) +} + +/// Format system time as HTTP date (RFC 7231) +fn format_http_date(time: SystemTime) -> String { + use std::time::Duration; + + let duration = time + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or(Duration::ZERO); + let secs = duration.as_secs(); + + // Simple HTTP date formatting + // In production, you'd use a proper date formatting library + let days = secs / 86400; + let remaining = secs % 86400; + let hours = remaining / 3600; + let minutes = (remaining % 3600) / 60; + let seconds = remaining % 60; + + // Calculate day of week and date (simplified) + let days_since_epoch = days; + let day_of_week = (days_since_epoch + 4) % 7; // Jan 1, 1970 was Thursday + let day_names = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"]; + let month_names = [ + "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", + ]; + + // Calculate year, month, day (simplified leap year handling) + let mut year = 1970; + let mut remaining_days = days_since_epoch as i64; + + loop { + let days_in_year = if is_leap_year(year) { 366 } else { 365 }; + if remaining_days < days_in_year { + break; + } + remaining_days -= days_in_year; + year += 1; + } + + let mut month = 0; + let days_in_months = if is_leap_year(year) { + [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + } else { + [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + }; + + for (i, &days_in_month) in days_in_months.iter().enumerate() { + if remaining_days < days_in_month as i64 { + month = i; + break; + } + remaining_days -= days_in_month as i64; + } + + let day = remaining_days + 1; + + format!( + "{}, {:02} {} {} {:02}:{:02}:{:02} GMT", + day_names[day_of_week as usize], day, month_names[month], year, hours, minutes, seconds + ) +} + +fn is_leap_year(year: i64) -> bool { + (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) +} + +/// Static file serving configuration +#[derive(Clone)] +pub struct StaticFileConfig { + /// Root directory for static files + pub root: PathBuf, + /// URL path prefix + pub prefix: String, + /// Whether to serve index.html for directories + pub serve_index: bool, + /// Index file name (default: "index.html") + pub index_file: String, + /// Enable ETag headers + pub etag: bool, + /// Enable Last-Modified headers + pub last_modified: bool, + /// Cache-Control max-age in seconds (0 = no caching) + pub max_age: u64, + /// Fallback file for SPA routing (e.g., "index.html") + pub fallback: Option, +} + +impl Default for StaticFileConfig { + fn default() -> Self { + Self { + root: PathBuf::from("./static"), + prefix: "/".to_string(), + serve_index: true, + index_file: "index.html".to_string(), + etag: true, + last_modified: true, + max_age: 3600, // 1 hour + fallback: None, + } + } +} + +impl StaticFileConfig { + /// Create a new static file configuration + pub fn new(root: impl Into, prefix: impl Into) -> Self { + Self { + root: root.into(), + prefix: prefix.into(), + ..Default::default() + } + } + + /// Set whether to serve index.html for directories + pub fn serve_index(mut self, enabled: bool) -> Self { + self.serve_index = enabled; + self + } + + /// Set the index file name + pub fn index_file(mut self, name: impl Into) -> Self { + self.index_file = name.into(); + self + } + + /// Enable or disable ETag headers + pub fn etag(mut self, enabled: bool) -> Self { + self.etag = enabled; + self + } + + /// Enable or disable Last-Modified headers + pub fn last_modified(mut self, enabled: bool) -> Self { + self.last_modified = enabled; + self + } + + /// Set Cache-Control max-age in seconds + pub fn max_age(mut self, seconds: u64) -> Self { + self.max_age = seconds; + self + } + + /// Set a fallback file for SPA routing + pub fn fallback(mut self, file: impl Into) -> Self { + self.fallback = Some(file.into()); + self + } +} + +/// Static file response +pub struct StaticFile { + #[allow(dead_code)] + path: PathBuf, + #[allow(dead_code)] + config: StaticFileConfig, +} + +impl StaticFile { + /// Create a new static file response + pub fn new(path: impl Into, config: StaticFileConfig) -> Self { + Self { + path: path.into(), + config, + } + } + + /// Serve a file from a path relative to the root + pub async fn serve( + relative_path: &str, + config: &StaticFileConfig, + ) -> Result { + // Sanitize path to prevent directory traversal + let clean_path = sanitize_path(relative_path); + let file_path = config.root.join(&clean_path); + + // Check if it's a directory + if file_path.is_dir() { + if config.serve_index { + let index_path = file_path.join(&config.index_file); + if index_path.exists() { + return Self::serve_file(&index_path, config).await; + } + } + return Err(ApiError::not_found("Directory listing not allowed")); + } + + // Try to serve the file + match Self::serve_file(&file_path, config).await { + Ok(response) => Ok(response), + Err(_) if config.fallback.is_some() => { + // Try fallback + let fallback_path = config.root.join(config.fallback.as_ref().unwrap()); + Self::serve_file(&fallback_path, config).await + } + Err(e) => Err(e), + } + } + + /// Serve a specific file + async fn serve_file(path: &Path, config: &StaticFileConfig) -> Result { + // Check if file exists + let metadata = fs::metadata(path) + .await + .map_err(|_| ApiError::not_found(format!("File not found: {}", path.display())))?; + + if !metadata.is_file() { + return Err(ApiError::not_found("Not a file")); + } + + // Read file + let content = fs::read(path) + .await + .map_err(|e| ApiError::internal(format!("Failed to read file: {}", e)))?; + + // Determine content type + let extension = path.extension().and_then(|e| e.to_str()).unwrap_or(""); + let content_type = mime_type_for_extension(extension); + + // Build response + let mut builder = http::Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, content_type) + .header(header::CONTENT_LENGTH, content.len()); + + // Add ETag + if config.etag { + if let Ok(modified) = metadata.modified() { + let etag = calculate_etag(modified, metadata.len()); + builder = builder.header(header::ETAG, etag); + } + } + + // Add Last-Modified + if config.last_modified { + if let Ok(modified) = metadata.modified() { + let http_date = format_http_date(modified); + builder = builder.header(header::LAST_MODIFIED, http_date); + } + } + + // Add Cache-Control + if config.max_age > 0 { + builder = builder.header( + header::CACHE_CONTROL, + format!("public, max-age={}", config.max_age), + ); + } + + builder + .body(Full::new(Bytes::from(content))) + .map_err(|e| ApiError::internal(format!("Failed to build response: {}", e))) + } +} + +/// Sanitize a file path to prevent directory traversal +fn sanitize_path(path: &str) -> String { + // Remove leading slashes + let path = path.trim_start_matches('/'); + + // Split and filter out dangerous components + let parts: Vec<&str> = path + .split('/') + .filter(|part| !part.is_empty() && *part != "." && *part != ".." && !part.contains('\\')) + .collect(); + + parts.join("/") +} + +/// Create a handler for serving static files +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::static_files::{static_handler, StaticFileConfig}; +/// +/// let config = StaticFileConfig::new("./public", "/assets"); +/// let handler = static_handler(config); +/// ``` +pub fn static_handler( + config: StaticFileConfig, +) -> impl Fn(crate::Request) -> std::pin::Pin + Send>> + + Clone + + Send + + Sync + + 'static { + move |req: crate::Request| { + let config = config.clone(); + let path = req.uri().path().to_string(); + + Box::pin(async move { + // Strip prefix from path + let relative_path = path.strip_prefix(&config.prefix).unwrap_or(&path); + + match StaticFile::serve(relative_path, &config).await { + Ok(response) => response, + Err(err) => err.into_response(), + } + }) + } +} + +/// Create a static file serving route +/// +/// This is the main function for adding static file serving to RustAPI. +/// +/// # Arguments +/// +/// * `prefix` - URL path prefix (e.g., "/static") +/// * `root` - File system root directory +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::static_files::serve_dir; +/// +/// // The handler can be used with a catch-all route +/// let config = serve_dir("/static", "./public"); +/// ``` +pub fn serve_dir(prefix: impl Into, root: impl Into) -> StaticFileConfig { + StaticFileConfig::new(root.into(), prefix.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mime_type_detection() { + assert_eq!(mime_type_for_extension("html"), "text/html; charset=utf-8"); + assert_eq!(mime_type_for_extension("css"), "text/css; charset=utf-8"); + assert_eq!( + mime_type_for_extension("js"), + "text/javascript; charset=utf-8" + ); + assert_eq!(mime_type_for_extension("png"), "image/png"); + assert_eq!(mime_type_for_extension("jpg"), "image/jpeg"); + assert_eq!(mime_type_for_extension("json"), "application/json"); + assert_eq!( + mime_type_for_extension("unknown"), + "application/octet-stream" + ); + } + + #[test] + fn test_sanitize_path() { + assert_eq!(sanitize_path("file.txt"), "file.txt"); + assert_eq!(sanitize_path("/file.txt"), "file.txt"); + assert_eq!(sanitize_path("../../../etc/passwd"), "etc/passwd"); + assert_eq!(sanitize_path("foo/../bar"), "foo/bar"); + assert_eq!(sanitize_path("./file.txt"), "file.txt"); + assert_eq!(sanitize_path("foo/./bar"), "foo/bar"); + } + + #[test] + fn test_etag_calculation() { + let time = SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1000000); + let etag = calculate_etag(time, 12345); + assert!(etag.starts_with('"')); + assert!(etag.ends_with('"')); + assert!(etag.contains('-')); + } + + #[test] + fn test_static_file_config() { + let config = StaticFileConfig::new("./public", "/assets") + .serve_index(true) + .index_file("index.html") + .etag(true) + .last_modified(true) + .max_age(7200) + .fallback("index.html"); + + assert_eq!(config.root, PathBuf::from("./public")); + assert_eq!(config.prefix, "/assets"); + assert!(config.serve_index); + assert_eq!(config.index_file, "index.html"); + assert!(config.etag); + assert!(config.last_modified); + assert_eq!(config.max_age, 7200); + assert_eq!(config.fallback, Some("index.html".to_string())); + } + + #[test] + fn test_is_leap_year() { + assert!(is_leap_year(2000)); // Divisible by 400 + assert!(!is_leap_year(1900)); // Divisible by 100 but not 400 + assert!(is_leap_year(2024)); // Divisible by 4 but not 100 + assert!(!is_leap_year(2023)); // Not divisible by 4 + } +} diff --git a/crates/rustapi-extras/Cargo.toml b/crates/rustapi-extras/Cargo.toml index bdda1d9..5b539f4 100644 --- a/crates/rustapi-extras/Cargo.toml +++ b/crates/rustapi-extras/Cargo.toml @@ -47,10 +47,14 @@ envy = { version = "0.4", optional = true } # Cookies (feature-gated) cookie = { version = "0.18", optional = true } +# Insight (feature-gated) - reuses dashmap from rate-limit +urlencoding = { version = "2.1", optional = true } + [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } proptest = "1.4" rustapi-core = { workspace = true, features = ["test-utils"] } +tempfile = "3.10" [features] default = [] @@ -62,9 +66,10 @@ rate-limit = ["dep:dashmap"] config = ["dep:dotenvy", "dep:envy"] cookies = ["dep:cookie"] sqlx = ["dep:sqlx"] +insight = ["dep:dashmap", "dep:urlencoding"] # Meta feature that enables all security features extras = ["jwt", "cors", "rate-limit"] # Full feature set -full = ["extras", "config", "cookies", "sqlx"] +full = ["extras", "config", "cookies", "sqlx", "insight"] diff --git a/crates/rustapi-extras/src/insight/config.rs b/crates/rustapi-extras/src/insight/config.rs new file mode 100644 index 0000000..569a1bd --- /dev/null +++ b/crates/rustapi-extras/src/insight/config.rs @@ -0,0 +1,477 @@ +//! Configuration for the InsightLayer middleware. +//! +//! This module provides the `InsightConfig` builder for customizing +//! traffic insight collection behavior. + +use super::data::InsightData; +use std::collections::HashSet; +use std::sync::Arc; + +/// Callback function type for processing insights. +pub type InsightCallback = Arc; + +/// Configuration for the InsightLayer middleware. +/// +/// Use the builder pattern to customize behavior: +/// +/// ```ignore +/// use rustapi_extras::insight::InsightConfig; +/// +/// let config = InsightConfig::new() +/// .sample_rate(0.5) // Sample 50% of requests +/// .max_body_size(4096) // Capture up to 4KB of body +/// .skip_path("/health") // Exclude health checks +/// .capture_request_body(true) // Enable request body capture +/// .header_whitelist(vec!["content-type", "user-agent"]); +/// ``` +#[derive(Clone)] +pub struct InsightConfig { + /// Sampling rate (0.0-1.0). 1.0 = all requests, 0.5 = 50% of requests. + pub(crate) sample_rate: f64, + + /// Maximum body size to capture (in bytes). Default: 4096 (4KB). + pub(crate) max_body_size: usize, + + /// Paths to skip from insight collection. + pub(crate) skip_paths: HashSet, + + /// Path prefixes to skip from insight collection. + pub(crate) skip_path_prefixes: HashSet, + + /// Request headers to capture (empty = none, use `*` for all). + pub(crate) header_whitelist: HashSet, + + /// Response headers to capture (empty = none, use `*` for all). + pub(crate) response_header_whitelist: HashSet, + + /// Whether to capture request bodies. Default: false. + pub(crate) capture_request_body: bool, + + /// Whether to capture response bodies. Default: false. + pub(crate) capture_response_body: bool, + + /// Callback to invoke for each insight (optional). + pub(crate) on_insight: Option, + + /// Dashboard endpoint path. Set to None to disable. Default: "/insights". + pub(crate) dashboard_path: Option, + + /// Stats endpoint path. Set to None to disable. Default: "/insights/stats". + pub(crate) stats_path: Option, + + /// Storage capacity for in-memory store. Default: 1000. + pub(crate) store_capacity: usize, + + /// Sensitive headers to redact (values replaced with "[REDACTED]"). + pub(crate) sensitive_headers: HashSet, + + /// Content types to capture body for. Default: application/json, text/*. + pub(crate) capturable_content_types: HashSet, +} + +impl Default for InsightConfig { + fn default() -> Self { + Self::new() + } +} + +impl InsightConfig { + /// Create a new configuration with default values. + /// + /// Defaults: + /// - Sample rate: 1.0 (all requests) + /// - Max body size: 4096 bytes (4KB) + /// - No paths skipped + /// - No headers captured + /// - Body capture disabled + /// - Dashboard at "/insights" + /// - Stats at "/insights/stats" + /// - Store capacity: 1000 entries + pub fn new() -> Self { + let mut sensitive = HashSet::new(); + sensitive.insert("authorization".to_string()); + sensitive.insert("cookie".to_string()); + sensitive.insert("x-api-key".to_string()); + sensitive.insert("x-auth-token".to_string()); + + let mut capturable = HashSet::new(); + capturable.insert("application/json".to_string()); + capturable.insert("text/plain".to_string()); + capturable.insert("text/html".to_string()); + capturable.insert("application/xml".to_string()); + capturable.insert("text/xml".to_string()); + + Self { + sample_rate: 1.0, + max_body_size: 4096, + skip_paths: HashSet::new(), + skip_path_prefixes: HashSet::new(), + header_whitelist: HashSet::new(), + response_header_whitelist: HashSet::new(), + capture_request_body: false, + capture_response_body: false, + on_insight: None, + dashboard_path: Some("/insights".to_string()), + stats_path: Some("/insights/stats".to_string()), + store_capacity: 1000, + sensitive_headers: sensitive, + capturable_content_types: capturable, + } + } + + /// Set the sampling rate (0.0 to 1.0). + /// + /// # Arguments + /// + /// * `rate` - Fraction of requests to sample. 1.0 = all, 0.1 = 10%. + /// + /// # Example + /// + /// ```ignore + /// let config = InsightConfig::new().sample_rate(0.5); // 50% sampling + /// ``` + pub fn sample_rate(mut self, rate: f64) -> Self { + self.sample_rate = rate.clamp(0.0, 1.0); + self + } + + /// Set the maximum body size to capture. + /// + /// Bodies larger than this will be truncated. + pub fn max_body_size(mut self, size: usize) -> Self { + self.max_body_size = size; + self + } + + /// Add a path to skip from insight collection. + /// + /// Exact match against request path. + pub fn skip_path(mut self, path: impl Into) -> Self { + self.skip_paths.insert(path.into()); + self + } + + /// Add multiple paths to skip. + pub fn skip_paths(mut self, paths: impl IntoIterator>) -> Self { + for path in paths { + self.skip_paths.insert(path.into()); + } + self + } + + /// Add a path prefix to skip. + /// + /// Any request path starting with this prefix will be skipped. + pub fn skip_path_prefix(mut self, prefix: impl Into) -> Self { + self.skip_path_prefixes.insert(prefix.into()); + self + } + + /// Set the request header whitelist. + /// + /// Only headers in this list will be captured. Use "*" to capture all. + /// Header names are case-insensitive. + pub fn header_whitelist( + mut self, + headers: impl IntoIterator>, + ) -> Self { + self.header_whitelist = headers + .into_iter() + .map(|h| h.into().to_lowercase()) + .collect(); + self + } + + /// Set the response header whitelist. + pub fn response_header_whitelist( + mut self, + headers: impl IntoIterator>, + ) -> Self { + self.response_header_whitelist = headers + .into_iter() + .map(|h| h.into().to_lowercase()) + .collect(); + self + } + + /// Enable or disable request body capture. + /// + /// When enabled, request bodies (up to max_body_size) will be stored. + pub fn capture_request_body(mut self, capture: bool) -> Self { + self.capture_request_body = capture; + self + } + + /// Enable or disable response body capture. + pub fn capture_response_body(mut self, capture: bool) -> Self { + self.capture_response_body = capture; + self + } + + /// Set a callback to invoke for each collected insight. + /// + /// Useful for custom processing, external logging, or real-time alerts. + /// + /// # Example + /// + /// ```ignore + /// let config = InsightConfig::new() + /// .on_insight(|insight| { + /// if insight.duration_ms > 1000 { + /// tracing::warn!("Slow request: {} {}ms", insight.path, insight.duration_ms); + /// } + /// }); + /// ``` + pub fn on_insight(mut self, callback: F) -> Self + where + F: Fn(&InsightData) + Send + Sync + 'static, + { + self.on_insight = Some(Arc::new(callback)); + self + } + + /// Set the dashboard endpoint path. + /// + /// Set to None to disable the dashboard endpoint. + pub fn dashboard_path(mut self, path: Option>) -> Self { + self.dashboard_path = path.map(|p| p.into()); + self + } + + /// Set the stats endpoint path. + /// + /// Set to None to disable the stats endpoint. + pub fn stats_path(mut self, path: Option>) -> Self { + self.stats_path = path.map(|p| p.into()); + self + } + + /// Set the in-memory store capacity. + /// + /// Older entries are evicted when capacity is reached. + pub fn store_capacity(mut self, capacity: usize) -> Self { + self.store_capacity = capacity; + self + } + + /// Add a sensitive header name. + /// + /// Values for these headers will be replaced with `"[REDACTED]"`. + pub fn sensitive_header(mut self, header: impl Into) -> Self { + self.sensitive_headers.insert(header.into().to_lowercase()); + self + } + + /// Set capturable content types. + /// + /// Bodies are only captured for requests/responses with these content types. + pub fn capturable_content_types( + mut self, + types: impl IntoIterator>, + ) -> Self { + self.capturable_content_types = + types.into_iter().map(|t| t.into().to_lowercase()).collect(); + self + } + + /// Check if a path should be skipped. + pub(crate) fn should_skip_path(&self, path: &str) -> bool { + // Check exact matches + if self.skip_paths.contains(path) { + return true; + } + + // Check prefixes + for prefix in &self.skip_path_prefixes { + if path.starts_with(prefix) { + return true; + } + } + + // Check if this is a dashboard/stats path + if let Some(ref dashboard) = self.dashboard_path { + if path == dashboard { + return true; + } + } + if let Some(ref stats) = self.stats_path { + if path == stats { + return true; + } + } + + false + } + + /// Check if the request should be sampled. + pub(crate) fn should_sample(&self) -> bool { + if self.sample_rate >= 1.0 { + return true; + } + if self.sample_rate <= 0.0 { + return false; + } + rand_sample(self.sample_rate) + } + + /// Check if a header should be captured. + pub(crate) fn should_capture_header(&self, name: &str) -> bool { + if self.header_whitelist.is_empty() { + return false; + } + if self.header_whitelist.contains("*") { + return true; + } + self.header_whitelist.contains(&name.to_lowercase()) + } + + /// Check if a response header should be captured. + pub(crate) fn should_capture_response_header(&self, name: &str) -> bool { + if self.response_header_whitelist.is_empty() { + return false; + } + if self.response_header_whitelist.contains("*") { + return true; + } + self.response_header_whitelist + .contains(&name.to_lowercase()) + } + + /// Check if a header is sensitive. + pub(crate) fn is_sensitive_header(&self, name: &str) -> bool { + self.sensitive_headers.contains(&name.to_lowercase()) + } + + /// Check if content type is capturable. + pub(crate) fn is_capturable_content_type(&self, content_type: &str) -> bool { + let ct_lower = content_type.to_lowercase(); + for allowed in &self.capturable_content_types { + if ct_lower.starts_with(allowed) + || (allowed.ends_with("/*") && ct_lower.starts_with(&allowed[..allowed.len() - 1])) + { + return true; + } + } + // Also allow text/* generically + ct_lower.starts_with("text/") || ct_lower.starts_with("application/json") + } +} + +/// Simple random sampling based on rate. +fn rand_sample(rate: f64) -> bool { + use std::time::{SystemTime, UNIX_EPOCH}; + + // Use system time nanoseconds as a simple random source + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .subsec_nanos(); + + let threshold = (rate * u32::MAX as f64) as u32; + nanos < threshold +} + +impl std::fmt::Debug for InsightConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InsightConfig") + .field("sample_rate", &self.sample_rate) + .field("max_body_size", &self.max_body_size) + .field("skip_paths", &self.skip_paths) + .field("skip_path_prefixes", &self.skip_path_prefixes) + .field("header_whitelist", &self.header_whitelist) + .field("capture_request_body", &self.capture_request_body) + .field("capture_response_body", &self.capture_response_body) + .field("dashboard_path", &self.dashboard_path) + .field("stats_path", &self.stats_path) + .field("store_capacity", &self.store_capacity) + .field("on_insight", &self.on_insight.is_some()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = InsightConfig::new(); + assert_eq!(config.sample_rate, 1.0); + assert_eq!(config.max_body_size, 4096); + assert!(!config.capture_request_body); + assert!(!config.capture_response_body); + assert_eq!(config.dashboard_path, Some("/insights".to_string())); + assert_eq!(config.stats_path, Some("/insights/stats".to_string())); + } + + #[test] + fn test_sample_rate_clamping() { + let config = InsightConfig::new().sample_rate(1.5); + assert_eq!(config.sample_rate, 1.0); + + let config = InsightConfig::new().sample_rate(-0.5); + assert_eq!(config.sample_rate, 0.0); + } + + #[test] + fn test_skip_paths() { + let config = InsightConfig::new() + .skip_path("/health") + .skip_path("/metrics") + .skip_path_prefix("/internal/"); + + assert!(config.should_skip_path("/health")); + assert!(config.should_skip_path("/metrics")); + assert!(config.should_skip_path("/internal/debug")); + assert!(!config.should_skip_path("/users")); + } + + #[test] + fn test_header_whitelist() { + let config = InsightConfig::new().header_whitelist(vec!["Content-Type", "User-Agent"]); + + assert!(config.should_capture_header("content-type")); + assert!(config.should_capture_header("Content-Type")); + assert!(config.should_capture_header("user-agent")); + assert!(!config.should_capture_header("authorization")); + } + + #[test] + fn test_header_wildcard() { + let config = InsightConfig::new().header_whitelist(vec!["*"]); + + assert!(config.should_capture_header("any-header")); + assert!(config.should_capture_header("another-one")); + } + + #[test] + fn test_sensitive_headers() { + let config = InsightConfig::new(); + + assert!(config.is_sensitive_header("authorization")); + assert!(config.is_sensitive_header("Authorization")); + assert!(config.is_sensitive_header("cookie")); + assert!(!config.is_sensitive_header("content-type")); + } + + #[test] + fn test_capturable_content_types() { + let config = InsightConfig::new(); + + assert!(config.is_capturable_content_type("application/json")); + assert!(config.is_capturable_content_type("application/json; charset=utf-8")); + assert!(config.is_capturable_content_type("text/plain")); + assert!(config.is_capturable_content_type("text/html")); + } + + #[test] + fn test_dashboard_path_exclusion() { + let config = InsightConfig::new() + .dashboard_path(Some("/insights")) + .stats_path(Some("/insights/stats")); + + assert!(config.should_skip_path("/insights")); + assert!(config.should_skip_path("/insights/stats")); + assert!(!config.should_skip_path("/users")); + } +} diff --git a/crates/rustapi-extras/src/insight/data.rs b/crates/rustapi-extras/src/insight/data.rs new file mode 100644 index 0000000..f0ddc8a --- /dev/null +++ b/crates/rustapi-extras/src/insight/data.rs @@ -0,0 +1,396 @@ +//! Data structures for traffic insight collection. +//! +//! This module defines the core data types used to capture and store +//! request/response information. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +/// A single insight entry capturing request/response information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InsightData { + /// Unique request identifier + pub request_id: String, + + /// HTTP method (GET, POST, etc.) + pub method: String, + + /// Request path (without query string) + pub path: String, + + /// Query parameters as key-value pairs + pub query_params: HashMap, + + /// HTTP status code of the response + pub status: u16, + + /// Request processing duration in milliseconds + pub duration_ms: u64, + + /// Request body size in bytes + pub request_size: usize, + + /// Response body size in bytes + pub response_size: usize, + + /// Unix timestamp (seconds since epoch) + pub timestamp: u64, + + /// Client IP address + pub client_ip: String, + + /// Captured request headers (based on whitelist) + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub request_headers: HashMap, + + /// Captured response headers (based on whitelist) + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub response_headers: HashMap, + + /// Request body (if capture enabled and within size limit) + #[serde(skip_serializing_if = "Option::is_none")] + pub request_body: Option, + + /// Response body (if capture enabled and within size limit) + #[serde(skip_serializing_if = "Option::is_none")] + pub response_body: Option, + + /// Route pattern that matched (e.g., "/users/{id}") + #[serde(skip_serializing_if = "Option::is_none")] + pub route_pattern: Option, + + /// Custom tags/labels for categorization + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub tags: HashMap, +} + +impl InsightData { + /// Create a new insight entry with required fields. + pub fn new( + request_id: impl Into, + method: impl Into, + path: impl Into, + ) -> Self { + Self { + request_id: request_id.into(), + method: method.into(), + path: path.into(), + query_params: HashMap::new(), + status: 0, + duration_ms: 0, + request_size: 0, + response_size: 0, + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + client_ip: String::new(), + request_headers: HashMap::new(), + response_headers: HashMap::new(), + request_body: None, + response_body: None, + route_pattern: None, + tags: HashMap::new(), + } + } + + /// Set the response status code. + pub fn with_status(mut self, status: u16) -> Self { + self.status = status; + self + } + + /// Set the request duration. + pub fn with_duration(mut self, duration: Duration) -> Self { + self.duration_ms = duration.as_millis() as u64; + self + } + + /// Set the client IP address. + pub fn with_client_ip(mut self, ip: impl Into) -> Self { + self.client_ip = ip.into(); + self + } + + /// Set request body size. + pub fn with_request_size(mut self, size: usize) -> Self { + self.request_size = size; + self + } + + /// Set response body size. + pub fn with_response_size(mut self, size: usize) -> Self { + self.response_size = size; + self + } + + /// Set route pattern. + pub fn with_route_pattern(mut self, pattern: impl Into) -> Self { + self.route_pattern = Some(pattern.into()); + self + } + + /// Add a query parameter. + pub fn add_query_param(&mut self, key: impl Into, value: impl Into) { + self.query_params.insert(key.into(), value.into()); + } + + /// Add a request header. + pub fn add_request_header(&mut self, key: impl Into, value: impl Into) { + self.request_headers.insert(key.into(), value.into()); + } + + /// Add a response header. + pub fn add_response_header(&mut self, key: impl Into, value: impl Into) { + self.response_headers.insert(key.into(), value.into()); + } + + /// Set captured request body. + pub fn set_request_body(&mut self, body: String) { + self.request_body = Some(body); + } + + /// Set captured response body. + pub fn set_response_body(&mut self, body: String) { + self.response_body = Some(body); + } + + /// Add a custom tag. + pub fn add_tag(&mut self, key: impl Into, value: impl Into) { + self.tags.insert(key.into(), value.into()); + } + + /// Check if this is a successful request (2xx status). + pub fn is_success(&self) -> bool { + self.status >= 200 && self.status < 300 + } + + /// Check if this is a client error (4xx status). + pub fn is_client_error(&self) -> bool { + self.status >= 400 && self.status < 500 + } + + /// Check if this is a server error (5xx status). + pub fn is_server_error(&self) -> bool { + self.status >= 500 + } +} + +/// Aggregated statistics from collected insights. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct InsightStats { + /// Total number of requests + pub total_requests: u64, + + /// Total number of successful requests (2xx) + pub successful_requests: u64, + + /// Total number of client errors (4xx) + pub client_errors: u64, + + /// Total number of server errors (5xx) + pub server_errors: u64, + + /// Average response time in milliseconds + pub avg_duration_ms: f64, + + /// Minimum response time in milliseconds + pub min_duration_ms: u64, + + /// Maximum response time in milliseconds + pub max_duration_ms: u64, + + /// 95th percentile response time in milliseconds + pub p95_duration_ms: u64, + + /// 99th percentile response time in milliseconds + pub p99_duration_ms: u64, + + /// Total bytes received (request bodies) + pub total_request_bytes: u64, + + /// Total bytes sent (response bodies) + pub total_response_bytes: u64, + + /// Requests per route pattern + pub requests_by_route: HashMap, + + /// Requests per HTTP method + pub requests_by_method: HashMap, + + /// Requests per status code + pub requests_by_status: HashMap, + + /// Average duration per route + pub avg_duration_by_route: HashMap, + + /// Request rate (requests per second) over the measurement period + pub requests_per_second: f64, + + /// Time period covered by these stats (in seconds) + pub time_period_secs: u64, +} + +impl InsightStats { + /// Create new empty statistics. + pub fn new() -> Self { + Self::default() + } + + /// Calculate statistics from a collection of insights. + pub fn from_insights(insights: &[InsightData]) -> Self { + if insights.is_empty() { + return Self::default(); + } + + let mut stats = Self::new(); + stats.total_requests = insights.len() as u64; + + let mut durations: Vec = Vec::with_capacity(insights.len()); + let mut route_durations: HashMap> = HashMap::new(); + + // Find time range + let min_timestamp = insights.iter().map(|i| i.timestamp).min().unwrap_or(0); + let max_timestamp = insights.iter().map(|i| i.timestamp).max().unwrap_or(0); + stats.time_period_secs = max_timestamp.saturating_sub(min_timestamp).max(1); + + for insight in insights { + // Count by status + if insight.is_success() { + stats.successful_requests += 1; + } else if insight.is_client_error() { + stats.client_errors += 1; + } else if insight.is_server_error() { + stats.server_errors += 1; + } + + // Duration tracking + durations.push(insight.duration_ms); + + // Bytes tracking + stats.total_request_bytes += insight.request_size as u64; + stats.total_response_bytes += insight.response_size as u64; + + // Route tracking + let route = insight + .route_pattern + .clone() + .unwrap_or_else(|| insight.path.clone()); + *stats.requests_by_route.entry(route.clone()).or_insert(0) += 1; + route_durations + .entry(route) + .or_default() + .push(insight.duration_ms); + + // Method tracking + *stats + .requests_by_method + .entry(insight.method.clone()) + .or_insert(0) += 1; + + // Status tracking + *stats.requests_by_status.entry(insight.status).or_insert(0) += 1; + } + + // Calculate duration statistics + if !durations.is_empty() { + durations.sort_unstable(); + + let sum: u64 = durations.iter().sum(); + stats.avg_duration_ms = sum as f64 / durations.len() as f64; + stats.min_duration_ms = durations[0]; + stats.max_duration_ms = durations[durations.len() - 1]; + stats.p95_duration_ms = percentile(&durations, 95); + stats.p99_duration_ms = percentile(&durations, 99); + } + + // Calculate average duration per route + for (route, route_durs) in route_durations { + let sum: u64 = route_durs.iter().sum(); + let avg = sum as f64 / route_durs.len() as f64; + stats.avg_duration_by_route.insert(route, avg); + } + + // Calculate requests per second + stats.requests_per_second = stats.total_requests as f64 / stats.time_period_secs as f64; + + stats + } +} + +/// Calculate the nth percentile of a sorted slice. +fn percentile(sorted: &[u64], n: u8) -> u64 { + if sorted.is_empty() { + return 0; + } + let idx = (sorted.len() as f64 * (n as f64 / 100.0)).ceil() as usize; + sorted[idx.saturating_sub(1).min(sorted.len() - 1)] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_insight_data_creation() { + let insight = InsightData::new("req-123", "GET", "/users") + .with_status(200) + .with_duration(Duration::from_millis(42)) + .with_client_ip("192.168.1.1"); + + assert_eq!(insight.request_id, "req-123"); + assert_eq!(insight.method, "GET"); + assert_eq!(insight.path, "/users"); + assert_eq!(insight.status, 200); + assert_eq!(insight.duration_ms, 42); + assert_eq!(insight.client_ip, "192.168.1.1"); + } + + #[test] + fn test_status_categorization() { + assert!(InsightData::new("", "", "").with_status(200).is_success()); + assert!(InsightData::new("", "", "").with_status(201).is_success()); + assert!(InsightData::new("", "", "") + .with_status(404) + .is_client_error()); + assert!(InsightData::new("", "", "") + .with_status(500) + .is_server_error()); + } + + #[test] + fn test_stats_calculation() { + let insights = vec![ + InsightData::new("1", "GET", "/users") + .with_status(200) + .with_duration(Duration::from_millis(10)), + InsightData::new("2", "POST", "/users") + .with_status(201) + .with_duration(Duration::from_millis(20)), + InsightData::new("3", "GET", "/users") + .with_status(404) + .with_duration(Duration::from_millis(5)), + InsightData::new("4", "GET", "/items") + .with_status(500) + .with_duration(Duration::from_millis(100)), + ]; + + let stats = InsightStats::from_insights(&insights); + + assert_eq!(stats.total_requests, 4); + assert_eq!(stats.successful_requests, 2); + assert_eq!(stats.client_errors, 1); + assert_eq!(stats.server_errors, 1); + assert_eq!(stats.requests_by_method.get("GET"), Some(&3)); + assert_eq!(stats.requests_by_method.get("POST"), Some(&1)); + } + + #[test] + fn test_percentile_calculation() { + let sorted = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + assert_eq!(percentile(&sorted, 50), 5); + assert_eq!(percentile(&sorted, 95), 10); + assert_eq!(percentile(&sorted, 99), 10); + } +} diff --git a/crates/rustapi-extras/src/insight/export.rs b/crates/rustapi-extras/src/insight/export.rs new file mode 100644 index 0000000..281d2c5 --- /dev/null +++ b/crates/rustapi-extras/src/insight/export.rs @@ -0,0 +1,541 @@ +//! Export functionality for insight data. +//! +//! This module provides traits and implementations for exporting +//! insight data to various destinations. + +use super::data::InsightData; +use std::fs::{File, OpenOptions}; +use std::io::{BufWriter, Write}; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + +/// Error type for export operations. +#[derive(Debug, thiserror::Error)] +pub enum ExportError { + /// IO error during export. + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + /// Serialization error. + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + /// HTTP error during webhook export. + #[error("HTTP error: {0}")] + Http(String), + + /// Export sink is closed or unavailable. + #[error("Export sink unavailable: {0}")] + Unavailable(String), +} + +/// Result type for export operations. +pub type ExportResult = Result; + +/// Trait for exporting insight data to external destinations. +/// +/// Implement this trait to create custom export sinks. +pub trait InsightExporter: Send + Sync + 'static { + /// Export a single insight entry. + fn export(&self, insight: &InsightData) -> ExportResult<()>; + + /// Export multiple insights in batch. + fn export_batch(&self, insights: &[InsightData]) -> ExportResult<()> { + for insight in insights { + self.export(insight)?; + } + Ok(()) + } + + /// Flush any buffered data. + fn flush(&self) -> ExportResult<()> { + Ok(()) + } + + /// Close the exporter and release resources. + fn close(&self) -> ExportResult<()> { + self.flush() + } + + /// Clone this exporter into a boxed trait object. + fn clone_exporter(&self) -> Box; +} + +/// File exporter that writes insights as JSON lines. +/// +/// Each insight is written as a single JSON object on its own line, +/// compatible with common log aggregation tools. +/// +/// # Example +/// +/// ```ignore +/// use rustapi_extras::insight::export::FileExporter; +/// +/// let exporter = FileExporter::new("./insights.jsonl")?; +/// ``` +pub struct FileExporter { + path: PathBuf, + writer: Arc>>, +} + +impl FileExporter { + /// Create a new file exporter. + /// + /// Creates or appends to the specified file. + pub fn new(path: impl Into) -> ExportResult { + let path = path.into(); + let file = OpenOptions::new().create(true).append(true).open(&path)?; + let writer = BufWriter::new(file); + + Ok(Self { + path, + writer: Arc::new(Mutex::new(writer)), + }) + } + + /// Get the file path. + pub fn path(&self) -> &PathBuf { + &self.path + } +} + +impl Clone for FileExporter { + fn clone(&self) -> Self { + Self { + path: self.path.clone(), + writer: self.writer.clone(), + } + } +} + +impl InsightExporter for FileExporter { + fn export(&self, insight: &InsightData) -> ExportResult<()> { + let mut writer = self + .writer + .lock() + .map_err(|e| ExportError::Unavailable(e.to_string()))?; + + let json = serde_json::to_string(insight)?; + writeln!(writer, "{}", json)?; + + Ok(()) + } + + fn export_batch(&self, insights: &[InsightData]) -> ExportResult<()> { + let mut writer = self + .writer + .lock() + .map_err(|e| ExportError::Unavailable(e.to_string()))?; + + for insight in insights { + let json = serde_json::to_string(insight)?; + writeln!(writer, "{}", json)?; + } + + Ok(()) + } + + fn flush(&self) -> ExportResult<()> { + let mut writer = self + .writer + .lock() + .map_err(|e| ExportError::Unavailable(e.to_string()))?; + + writer.flush()?; + Ok(()) + } + + fn clone_exporter(&self) -> Box { + Box::new(self.clone()) + } +} + +/// Webhook exporter configuration. +#[derive(Clone, Debug)] +pub struct WebhookConfig { + /// URL to POST insights to. + pub url: String, + /// Optional authorization header value. + pub auth_header: Option, + /// Custom headers to include. + pub headers: Vec<(String, String)>, + /// Batch size for batched exports. + pub batch_size: usize, + /// Request timeout in seconds. + pub timeout_secs: u64, +} + +impl WebhookConfig { + /// Create a new webhook configuration. + pub fn new(url: impl Into) -> Self { + Self { + url: url.into(), + auth_header: None, + headers: Vec::new(), + batch_size: 100, + timeout_secs: 30, + } + } + + /// Set the authorization header. + pub fn auth(mut self, value: impl Into) -> Self { + self.auth_header = Some(value.into()); + self + } + + /// Add a custom header. + pub fn header(mut self, name: impl Into, value: impl Into) -> Self { + self.headers.push((name.into(), value.into())); + self + } + + /// Set the batch size for batched exports. + pub fn batch_size(mut self, size: usize) -> Self { + self.batch_size = size; + self + } + + /// Set the request timeout. + pub fn timeout(mut self, secs: u64) -> Self { + self.timeout_secs = secs; + self + } +} + +/// Webhook exporter that POSTs insights to a URL. +/// +/// Insights are sent as JSON in POST requests. +/// +/// # Example +/// +/// ```ignore +/// use rustapi_extras::insight::export::{WebhookExporter, WebhookConfig}; +/// +/// let config = WebhookConfig::new("https://example.com/insights") +/// .auth("Bearer my-token") +/// .batch_size(50); +/// +/// let exporter = WebhookExporter::new(config); +/// ``` +#[derive(Clone)] +pub struct WebhookExporter { + config: WebhookConfig, + buffer: Arc>>, +} + +impl WebhookExporter { + /// Create a new webhook exporter. + pub fn new(config: WebhookConfig) -> Self { + Self { + config, + buffer: Arc::new(Mutex::new(Vec::new())), + } + } + + /// Send insights to the webhook. + fn send_insights(&self, insights: &[InsightData]) -> ExportResult<()> { + // Note: This is a simplified implementation. + // In production, you'd use an async HTTP client like reqwest. + // For now, we'll just log and return success since this crate + // doesn't want to add heavy HTTP client dependencies. + + let json = serde_json::to_string(insights)?; + tracing::debug!( + url = %self.config.url, + count = insights.len(), + size = json.len(), + "Would send insights to webhook" + ); + + // TODO: Implement actual HTTP POST when reqwest is available + // For now, this is a placeholder that logs the intent + + Ok(()) + } +} + +impl InsightExporter for WebhookExporter { + fn export(&self, insight: &InsightData) -> ExportResult<()> { + let mut buffer = self + .buffer + .lock() + .map_err(|e| ExportError::Unavailable(e.to_string()))?; + + buffer.push(insight.clone()); + + // Flush if batch size reached + if buffer.len() >= self.config.batch_size { + let to_send: Vec<_> = buffer.drain(..).collect(); + drop(buffer); // Release lock before sending + self.send_insights(&to_send)?; + } + + Ok(()) + } + + fn export_batch(&self, insights: &[InsightData]) -> ExportResult<()> { + // Send in batches + for chunk in insights.chunks(self.config.batch_size) { + self.send_insights(chunk)?; + } + Ok(()) + } + + fn flush(&self) -> ExportResult<()> { + let mut buffer = self + .buffer + .lock() + .map_err(|e| ExportError::Unavailable(e.to_string()))?; + + if !buffer.is_empty() { + let to_send: Vec<_> = buffer.drain(..).collect(); + drop(buffer); + self.send_insights(&to_send)?; + } + + Ok(()) + } + + fn clone_exporter(&self) -> Box { + Box::new(self.clone()) + } +} + +/// A composite exporter that sends to multiple destinations. +/// +/// # Example +/// +/// ```ignore +/// use rustapi_extras::insight::export::{CompositeExporter, FileExporter, WebhookExporter, WebhookConfig}; +/// +/// let composite = CompositeExporter::new() +/// .add(FileExporter::new("./insights.jsonl")?) +/// .add(WebhookExporter::new(WebhookConfig::new("https://example.com/insights"))); +/// ``` +#[derive(Default)] +pub struct CompositeExporter { + exporters: Vec>, +} + +impl Clone for CompositeExporter { + fn clone(&self) -> Self { + let exporters = self.exporters.iter().map(|e| e.clone_exporter()).collect(); + Self { exporters } + } +} + +impl CompositeExporter { + /// Create a new composite exporter. + pub fn new() -> Self { + Self { + exporters: Vec::new(), + } + } + + /// Add an exporter to the composite. + pub fn with_exporter(mut self, exporter: E) -> Self { + self.exporters.push(Box::new(exporter)); + self + } + + /// Add a boxed exporter to the composite. + pub fn with_boxed_exporter(mut self, exporter: Box) -> Self { + self.exporters.push(exporter); + self + } +} + +impl InsightExporter for CompositeExporter { + fn export(&self, insight: &InsightData) -> ExportResult<()> { + for exporter in &self.exporters { + if let Err(e) = exporter.export(insight) { + tracing::warn!(error = %e, "Export failed for one sink"); + } + } + Ok(()) + } + + fn export_batch(&self, insights: &[InsightData]) -> ExportResult<()> { + for exporter in &self.exporters { + if let Err(e) = exporter.export_batch(insights) { + tracing::warn!(error = %e, "Batch export failed for one sink"); + } + } + Ok(()) + } + + fn flush(&self) -> ExportResult<()> { + for exporter in &self.exporters { + if let Err(e) = exporter.flush() { + tracing::warn!(error = %e, "Flush failed for one sink"); + } + } + Ok(()) + } + + fn close(&self) -> ExportResult<()> { + for exporter in &self.exporters { + if let Err(e) = exporter.close() { + tracing::warn!(error = %e, "Close failed for one sink"); + } + } + Ok(()) + } + + fn clone_exporter(&self) -> Box { + let exporters: Vec<_> = self.exporters.iter().map(|e| e.clone_exporter()).collect(); + Box::new(CompositeExporter { exporters }) + } +} + +/// A callback-based exporter that invokes a function for each insight. +/// +/// # Example +/// +/// ```ignore +/// use rustapi_extras::insight::export::CallbackExporter; +/// +/// let exporter = CallbackExporter::new(|insight| { +/// println!("Received: {} {}", insight.method, insight.path); +/// }); +/// ``` +pub struct CallbackExporter +where + F: Fn(&InsightData) + Send + Sync + 'static, +{ + callback: Arc, +} + +impl CallbackExporter +where + F: Fn(&InsightData) + Send + Sync + 'static, +{ + /// Create a new callback exporter. + pub fn new(callback: F) -> Self { + Self { + callback: Arc::new(callback), + } + } +} + +impl Clone for CallbackExporter +where + F: Fn(&InsightData) + Send + Sync + 'static, +{ + fn clone(&self) -> Self { + Self { + callback: self.callback.clone(), + } + } +} + +impl InsightExporter for CallbackExporter +where + F: Fn(&InsightData) + Send + Sync + 'static, +{ + fn export(&self, insight: &InsightData) -> ExportResult<()> { + (self.callback)(insight); + Ok(()) + } + + fn clone_exporter(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; + use tempfile::tempdir; + + fn create_test_insight() -> InsightData { + InsightData::new("test-123", "GET", "/users") + .with_status(200) + .with_duration(Duration::from_millis(42)) + } + + #[test] + fn test_file_exporter() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test.jsonl"); + + let exporter = FileExporter::new(&path).unwrap(); + exporter.export(&create_test_insight()).unwrap(); + exporter.flush().unwrap(); + + let content = std::fs::read_to_string(&path).unwrap(); + assert!(content.contains("test-123")); + assert!(content.contains("GET")); + assert!(content.contains("/users")); + } + + #[test] + fn test_file_exporter_batch() { + let dir = tempdir().unwrap(); + let path = dir.path().join("batch.jsonl"); + + let exporter = FileExporter::new(&path).unwrap(); + let insights: Vec<_> = (0..5) + .map(|i| InsightData::new(format!("req-{}", i), "GET", "/test")) + .collect(); + + exporter.export_batch(&insights).unwrap(); + exporter.flush().unwrap(); + + let content = std::fs::read_to_string(&path).unwrap(); + let lines: Vec<_> = content.lines().collect(); + assert_eq!(lines.len(), 5); + } + + #[test] + fn test_callback_exporter() { + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + + let exporter = CallbackExporter::new(move |_insight| { + count_clone.fetch_add(1, Ordering::SeqCst); + }); + + exporter.export(&create_test_insight()).unwrap(); + exporter.export(&create_test_insight()).unwrap(); + + assert_eq!(count.load(Ordering::SeqCst), 2); + } + + #[test] + fn test_composite_exporter() { + let dir = tempdir().unwrap(); + let path = dir.path().join("composite.jsonl"); + + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + + let composite = CompositeExporter::new() + .with_exporter(FileExporter::new(&path).unwrap()) + .with_exporter(CallbackExporter::new(move |_| { + count_clone.fetch_add(1, Ordering::SeqCst); + })); + + composite.export(&create_test_insight()).unwrap(); + composite.flush().unwrap(); + + assert_eq!(count.load(Ordering::SeqCst), 1); + assert!(std::fs::read_to_string(&path).unwrap().contains("test-123")); + } + + #[test] + fn test_webhook_config() { + let config = WebhookConfig::new("https://example.com/insights") + .auth("Bearer token") + .header("X-Custom", "value") + .batch_size(50) + .timeout(60); + + assert_eq!(config.url, "https://example.com/insights"); + assert_eq!(config.auth_header, Some("Bearer token".to_string())); + assert_eq!(config.batch_size, 50); + assert_eq!(config.timeout_secs, 60); + } +} diff --git a/crates/rustapi-extras/src/insight/layer.rs b/crates/rustapi-extras/src/insight/layer.rs new file mode 100644 index 0000000..ceffbdb --- /dev/null +++ b/crates/rustapi-extras/src/insight/layer.rs @@ -0,0 +1,434 @@ +//! InsightLayer middleware for traffic data collection. +//! +//! This module provides the main middleware layer that captures +//! request and response information. + +use super::config::InsightConfig; +use super::data::InsightData; +use super::store::{InMemoryInsightStore, InsightStore}; +use bytes::Bytes; +use http::StatusCode; +use http_body_util::{BodyExt, Full}; +use rustapi_core::middleware::{BoxedNext, MiddlewareLayer}; +use rustapi_core::{Request, Response}; +use serde_json::json; +use std::future::Future; +use std::net::IpAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Instant; + +/// Traffic insight middleware layer. +/// +/// Collects request/response data for analytics, debugging, and monitoring. +/// +/// # Example +/// +/// ```ignore +/// use rustapi_extras::insight::{InsightLayer, InsightConfig}; +/// +/// let insight = InsightLayer::new() +/// .with_config(InsightConfig::new() +/// .sample_rate(0.5) +/// .skip_path("/health")); +/// +/// let app = RustApi::new() +/// .layer(insight) +/// .route("/api", get(handler)); +/// ``` +#[derive(Clone)] +pub struct InsightLayer { + config: Arc, + store: Arc, +} + +impl InsightLayer { + /// Create a new InsightLayer with default configuration. + pub fn new() -> Self { + let config = InsightConfig::new(); + let store = InMemoryInsightStore::new(config.store_capacity); + Self { + config: Arc::new(config), + store: Arc::new(store), + } + } + + /// Create an InsightLayer with custom configuration. + pub fn with_config(config: InsightConfig) -> Self { + let store = InMemoryInsightStore::new(config.store_capacity); + Self { + config: Arc::new(config), + store: Arc::new(store), + } + } + + /// Use a custom store implementation. + pub fn with_store(mut self, store: S) -> Self { + self.store = Arc::new(store); + self + } + + /// Get a reference to the insight store. + pub fn store(&self) -> &Arc { + &self.store + } + + /// Get a reference to the configuration. + pub fn config(&self) -> &InsightConfig { + &self.config + } + + /// Extract client IP from request headers. + fn extract_client_ip(req: &Request) -> String { + // Try X-Forwarded-For header first + if let Some(forwarded) = req.headers().get("x-forwarded-for") { + if let Ok(forwarded_str) = forwarded.to_str() { + if let Some(first_ip) = forwarded_str.split(',').next() { + let ip_str = first_ip.trim(); + if ip_str.parse::().is_ok() { + return ip_str.to_string(); + } + } + } + } + + // Try X-Real-IP header + if let Some(real_ip) = req.headers().get("x-real-ip") { + if let Ok(ip_str) = real_ip.to_str() { + let ip_str = ip_str.trim(); + if ip_str.parse::().is_ok() { + return ip_str.to_string(); + } + } + } + + // Default to localhost + "127.0.0.1".to_string() + } + + /// Extract request ID from headers or generate one. + fn extract_request_id(req: &Request) -> String { + // Try common request ID headers + for header_name in &["x-request-id", "x-correlation-id", "x-trace-id"] { + if let Some(value) = req.headers().get(*header_name) { + if let Ok(id) = value.to_str() { + return id.to_string(); + } + } + } + + // Generate a simple unique ID + use std::time::{SystemTime, UNIX_EPOCH}; + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + format!("insight_{:x}", timestamp) + } + + /// Extract query parameters from URI. + fn extract_query_params(uri: &http::Uri) -> std::collections::HashMap { + let mut params = std::collections::HashMap::new(); + if let Some(query) = uri.query() { + for pair in query.split('&') { + let mut parts = pair.splitn(2, '='); + if let (Some(key), Some(value)) = (parts.next(), parts.next()) { + params.insert( + urlencoding::decode(key).unwrap_or_default().into_owned(), + urlencoding::decode(value).unwrap_or_default().into_owned(), + ); + } + } + } + params + } + + /// Capture headers based on whitelist. + fn capture_headers( + headers: &http::HeaderMap, + config: &InsightConfig, + is_response: bool, + ) -> std::collections::HashMap { + let mut captured = std::collections::HashMap::new(); + + for (name, value) in headers.iter() { + let name_str = name.as_str(); + let should_capture = if is_response { + config.should_capture_response_header(name_str) + } else { + config.should_capture_header(name_str) + }; + + if should_capture { + if let Ok(value_str) = value.to_str() { + let final_value = if config.is_sensitive_header(name_str) { + "[REDACTED]".to_string() + } else { + value_str.to_string() + }; + captured.insert(name_str.to_string(), final_value); + } + } + } + + captured + } + + /// Check if body should be captured based on content type. + fn should_capture_body(headers: &http::HeaderMap, config: &InsightConfig) -> bool { + if let Some(content_type) = headers.get(http::header::CONTENT_TYPE) { + if let Ok(ct) = content_type.to_str() { + return config.is_capturable_content_type(ct); + } + } + false + } + + /// Create dashboard response with recent insights. + fn create_dashboard_response(store: &dyn InsightStore, limit: usize) -> Response { + let insights = store.get_recent(limit); + let body = json!({ + "insights": insights, + "count": insights.len(), + "total": store.count() + }); + + let body_bytes = serde_json::to_vec(&body).unwrap_or_default(); + http::Response::builder() + .status(StatusCode::OK) + .header(http::header::CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(body_bytes))) + .unwrap() + } + + /// Create stats response. + fn create_stats_response(store: &dyn InsightStore) -> Response { + let stats = store.get_stats(); + let body_bytes = serde_json::to_vec(&stats).unwrap_or_default(); + http::Response::builder() + .status(StatusCode::OK) + .header(http::header::CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(body_bytes))) + .unwrap() + } +} + +impl Default for InsightLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for InsightLayer { + fn call( + &self, + mut req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + let store = self.store.clone(); + + Box::pin(async move { + let path = req.uri().path().to_string(); + let method = req.method().to_string(); + + // Handle dashboard endpoints + if let Some(ref dashboard_path) = config.dashboard_path { + if path == *dashboard_path && method == "GET" { + // Parse limit from query string + let limit = InsightLayer::extract_query_params(req.uri()) + .get("limit") + .and_then(|v| v.parse().ok()) + .unwrap_or(100); + return InsightLayer::create_dashboard_response(store.as_ref(), limit); + } + } + + if let Some(ref stats_path) = config.stats_path { + if path == *stats_path && method == "GET" { + return InsightLayer::create_stats_response(store.as_ref()); + } + } + + // Check if this path should be skipped + if config.should_skip_path(&path) { + return next(req).await; + } + + // Check sampling + if !config.should_sample() { + return next(req).await; + } + + // Start timing + let start = Instant::now(); + + // Extract request info before calling next + let request_id = InsightLayer::extract_request_id(&req); + let client_ip = InsightLayer::extract_client_ip(&req); + let query_params = InsightLayer::extract_query_params(req.uri()); + let request_headers = InsightLayer::capture_headers(req.headers(), &config, false); + let capture_request_body = config.capture_request_body + && InsightLayer::should_capture_body(req.headers(), &config); + + // Get request body info if body capture is enabled + // Note: take_body() consumes the body, so we can only capture OR process, not both + // For insight purposes, we estimate size from content-length header when not capturing + let (request_size, request_body_capture) = if capture_request_body { + if let Some(body_bytes) = req.take_body() { + let size = body_bytes.len(); + let body_str = if size <= config.max_body_size { + String::from_utf8(body_bytes.to_vec()).ok() + } else { + None + }; + (size, body_str) + } else { + (0, None) + } + } else { + // Estimate size from Content-Length header + let size = req + .headers() + .get(http::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0); + (size, None) + }; + + // Call the next handler + let response = next(req).await; + + // Calculate duration + let duration = start.elapsed(); + let status = response.status().as_u16(); + + // Capture response info + let response_headers = InsightLayer::capture_headers(response.headers(), &config, true); + let capture_response_body = config.capture_response_body + && InsightLayer::should_capture_body(response.headers(), &config); + + // Buffer response body if needed + let (resp_parts, resp_body) = response.into_parts(); + let resp_body_bytes = match resp_body.collect().await { + Ok(collected) => collected.to_bytes(), + Err(_) => Bytes::new(), + }; + + let response_size = resp_body_bytes.len(); + let response_body_capture = + if capture_response_body && response_size <= config.max_body_size { + String::from_utf8(resp_body_bytes.to_vec()).ok() + } else { + None + }; + + // Create insight + let mut insight = InsightData::new(&request_id, &method, &path) + .with_status(status) + .with_duration(duration) + .with_client_ip(&client_ip) + .with_request_size(request_size) + .with_response_size(response_size); + + // Add query params + for (key, value) in query_params { + insight.add_query_param(key, value); + } + + // Add headers + for (key, value) in request_headers { + insight.add_request_header(key, value); + } + for (key, value) in response_headers { + insight.add_response_header(key, value); + } + + // Add body captures + if let Some(body) = request_body_capture { + insight.set_request_body(body); + } + if let Some(body) = response_body_capture { + insight.set_response_body(body); + } + + // Invoke callback if configured + if let Some(ref callback) = config.on_insight { + callback(&insight); + } + + // Store the insight + store.store(insight); + + // Reconstruct response + http::Response::from_parts(resp_parts, Full::new(resp_body_bytes)) + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_query_params() { + let uri: http::Uri = "/users?page=1&limit=10".parse().unwrap(); + let params = InsightLayer::extract_query_params(&uri); + + assert_eq!(params.get("page"), Some(&"1".to_string())); + assert_eq!(params.get("limit"), Some(&"10".to_string())); + } + + #[test] + fn test_capture_headers_with_whitelist() { + let mut headers = http::HeaderMap::new(); + headers.insert( + http::header::CONTENT_TYPE, + "application/json".parse().unwrap(), + ); + headers.insert(http::header::USER_AGENT, "test-agent".parse().unwrap()); + headers.insert( + http::header::AUTHORIZATION, + "Bearer secret".parse().unwrap(), + ); + + let config = InsightConfig::new().header_whitelist(vec!["content-type", "authorization"]); + + let captured = InsightLayer::capture_headers(&headers, &config, false); + + assert_eq!( + captured.get("content-type"), + Some(&"application/json".to_string()) + ); + assert_eq!( + captured.get("authorization"), + Some(&"[REDACTED]".to_string()) + ); + assert!(!captured.contains_key("user-agent")); + } + + #[test] + fn test_default_layer() { + let layer = InsightLayer::new(); + assert_eq!(layer.config().sample_rate, 1.0); + assert_eq!(layer.config().store_capacity, 1000); + } + + #[test] + fn test_custom_config() { + let config = InsightConfig::new() + .sample_rate(0.5) + .max_body_size(8192) + .skip_path("/health"); + + let layer = InsightLayer::with_config(config); + + assert_eq!(layer.config().sample_rate, 0.5); + assert_eq!(layer.config().max_body_size, 8192); + } +} diff --git a/crates/rustapi-extras/src/insight/mod.rs b/crates/rustapi-extras/src/insight/mod.rs new file mode 100644 index 0000000..6151b4a --- /dev/null +++ b/crates/rustapi-extras/src/insight/mod.rs @@ -0,0 +1,153 @@ +//! Traffic Insight - Opt-in traffic data collection middleware. +//! +//! This module provides comprehensive request/response monitoring for +//! analytics, debugging, and observability. +//! +//! # Features +//! +//! - **Request/Response Capture**: Collect method, path, status, duration, body sizes +//! - **Header Collection**: Configurable whitelist with sensitive data redaction +//! - **Body Capture**: Opt-in request/response body logging +//! - **Sampling**: Configurable sampling rate to reduce overhead +//! - **In-Memory Storage**: Ring buffer with configurable capacity +//! - **Dashboard Endpoints**: Built-in `/insights` and `/insights/stats` endpoints +//! - **Export**: File (JSON lines), webhook, and custom export sinks +//! +//! # Quick Start +//! +//! ```ignore +//! use rustapi_rs::prelude::*; +//! use rustapi_extras::insight::{InsightLayer, InsightConfig}; +//! +//! #[rustapi::main] +//! async fn main() { +//! let insight = InsightLayer::with_config( +//! InsightConfig::new() +//! .sample_rate(1.0) // Capture all requests +//! .skip_path("/health") // Skip health checks +//! .header_whitelist(vec!["content-type", "user-agent"]) +//! ); +//! +//! RustApi::new() +//! .layer(insight) +//! .mount(hello) +//! .run("127.0.0.1:3000") +//! .await +//! .unwrap(); +//! } +//! +//! #[rustapi::get("/hello")] +//! async fn hello() -> &'static str { +//! "Hello, World!" +//! } +//! ``` +//! +//! # Configuration +//! +//! Use [`InsightConfig`] to customize behavior: +//! +//! ```ignore +//! use rustapi_extras::insight::InsightConfig; +//! +//! let config = InsightConfig::new() +//! // Sampling +//! .sample_rate(0.1) // 10% of requests +//! +//! // Paths to exclude +//! .skip_path("/health") +//! .skip_path("/metrics") +//! .skip_path_prefix("/internal/") +//! +//! // Header capture +//! .header_whitelist(vec!["content-type", "user-agent", "accept"]) +//! .response_header_whitelist(vec!["content-type", "x-request-id"]) +//! +//! // Body capture (opt-in) +//! .capture_request_body(true) +//! .capture_response_body(true) +//! .max_body_size(8192) // 8KB max +//! +//! // Storage +//! .store_capacity(5000) // Keep 5000 entries +//! +//! // Endpoints +//! .dashboard_path(Some("/admin/insights")) +//! .stats_path(Some("/admin/insights/stats")) +//! +//! // Callback for custom processing +//! .on_insight(|insight| { +//! if insight.duration_ms > 1000 { +//! tracing::warn!("Slow request: {} {}ms", insight.path, insight.duration_ms); +//! } +//! }); +//! ``` +//! +//! # Dashboard Endpoints +//! +//! The middleware automatically exposes two endpoints: +//! +//! - `GET /insights` - Returns recent insights as JSON +//! - Query param: `?limit=100` to control number of results +//! - `GET /insights/stats` - Returns aggregated statistics +//! +//! These paths are configurable via [`InsightConfig`]. +//! +//! # Export +//! +//! Export insights to external systems: +//! +//! ```ignore +//! use rustapi_extras::insight::export::{FileExporter, WebhookConfig, WebhookExporter, CompositeExporter}; +//! +//! // File export (JSON lines format) +//! let file_exporter = FileExporter::new("./insights.jsonl")?; +//! +//! // Webhook export +//! let webhook = WebhookExporter::new( +//! WebhookConfig::new("https://logs.example.com/ingest") +//! .auth("Bearer my-token") +//! .batch_size(100) +//! ); +//! +//! // Multiple destinations +//! let composite = CompositeExporter::new() +//! .add(file_exporter) +//! .add(webhook); +//! ``` +//! +//! # Data Structure +//! +//! Each [`InsightData`] entry contains: +//! +//! - `request_id` - Unique request identifier +//! - `method` - HTTP method +//! - `path` - Request path +//! - `query_params` - Query string parameters +//! - `status` - Response status code +//! - `duration_ms` - Processing time in milliseconds +//! - `request_size` / `response_size` - Body sizes in bytes +//! - `timestamp` - Unix timestamp +//! - `client_ip` - Client IP address +//! - `request_headers` / `response_headers` - Captured headers +//! - `request_body` / `response_body` - Captured bodies (if enabled) +//! +//! # Statistics +//! +//! [`InsightStats`] provides aggregated metrics: +//! +//! - Request counts (total, successful, client errors, server errors) +//! - Duration statistics (avg, min, max, p95, p99) +//! - Bytes transferred (request/response) +//! - Breakdowns by route, method, and status code +//! - Requests per second + +mod config; +mod data; +pub mod export; +mod layer; +mod store; + +pub use config::InsightConfig; +pub use data::{InsightData, InsightStats}; +pub use layer::InsightLayer; +pub use store::{InMemoryInsightStore, InsightStore, NullInsightStore}; diff --git a/crates/rustapi-extras/src/insight/store.rs b/crates/rustapi-extras/src/insight/store.rs new file mode 100644 index 0000000..47c3758 --- /dev/null +++ b/crates/rustapi-extras/src/insight/store.rs @@ -0,0 +1,348 @@ +//! Storage backends for traffic insight data. +//! +//! This module provides the `InsightStore` trait and default implementations +//! for storing and retrieving insight data. + +use super::data::{InsightData, InsightStats}; +use dashmap::DashMap; +use std::collections::VecDeque; +use std::sync::{Arc, RwLock}; + +/// Trait for storing and retrieving insight data. +/// +/// Implement this trait to create custom storage backends (e.g., database, Redis). +pub trait InsightStore: Send + Sync + 'static { + /// Store a new insight entry. + fn store(&self, insight: InsightData); + + /// Get recent insights (up to `limit` entries). + fn get_recent(&self, limit: usize) -> Vec; + + /// Get all stored insights. + fn get_all(&self) -> Vec; + + /// Get insights filtered by path pattern. + fn get_by_path(&self, path_pattern: &str) -> Vec; + + /// Get insights filtered by status code range. + fn get_by_status(&self, min_status: u16, max_status: u16) -> Vec; + + /// Get aggregated statistics. + fn get_stats(&self) -> InsightStats; + + /// Clear all stored insights. + fn clear(&self); + + /// Get the current count of stored insights. + fn count(&self) -> usize; + + /// Clone this store into a boxed trait object. + fn clone_store(&self) -> Box; +} + +/// In-memory insight store using a ring buffer. +/// +/// This store keeps the most recent N insights in memory with thread-safe access. +/// +/// # Example +/// +/// ```ignore +/// use rustapi_extras::insight::InMemoryInsightStore; +/// +/// // Store up to 1000 insights +/// let store = InMemoryInsightStore::new(1000); +/// ``` +#[derive(Clone)] +pub struct InMemoryInsightStore { + /// Ring buffer holding insights + buffer: Arc>>, + /// Maximum capacity of the buffer + capacity: usize, + /// Index for quick lookup by request_id + index: Arc>, +} + +impl InMemoryInsightStore { + /// Create a new in-memory store with the specified capacity. + /// + /// # Arguments + /// + /// * `capacity` - Maximum number of insights to store (default: 1000) + pub fn new(capacity: usize) -> Self { + Self { + buffer: Arc::new(RwLock::new(VecDeque::with_capacity(capacity))), + capacity, + index: Arc::new(DashMap::new()), + } + } + + /// Create a new in-memory store with default capacity (1000 entries). + pub fn default_capacity() -> Self { + Self::new(1000) + } + + /// Get the maximum capacity of this store. + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Get an insight by request ID. + pub fn get_by_request_id(&self, request_id: &str) -> Option { + let idx = self.index.get(request_id)?; + let buffer = self.buffer.read().ok()?; + buffer.get(*idx).cloned() + } +} + +impl Default for InMemoryInsightStore { + fn default() -> Self { + Self::default_capacity() + } +} + +impl InsightStore for InMemoryInsightStore { + fn store(&self, insight: InsightData) { + let mut buffer = match self.buffer.write() { + Ok(b) => b, + Err(_) => return, // Poisoned lock, skip storage + }; + + // If at capacity, remove oldest entry + if buffer.len() >= self.capacity { + if let Some(old) = buffer.pop_front() { + self.index.remove(&old.request_id); + } + // Rebuild indices after removal (indices shift) + self.index.clear(); + for (i, item) in buffer.iter().enumerate() { + self.index.insert(item.request_id.clone(), i); + } + } + + // Add new insight + let idx = buffer.len(); + self.index.insert(insight.request_id.clone(), idx); + buffer.push_back(insight); + } + + fn get_recent(&self, limit: usize) -> Vec { + let buffer = match self.buffer.read() { + Ok(b) => b, + Err(_) => return Vec::new(), + }; + + buffer.iter().rev().take(limit).cloned().collect() + } + + fn get_all(&self) -> Vec { + let buffer = match self.buffer.read() { + Ok(b) => b, + Err(_) => return Vec::new(), + }; + + buffer.iter().cloned().collect() + } + + fn get_by_path(&self, path_pattern: &str) -> Vec { + let buffer = match self.buffer.read() { + Ok(b) => b, + Err(_) => return Vec::new(), + }; + + buffer + .iter() + .filter(|i| i.path.contains(path_pattern)) + .cloned() + .collect() + } + + fn get_by_status(&self, min_status: u16, max_status: u16) -> Vec { + let buffer = match self.buffer.read() { + Ok(b) => b, + Err(_) => return Vec::new(), + }; + + buffer + .iter() + .filter(|i| i.status >= min_status && i.status <= max_status) + .cloned() + .collect() + } + + fn get_stats(&self) -> InsightStats { + let all = self.get_all(); + InsightStats::from_insights(&all) + } + + fn clear(&self) { + if let Ok(mut buffer) = self.buffer.write() { + buffer.clear(); + } + self.index.clear(); + } + + fn count(&self) -> usize { + self.buffer.read().map(|b| b.len()).unwrap_or(0) + } + + fn clone_store(&self) -> Box { + Box::new(self.clone()) + } +} + +/// A no-op store that discards all insights. +/// +/// Useful for testing or when you only want callback-based processing. +#[derive(Clone, Copy, Default)] +pub struct NullInsightStore; + +impl InsightStore for NullInsightStore { + fn store(&self, _insight: InsightData) { + // Discard + } + + fn get_recent(&self, _limit: usize) -> Vec { + Vec::new() + } + + fn get_all(&self) -> Vec { + Vec::new() + } + + fn get_by_path(&self, _path_pattern: &str) -> Vec { + Vec::new() + } + + fn get_by_status(&self, _min_status: u16, _max_status: u16) -> Vec { + Vec::new() + } + + fn get_stats(&self) -> InsightStats { + InsightStats::default() + } + + fn clear(&self) {} + + fn count(&self) -> usize { + 0 + } + + fn clone_store(&self) -> Box { + Box::new(*self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + fn create_test_insight(id: &str, path: &str, status: u16) -> InsightData { + InsightData::new(id, "GET", path) + .with_status(status) + .with_duration(Duration::from_millis(10)) + } + + #[test] + fn test_in_memory_store_basic() { + let store = InMemoryInsightStore::new(10); + + store.store(create_test_insight("1", "/users", 200)); + store.store(create_test_insight("2", "/items", 201)); + + assert_eq!(store.count(), 2); + + let recent = store.get_recent(10); + assert_eq!(recent.len(), 2); + // Most recent first + assert_eq!(recent[0].request_id, "2"); + assert_eq!(recent[1].request_id, "1"); + } + + #[test] + fn test_ring_buffer_eviction() { + let store = InMemoryInsightStore::new(3); + + store.store(create_test_insight("1", "/a", 200)); + store.store(create_test_insight("2", "/b", 200)); + store.store(create_test_insight("3", "/c", 200)); + store.store(create_test_insight("4", "/d", 200)); // Should evict "1" + + assert_eq!(store.count(), 3); + + let all = store.get_all(); + let ids: Vec<_> = all.iter().map(|i| i.request_id.as_str()).collect(); + assert!(!ids.contains(&"1")); + assert!(ids.contains(&"2")); + assert!(ids.contains(&"3")); + assert!(ids.contains(&"4")); + } + + #[test] + fn test_filter_by_path() { + let store = InMemoryInsightStore::new(10); + + store.store(create_test_insight("1", "/users/123", 200)); + store.store(create_test_insight("2", "/items/456", 200)); + store.store(create_test_insight("3", "/users/789", 200)); + + let user_insights = store.get_by_path("/users"); + assert_eq!(user_insights.len(), 2); + } + + #[test] + fn test_filter_by_status() { + let store = InMemoryInsightStore::new(10); + + store.store(create_test_insight("1", "/a", 200)); + store.store(create_test_insight("2", "/b", 404)); + store.store(create_test_insight("3", "/c", 500)); + store.store(create_test_insight("4", "/d", 201)); + + let errors = store.get_by_status(400, 599); + assert_eq!(errors.len(), 2); + + let success = store.get_by_status(200, 299); + assert_eq!(success.len(), 2); + } + + #[test] + fn test_clear() { + let store = InMemoryInsightStore::new(10); + + store.store(create_test_insight("1", "/a", 200)); + store.store(create_test_insight("2", "/b", 200)); + + assert_eq!(store.count(), 2); + + store.clear(); + + assert_eq!(store.count(), 0); + assert!(store.get_all().is_empty()); + } + + #[test] + fn test_stats() { + let store = InMemoryInsightStore::new(10); + + store.store(create_test_insight("1", "/users", 200)); + store.store(create_test_insight("2", "/users", 201)); + store.store(create_test_insight("3", "/items", 404)); + + let stats = store.get_stats(); + + assert_eq!(stats.total_requests, 3); + assert_eq!(stats.successful_requests, 2); + assert_eq!(stats.client_errors, 1); + } + + #[test] + fn test_null_store() { + let store = NullInsightStore; + + store.store(create_test_insight("1", "/a", 200)); + + assert_eq!(store.count(), 0); + assert!(store.get_all().is_empty()); + } +} diff --git a/crates/rustapi-extras/src/lib.rs b/crates/rustapi-extras/src/lib.rs index 21275f1..7f095fb 100644 --- a/crates/rustapi-extras/src/lib.rs +++ b/crates/rustapi-extras/src/lib.rs @@ -13,6 +13,7 @@ //! - `config` - Configuration management with `.env` file support //! - `cookies` - Cookie parsing extractor //! - `sqlx` - SQLx database error conversion to ApiError +//! - `insight` - Traffic insight middleware for analytics and debugging //! - `extras` - Meta feature enabling jwt, cors, and rate-limit //! - `full` - All features enabled //! @@ -20,7 +21,7 @@ //! //! ```toml //! [dependencies] -//! rustapi-extras = { version = "0.1", features = ["jwt", "cors"] } +//! rustapi-extras = { version = "0.1", features = ["jwt", "cors", "insight"] } //! ``` #![warn(missing_docs)] @@ -46,6 +47,10 @@ pub mod config; #[cfg(feature = "sqlx")] pub mod sqlx; +// Traffic insight module +#[cfg(feature = "insight")] +pub mod insight; + // Re-exports for convenience #[cfg(feature = "jwt")] pub use jwt::{create_token, AuthUser, JwtError, JwtLayer, JwtValidation, ValidatedClaims}; @@ -63,3 +68,8 @@ pub use config::{ #[cfg(feature = "sqlx")] pub use sqlx::{convert_sqlx_error, SqlxErrorExt}; + +#[cfg(feature = "insight")] +pub use insight::{ + InMemoryInsightStore, InsightConfig, InsightData, InsightLayer, InsightStats, InsightStore, +}; diff --git a/crates/rustapi-rs/Cargo.toml b/crates/rustapi-rs/Cargo.toml index 35f0df7..27c244e 100644 --- a/crates/rustapi-rs/Cargo.toml +++ b/crates/rustapi-rs/Cargo.toml @@ -16,6 +16,8 @@ rustapi-core = { workspace = true, default-features = false } rustapi-macros = { workspace = true } rustapi-extras = { workspace = true, optional = true } rustapi-toon = { workspace = true, optional = true } +rustapi-ws = { workspace = true, optional = true } +rustapi-view = { workspace = true, optional = true } # Re-exports for user convenience tokio = { workspace = true } @@ -33,6 +35,10 @@ utoipa = { workspace = true } default = ["swagger-ui"] swagger-ui = ["rustapi-core/swagger-ui", "rustapi-openapi/swagger-ui"] +# Compression middleware +compression = ["rustapi-core/compression"] +compression-brotli = ["rustapi-core/compression-brotli"] + # Security and utility features (from rustapi-extras) jwt = ["dep:rustapi-extras", "rustapi-extras/jwt"] cors = ["dep:rustapi-extras", "rustapi-extras/cors"] @@ -40,10 +46,17 @@ rate-limit = ["dep:rustapi-extras", "rustapi-extras/rate-limit"] config = ["dep:rustapi-extras", "rustapi-extras/config"] cookies = ["dep:rustapi-extras", "rustapi-extras/cookies", "rustapi-core/cookies"] sqlx = ["dep:rustapi-extras", "rustapi-extras/sqlx"] +insight = ["dep:rustapi-extras", "rustapi-extras/insight"] # TOON format support toon = ["dep:rustapi-toon"] +# WebSocket support +ws = ["dep:rustapi-ws"] + +# Template engine support +view = ["dep:rustapi-view"] + # Meta features extras = ["jwt", "cors", "rate-limit"] -full = ["extras", "config", "cookies", "sqlx", "toon"] +full = ["extras", "config", "cookies", "sqlx", "toon", "insight", "compression", "ws", "view"] diff --git a/crates/rustapi-rs/src/lib.rs b/crates/rustapi-rs/src/lib.rs index 18dacc5..4b6863b 100644 --- a/crates/rustapi-rs/src/lib.rs +++ b/crates/rustapi-rs/src/lib.rs @@ -119,6 +119,59 @@ pub mod toon { pub use rustapi_toon::*; } +// Re-export WebSocket support (feature-gated) +#[cfg(feature = "ws")] +pub mod ws { + //! WebSocket support for real-time bidirectional communication + //! + //! This module provides WebSocket functionality through the `WebSocket` extractor, + //! enabling real-time communication patterns like chat, live updates, and streaming. + //! + //! # Example + //! + //! ```rust,ignore + //! use rustapi_rs::ws::{WebSocket, Message}; + //! + //! async fn websocket_handler(ws: WebSocket) -> impl IntoResponse { + //! ws.on_upgrade(|mut socket| async move { + //! while let Some(Ok(msg)) = socket.recv().await { + //! if let Message::Text(text) = msg { + //! socket.send(Message::Text(format!("Echo: {}", text))).await.ok(); + //! } + //! } + //! }) + //! } + //! ``` + pub use rustapi_ws::*; +} + +// Re-export View/Template support (feature-gated) +#[cfg(feature = "view")] +pub mod view { + //! Template engine support for server-side rendering + //! + //! This module provides Tera-based templating with the `View` response type, + //! enabling server-side HTML rendering with template inheritance and context. + //! + //! # Example + //! + //! ```rust,ignore + //! use rustapi_rs::view::{Templates, View, ContextBuilder}; + //! + //! #[derive(Clone)] + //! struct AppState { + //! templates: Templates, + //! } + //! + //! async fn index(State(state): State) -> View<()> { + //! View::new(&state.templates, "index.html") + //! .with("title", "Home") + //! .with("message", "Welcome!") + //! } + //! ``` + pub use rustapi_view::*; +} + /// Prelude module - import everything you need with `use rustapi_rs::prelude::*` pub mod prelude { // Core types @@ -133,6 +186,8 @@ pub mod prelude { post_route, put, put_route, + serve_dir, + sse_response, // Error handling ApiError, Body, @@ -146,6 +201,11 @@ pub mod prelude { IntoResponse, // Extractors Json, + KeepAlive, + // Multipart + Multipart, + MultipartConfig, + MultipartField, NoContent, Path, Query, @@ -168,12 +228,22 @@ pub mod prelude { Sse, SseEvent, State, + // Static files + StaticFile, + StaticFileConfig, StreamBody, TracingLayer, + UploadedFile, ValidatedJson, WithStatus, }; + // Compression middleware (feature-gated in core) + #[cfg(feature = "compression")] + pub use rustapi_core::middleware::{CompressionAlgorithm, CompressionConfig}; + #[cfg(feature = "compression")] + pub use rustapi_core::CompressionLayer; + // Cookies extractor (feature-gated in core) #[cfg(feature = "cookies")] pub use rustapi_core::Cookies; @@ -219,6 +289,14 @@ pub mod prelude { // TOON types (feature-gated) #[cfg(feature = "toon")] pub use rustapi_toon::{AcceptHeader, LlmResponse, Negotiate, OutputFormat, Toon}; + + // WebSocket types (feature-gated) + #[cfg(feature = "ws")] + pub use rustapi_ws::{Broadcast, Message, WebSocket, WebSocketStream}; + + // View/Template types (feature-gated) + #[cfg(feature = "view")] + pub use rustapi_view::{ContextBuilder, Templates, TemplatesConfig, View}; } #[cfg(test)] diff --git a/crates/rustapi-view/Cargo.toml b/crates/rustapi-view/Cargo.toml new file mode 100644 index 0000000..f90ac1c --- /dev/null +++ b/crates/rustapi-view/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "rustapi-view" +description = "Template rendering support for RustAPI - Server-side HTML with Tera templates" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +keywords = ["web", "framework", "api", "templates", "html"] +categories = ["web-programming::http-server", "template-engine"] +rust-version.workspace = true +readme = "README.md" + +[dependencies] +# Core dependencies +rustapi-core = { workspace = true } +rustapi-openapi = { workspace = true } + +# Template engine +tera = { workspace = true } + +# Serialization +serde = { workspace = true } +serde_json = { workspace = true } + +# HTTP types +http = { workspace = true } +http-body-util = { workspace = true } +bytes = { workspace = true } + +# Async +tokio = { workspace = true } + +# Utilities +thiserror = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/crates/rustapi-view/README.md b/crates/rustapi-view/README.md new file mode 100644 index 0000000..816aad3 --- /dev/null +++ b/crates/rustapi-view/README.md @@ -0,0 +1,100 @@ +# rustapi-view + +Template rendering support for RustAPI framework using Tera templates. + +## Features + +- **Tera Templates**: Full Tera template engine support +- **Type-Safe Context**: Build template context from Rust structs +- **Auto-Reload**: Development mode auto-reloads templates (optional) +- **Response Types**: `View` and `Html` response types +- **Layout Support**: Template inheritance and blocks + +## Quick Start + +```rust +use rustapi_rs::prelude::*; +use rustapi_view::{View, Templates}; +use serde::Serialize; + +#[derive(Serialize)] +struct HomeContext { + title: String, + user: Option, +} + +async fn home() -> View { + View::new("home.html", HomeContext { + title: "Welcome".to_string(), + user: Some("Alice".to_string()), + }) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize templates from directory + let templates = Templates::new("templates/**/*.html")?; + + RustApi::new() + .state(templates) + .route("/", get(home)) + .run("127.0.0.1:8080") + .await +} +``` + +## Template Files + +Create your templates in a `templates/` directory: + +```html + + + + + {% block title %}{{ title }}{% endblock %} + + + {% block content %}{% endblock %} + + + + +{% extends "base.html" %} + +{% block content %} +

Welcome{% if user %}, {{ user }}{% endif %}!

+{% endblock %} +``` + +## Context Building + +```rust +use rustapi_view::{Context, View}; + +// From struct (requires Serialize) +let view = View::new("template.html", MyStruct { ... }); + +// From context builder +let view = View::with_context("template.html", |ctx| { + ctx.insert("name", "Alice"); + ctx.insert("items", &vec!["a", "b", "c"]); +}); +``` + +## Configuration + +```rust +use rustapi_view::{Templates, TemplatesConfig}; + +// With configuration +let templates = Templates::with_config(TemplatesConfig { + glob: "templates/**/*.html".to_string(), + auto_reload: cfg!(debug_assertions), // Auto-reload in debug mode + strict_mode: true, // Fail on undefined variables +}); +``` + +## License + +MIT OR Apache-2.0 diff --git a/crates/rustapi-view/src/context.rs b/crates/rustapi-view/src/context.rs new file mode 100644 index 0000000..dd9db5c --- /dev/null +++ b/crates/rustapi-view/src/context.rs @@ -0,0 +1,135 @@ +//! Context builder for templates + +use serde::Serialize; +use tera::Context; + +/// Builder for constructing template context +/// +/// This provides a fluent API for building template context without +/// needing to create a struct for simple cases. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_view::ContextBuilder; +/// +/// let context = ContextBuilder::new() +/// .insert("name", "Alice") +/// .insert("age", 30) +/// .insert_if("admin", true, |_| user.is_admin()) +/// .build(); +/// ``` +pub struct ContextBuilder { + context: Context, +} + +impl ContextBuilder { + /// Create a new context builder + pub fn new() -> Self { + Self { + context: Context::new(), + } + } + + /// Insert a value into the context + pub fn insert(mut self, key: impl Into, value: &T) -> Self { + self.context.insert(key.into(), value); + self + } + + /// Insert a value if a condition is met + pub fn insert_if( + self, + key: impl Into, + value: &T, + condition: F, + ) -> Self + where + F: FnOnce(&T) -> bool, + { + if condition(value) { + self.insert(key, value) + } else { + self + } + } + + /// Insert a value if it's Some + pub fn insert_some( + self, + key: impl Into, + value: Option<&T>, + ) -> Self { + if let Some(v) = value { + self.insert(key, v) + } else { + self + } + } + + /// Extend with values from a serializable struct + pub fn extend(mut self, value: &T) -> Result { + let additional = Context::from_serialize(value)?; + self.context.extend(additional); + Ok(self) + } + + /// Build the context + pub fn build(self) -> Context { + self.context + } +} + +impl Default for ContextBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for Context { + fn from(builder: ContextBuilder) -> Self { + builder.build() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_context_builder() { + let context = ContextBuilder::new() + .insert("name", &"Alice") + .insert("age", &30) + .build(); + + assert!(context.contains_key("name")); + assert!(context.contains_key("age")); + } + + #[test] + fn test_insert_if() { + let show = true; + let context = ContextBuilder::new() + .insert_if("visible", &"yes", |_| show) + .insert_if("hidden", &"no", |_| !show) + .build(); + + assert!(context.contains_key("visible")); + assert!(!context.contains_key("hidden")); + } + + #[test] + fn test_insert_some() { + let name: Option<&str> = Some("Alice"); + let missing: Option<&str> = None; + + let context = ContextBuilder::new() + .insert_some("name", name) + .insert_some("missing", missing) + .build(); + + assert!(context.contains_key("name")); + assert!(!context.contains_key("missing")); + } +} diff --git a/crates/rustapi-view/src/error.rs b/crates/rustapi-view/src/error.rs new file mode 100644 index 0000000..2937464 --- /dev/null +++ b/crates/rustapi-view/src/error.rs @@ -0,0 +1,71 @@ +//! View error types + +use thiserror::Error; + +/// Error type for view/template operations +#[derive(Error, Debug)] +pub enum ViewError { + /// Template not found + #[error("Template not found: {0}")] + TemplateNotFound(String), + + /// Template rendering failed + #[error("Template rendering failed: {0}")] + RenderError(String), + + /// Template parsing failed + #[error("Template parsing failed: {0}")] + ParseError(String), + + /// Context serialization failed + #[error("Context serialization failed: {0}")] + SerializationError(String), + + /// Template engine not initialized + #[error("Template engine not initialized")] + NotInitialized, + + /// IO error + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + + /// Tera error + #[error("Tera error: {0}")] + Tera(#[from] tera::Error), +} + +impl ViewError { + /// Create a template not found error + pub fn not_found(template: impl Into) -> Self { + Self::TemplateNotFound(template.into()) + } + + /// Create a render error + pub fn render_error(msg: impl Into) -> Self { + Self::RenderError(msg.into()) + } + + /// Create a parse error + pub fn parse_error(msg: impl Into) -> Self { + Self::ParseError(msg.into()) + } + + /// Create a serialization error + pub fn serialization_error(msg: impl Into) -> Self { + Self::SerializationError(msg.into()) + } +} + +impl From for rustapi_core::ApiError { + fn from(err: ViewError) -> Self { + match err { + ViewError::TemplateNotFound(name) => { + rustapi_core::ApiError::internal(format!("Template not found: {}", name)) + } + ViewError::NotInitialized => { + rustapi_core::ApiError::internal("Template engine not initialized") + } + _ => rustapi_core::ApiError::internal(err.to_string()), + } + } +} diff --git a/crates/rustapi-view/src/lib.rs b/crates/rustapi-view/src/lib.rs new file mode 100644 index 0000000..9db76dd --- /dev/null +++ b/crates/rustapi-view/src/lib.rs @@ -0,0 +1,67 @@ +//! # rustapi-view +//! +//! Template rendering support for the RustAPI framework using Tera templates. +//! +//! This crate provides server-side HTML rendering with type-safe template contexts, +//! layout inheritance, and development-friendly features like auto-reload. +//! +//! ## Features +//! +//! - **Tera Templates**: Full Tera template engine support with filters, macros, and inheritance +//! - **Type-Safe Context**: Build template context from Rust structs via serde +//! - **Auto-Reload**: Development mode can auto-reload templates on change +//! - **Response Types**: `View` response type for rendering templates +//! - **Layout Support**: Template inheritance with blocks +//! +//! ## Quick Start +//! +//! ```rust,ignore +//! use rustapi_rs::prelude::*; +//! use rustapi_view::{View, Templates}; +//! use serde::Serialize; +//! +//! #[derive(Serialize)] +//! struct HomeContext { +//! title: String, +//! user: Option, +//! } +//! +//! async fn home(templates: State) -> View { +//! View::render(&templates, "home.html", HomeContext { +//! title: "Welcome".to_string(), +//! user: Some("Alice".to_string()), +//! }) +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! let templates = Templates::new("templates/**/*.html")?; +//! +//! RustApi::new() +//! .state(templates) +//! .route("/", get(home)) +//! .run("127.0.0.1:8080") +//! .await +//! } +//! ``` + +#![warn(missing_docs)] +#![warn(rustdoc::missing_crate_level_docs)] + +mod context; +mod error; +mod templates; +mod view; + +pub use context::ContextBuilder; +pub use error::ViewError; +pub use templates::{Templates, TemplatesConfig}; +pub use view::View; + +// Re-export tera types that users might need +pub use tera::Context; + +/// Prelude module for convenient imports +pub mod prelude { + pub use crate::{Context, ContextBuilder, Templates, TemplatesConfig, View, ViewError}; +} diff --git a/crates/rustapi-view/src/templates.rs b/crates/rustapi-view/src/templates.rs new file mode 100644 index 0000000..dacd720 --- /dev/null +++ b/crates/rustapi-view/src/templates.rs @@ -0,0 +1,248 @@ +//! Template engine wrapper + +use crate::ViewError; +use std::sync::Arc; +use tera::Tera; +use tokio::sync::RwLock; + +/// Configuration for the template engine +#[derive(Debug, Clone)] +pub struct TemplatesConfig { + /// Glob pattern for template files + pub glob: String, + /// Whether to auto-reload templates on change (development mode) + pub auto_reload: bool, + /// Whether to fail on undefined variables + pub strict_mode: bool, +} + +impl Default for TemplatesConfig { + fn default() -> Self { + Self { + glob: "templates/**/*.html".to_string(), + auto_reload: cfg!(debug_assertions), + strict_mode: false, + } + } +} + +impl TemplatesConfig { + /// Create a new config with the given glob pattern + pub fn new(glob: impl Into) -> Self { + Self { + glob: glob.into(), + ..Default::default() + } + } + + /// Set auto-reload behavior + pub fn auto_reload(mut self, enabled: bool) -> Self { + self.auto_reload = enabled; + self + } + + /// Set strict mode (fail on undefined variables) + pub fn strict_mode(mut self, enabled: bool) -> Self { + self.strict_mode = enabled; + self + } +} + +/// Template engine wrapper providing thread-safe template rendering +/// +/// This type wraps the Tera template engine and can be shared across +/// handlers via `State`. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_view::Templates; +/// +/// let templates = Templates::new("templates/**/*.html")?; +/// ``` +#[derive(Clone)] +pub struct Templates { + inner: Arc>, + config: TemplatesConfig, +} + +impl Templates { + /// Create a new template engine from a glob pattern + /// + /// The glob pattern specifies which files to load as templates. + /// Common patterns: + /// - `templates/**/*.html` - All HTML files in templates directory + /// - `views/*.tera` - All .tera files in views directory + /// + /// # Errors + /// + /// Returns an error if the glob pattern is invalid or templates fail to parse. + pub fn new(glob: impl Into) -> Result { + let config = TemplatesConfig::new(glob); + Self::with_config(config) + } + + /// Create a new template engine with configuration + pub fn with_config(config: TemplatesConfig) -> Result { + let mut tera = Tera::new(&config.glob)?; + + // Register custom filters/functions + register_builtin_filters(&mut tera); + + Ok(Self { + inner: Arc::new(RwLock::new(tera)), + config, + }) + } + + /// Create an empty template engine (for adding templates programmatically) + pub fn empty() -> Self { + Self { + inner: Arc::new(RwLock::new(Tera::default())), + config: TemplatesConfig::default(), + } + } + + /// Add a template from a string + pub async fn add_template( + &self, + name: impl Into, + content: impl Into, + ) -> Result<(), ViewError> { + let mut tera = self.inner.write().await; + tera.add_raw_template(&name.into(), &content.into())?; + Ok(()) + } + + /// Render a template with the given context + pub async fn render( + &self, + template: &str, + context: &tera::Context, + ) -> Result { + // If auto-reload is enabled and in debug mode, try to reload + #[cfg(debug_assertions)] + if self.config.auto_reload { + let mut tera = self.inner.write().await; + if let Err(e) = tera.full_reload() { + tracing::warn!("Template reload failed: {}", e); + } + } + + let tera = self.inner.read().await; + tera.render(template, context).map_err(ViewError::from) + } + + /// Render a template with a serializable context + pub async fn render_with( + &self, + template: &str, + data: &T, + ) -> Result { + let context = tera::Context::from_serialize(data) + .map_err(|e| ViewError::serialization_error(e.to_string()))?; + self.render(template, &context).await + } + + /// Check if a template exists + pub async fn has_template(&self, name: &str) -> bool { + let tera = self.inner.read().await; + let result = tera.get_template_names().any(|n| n == name); + result + } + + /// Get all template names + pub async fn template_names(&self) -> Vec { + let tera = self.inner.read().await; + tera.get_template_names().map(String::from).collect() + } + + /// Reload all templates from disk + pub async fn reload(&self) -> Result<(), ViewError> { + let mut tera = self.inner.write().await; + tera.full_reload()?; + Ok(()) + } + + /// Get the configuration + pub fn config(&self) -> &TemplatesConfig { + &self.config + } +} + +/// Register built-in template filters +fn register_builtin_filters(tera: &mut Tera) { + // JSON filter for debugging + tera.register_filter( + "json_pretty", + |value: &tera::Value, _: &std::collections::HashMap| { + serde_json::to_string_pretty(value) + .map(tera::Value::String) + .map_err(|e| tera::Error::msg(e.to_string())) + }, + ); + + // Truncate string + tera.register_filter( + "truncate_words", + |value: &tera::Value, args: &std::collections::HashMap| { + let s = tera::try_get_value!("truncate_words", "value", String, value); + let length = match args.get("length") { + Some(val) => tera::try_get_value!("truncate_words", "length", usize, val), + None => 50, + }; + let end = match args.get("end") { + Some(val) => tera::try_get_value!("truncate_words", "end", String, val), + None => "...".to_string(), + }; + + let words: Vec<&str> = s.split_whitespace().collect(); + if words.len() <= length { + Ok(tera::Value::String(s)) + } else { + let truncated: String = words[..length].join(" "); + Ok(tera::Value::String(format!("{}{}", truncated, end))) + } + }, + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_empty_templates() { + let templates = Templates::empty(); + templates + .add_template("test", "Hello, {{ name }}!") + .await + .unwrap(); + + let mut ctx = tera::Context::new(); + ctx.insert("name", "World"); + + let result = templates.render("test", &ctx).await.unwrap(); + assert_eq!(result, "Hello, World!"); + } + + #[tokio::test] + async fn test_render_with_struct() { + #[derive(serde::Serialize)] + struct Data { + name: String, + } + + let templates = Templates::empty(); + templates + .add_template("test", "Hello, {{ name }}!") + .await + .unwrap(); + + let data = Data { + name: "Alice".to_string(), + }; + let result = templates.render_with("test", &data).await.unwrap(); + assert_eq!(result, "Hello, Alice!"); + } +} diff --git a/crates/rustapi-view/src/view.rs b/crates/rustapi-view/src/view.rs new file mode 100644 index 0000000..1449081 --- /dev/null +++ b/crates/rustapi-view/src/view.rs @@ -0,0 +1,174 @@ +//! View response type + +use crate::{Templates, ViewError}; +use bytes::Bytes; +use http::{header, Response, StatusCode}; +use http_body_util::Full; +use rustapi_core::IntoResponse; +use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef}; +use serde::Serialize; +use std::collections::HashMap; +use std::marker::PhantomData; + +/// A response that renders a template with a context +/// +/// This is the primary way to render HTML templates in RustAPI handlers. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_view::{View, Templates}; +/// use serde::Serialize; +/// +/// #[derive(Serialize)] +/// struct HomeContext { +/// title: String, +/// } +/// +/// async fn home(templates: State) -> View { +/// View::render(&templates, "home.html", HomeContext { +/// title: "Home".to_string(), +/// }) +/// } +/// ``` +pub struct View { + /// The rendered HTML content + content: Result, + /// Status code (default 200) + status: StatusCode, + /// Phantom data for the context type + _phantom: PhantomData, +} + +impl View { + /// Create a view by rendering a template with a serializable context + /// + /// This is an async operation that renders the template immediately. + /// For deferred rendering, use `View::deferred`. + pub async fn render(templates: &Templates, template: &str, context: T) -> Self { + let content = templates.render_with(template, &context).await; + Self { + content, + status: StatusCode::OK, + _phantom: PhantomData, + } + } + + /// Create a view with a specific status code + pub async fn render_with_status( + templates: &Templates, + template: &str, + context: T, + status: StatusCode, + ) -> Self { + let content = templates.render_with(template, &context).await; + Self { + content, + status, + _phantom: PhantomData, + } + } + + /// Create a view from pre-rendered HTML + pub fn from_html(html: impl Into) -> Self { + Self { + content: Ok(html.into()), + status: StatusCode::OK, + _phantom: PhantomData, + } + } + + /// Create an error view + pub fn error(err: ViewError) -> Self { + Self { + content: Err(err), + status: StatusCode::INTERNAL_SERVER_ERROR, + _phantom: PhantomData, + } + } + + /// Set the status code + pub fn status(mut self, status: StatusCode) -> Self { + self.status = status; + self + } +} + +impl View<()> { + /// Create a view by rendering a template with a tera Context + pub async fn render_context( + templates: &Templates, + template: &str, + context: &tera::Context, + ) -> Self { + let content = templates.render(template, context).await; + Self { + content, + status: StatusCode::OK, + _phantom: PhantomData, + } + } +} + +impl IntoResponse for View { + fn into_response(self) -> Response> { + match self.content { + Ok(html) => Response::builder() + .status(self.status) + .header(header::CONTENT_TYPE, "text/html; charset=utf-8") + .body(Full::new(Bytes::from(html))) + .unwrap(), + Err(err) => { + tracing::error!("Template rendering failed: {}", err); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(header::CONTENT_TYPE, "text/html; charset=utf-8") + .body(Full::new(Bytes::from( + "Error\ +

500 Internal Server Error

\ +

Template rendering failed

", + ))) + .unwrap() + } + } + } +} + +impl ResponseModifier for View { + fn update_response(op: &mut Operation) { + op.responses.insert( + "200".to_string(), + ResponseSpec { + description: "HTML Content".to_string(), + content: { + let mut map = HashMap::new(); + map.insert( + "text/html".to_string(), + MediaType { + schema: SchemaRef::Inline(serde_json::json!({ "type": "string" })), + }, + ); + Some(map) + }, + }, + ); + } +} + +/// Helper for creating views with different status codes +impl View { + /// Create a 404 Not Found view + pub async fn not_found(templates: &Templates, template: &str, context: T) -> Self { + Self::render_with_status(templates, template, context, StatusCode::NOT_FOUND).await + } + + /// Create a 403 Forbidden view + pub async fn forbidden(templates: &Templates, template: &str, context: T) -> Self { + Self::render_with_status(templates, template, context, StatusCode::FORBIDDEN).await + } + + /// Create a 401 Unauthorized view + pub async fn unauthorized(templates: &Templates, template: &str, context: T) -> Self { + Self::render_with_status(templates, template, context, StatusCode::UNAUTHORIZED).await + } +} diff --git a/crates/rustapi-ws/Cargo.toml b/crates/rustapi-ws/Cargo.toml new file mode 100644 index 0000000..a2d1b30 --- /dev/null +++ b/crates/rustapi-ws/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "rustapi-ws" +description = "WebSocket support for RustAPI - Real-time bidirectional communication" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +keywords = ["web", "framework", "api", "websocket", "real-time"] +categories = ["web-programming::http-server", "web-programming::websocket"] +rust-version.workspace = true +readme = "README.md" + +[dependencies] +# Core dependencies +rustapi-core = { workspace = true } +rustapi-openapi = { workspace = true } + +# WebSocket implementation +tokio-tungstenite = "0.24" +tungstenite = "0.24" + +# Async runtime +tokio = { workspace = true, features = ["sync", "macros"] } +futures-util = { workspace = true } + +# HTTP types +http = { workspace = true } +http-body-util = { workspace = true } +bytes = { workspace = true } +hyper = { workspace = true } +hyper-util = { workspace = true } + +# Serialization (optional, for JSON messages) +serde = { workspace = true } +serde_json = { workspace = true } + +# Utilities +thiserror = { workspace = true } +tracing = { workspace = true } +pin-project-lite = { workspace = true } + +# SHA-1 for WebSocket handshake +sha1 = "0.10" +base64 = "0.22" + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } diff --git a/crates/rustapi-ws/README.md b/crates/rustapi-ws/README.md new file mode 100644 index 0000000..f1f05f3 --- /dev/null +++ b/crates/rustapi-ws/README.md @@ -0,0 +1,118 @@ +# rustapi-ws + +WebSocket support for RustAPI framework, enabling real-time bidirectional communication. + +## Features + +- **WebSocket Upgrade**: Seamless HTTP to WebSocket upgrade +- **Message Types**: Text, Binary, Ping/Pong support +- **Type-Safe Messages**: JSON serialization/deserialization +- **Connection Management**: Clean connection lifecycle handling +- **Broadcast Support**: Send messages to multiple clients + +## Quick Start + +```rust +use rustapi_rs::prelude::*; +use rustapi_ws::{WebSocket, Message}; + +async fn ws_handler(ws: WebSocket) -> impl IntoResponse { + ws.on_upgrade(|socket| async move { + let (mut sender, mut receiver) = socket.split(); + + while let Some(msg) = receiver.next().await { + match msg { + Ok(Message::Text(text)) => { + // Echo the message back + let _ = sender.send(Message::Text(format!("Echo: {}", text))).await; + } + Ok(Message::Close(_)) => break, + _ => {} + } + } + }) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + RustApi::new() + .route("/ws", get(ws_handler)) + .run("127.0.0.1:8080") + .await +} +``` + +## Message Types + +```rust +use rustapi_ws::Message; + +// Text message +let msg = Message::Text("Hello".to_string()); + +// Binary message +let msg = Message::Binary(vec![1, 2, 3]); + +// JSON message (requires serde) +let msg = Message::json(&MyStruct { field: "value" })?; + +// Ping/Pong +let msg = Message::Ping(vec![]); +let msg = Message::Pong(vec![]); + +// Close connection +let msg = Message::Close(Some(CloseFrame { + code: CloseCode::Normal, + reason: "Goodbye".into(), +})); +``` + +## Connection State + +```rust +use rustapi_ws::{WebSocket, WebSocketState}; + +async fn stateful_ws(ws: WebSocket, State(app_state): State) -> impl IntoResponse { + ws.on_upgrade(move |socket| async move { + // Access application state within the WebSocket handler + let config = &app_state.config; + // ... + }) +} +``` + +## Broadcasting + +```rust +use rustapi_ws::{Broadcast, Message}; +use std::sync::Arc; + +// Create a broadcast channel +let broadcast = Arc::new(Broadcast::new()); + +// In your WebSocket handler +async fn ws_handler(ws: WebSocket, State(broadcast): State>) -> impl IntoResponse { + ws.on_upgrade(move |socket| async move { + let (sender, mut receiver) = socket.split(); + + // Subscribe to broadcasts + let mut rx = broadcast.subscribe(); + + // Handle incoming messages and broadcasts + tokio::select! { + // Receive from client + msg = receiver.next() => { + // Handle message + } + // Receive broadcast + msg = rx.recv() => { + // Forward to client + } + } + }) +} +``` + +## License + +MIT OR Apache-2.0 diff --git a/crates/rustapi-ws/src/broadcast.rs b/crates/rustapi-ws/src/broadcast.rs new file mode 100644 index 0000000..c502f1d --- /dev/null +++ b/crates/rustapi-ws/src/broadcast.rs @@ -0,0 +1,154 @@ +//! Broadcast channel for WebSocket messages + +use crate::Message; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::broadcast; + +/// A broadcast channel for sending messages to multiple WebSocket clients +/// +/// This is useful for implementing pub/sub patterns, chat rooms, or any +/// scenario where you need to send the same message to multiple clients. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_ws::{Broadcast, Message}; +/// use std::sync::Arc; +/// +/// let broadcast = Arc::new(Broadcast::new()); +/// +/// // Subscribe to receive messages +/// let mut rx = broadcast.subscribe(); +/// +/// // Send a message to all subscribers +/// broadcast.send(Message::text("Hello everyone!")); +/// +/// // Receive the message +/// let msg = rx.recv().await.unwrap(); +/// ``` +#[derive(Clone)] +pub struct Broadcast { + sender: broadcast::Sender, + subscriber_count: Arc, +} + +impl Broadcast { + /// Create a new broadcast channel with default capacity (100 messages) + pub fn new() -> Self { + Self::with_capacity(100) + } + + /// Create a new broadcast channel with specified capacity + pub fn with_capacity(capacity: usize) -> Self { + let (sender, _) = broadcast::channel(capacity); + Self { + sender, + subscriber_count: Arc::new(AtomicUsize::new(0)), + } + } + + /// Subscribe to receive broadcast messages + pub fn subscribe(&self) -> BroadcastReceiver { + self.subscriber_count.fetch_add(1, Ordering::SeqCst); + BroadcastReceiver { + inner: self.sender.subscribe(), + subscriber_count: self.subscriber_count.clone(), + } + } + + /// Send a message to all subscribers + /// + /// Returns the number of receivers that received the message. + /// Returns 0 if there are no active subscribers. + pub fn send(&self, msg: Message) -> usize { + self.sender.send(msg).unwrap_or(0) + } + + /// Send a text message to all subscribers + pub fn send_text(&self, text: impl Into) -> usize { + self.send(Message::text(text)) + } + + /// Send a JSON message to all subscribers + pub fn send_json( + &self, + value: &T, + ) -> Result { + let msg = Message::json(value)?; + Ok(self.send(msg)) + } + + /// Get the current number of subscribers + pub fn subscriber_count(&self) -> usize { + self.subscriber_count.load(Ordering::SeqCst) + } + + /// Check if there are any active subscribers + pub fn has_subscribers(&self) -> bool { + self.subscriber_count() > 0 + } +} + +impl Default for Broadcast { + fn default() -> Self { + Self::new() + } +} + +/// Receiver for broadcast messages +pub struct BroadcastReceiver { + inner: broadcast::Receiver, + subscriber_count: Arc, +} + +impl BroadcastReceiver { + /// Receive the next broadcast message + /// + /// Returns `None` if the broadcast channel is closed. + /// Returns `Err` if messages were missed due to slow consumption. + pub async fn recv(&mut self) -> Option> { + match self.inner.recv().await { + Ok(msg) => Some(Ok(msg)), + Err(broadcast::error::RecvError::Closed) => None, + Err(broadcast::error::RecvError::Lagged(count)) => { + Some(Err(BroadcastRecvError::Lagged(count))) + } + } + } + + /// Try to receive a message without waiting + pub fn try_recv(&mut self) -> Option> { + match self.inner.try_recv() { + Ok(msg) => Some(Ok(msg)), + Err(broadcast::error::TryRecvError::Empty) => None, + Err(broadcast::error::TryRecvError::Closed) => None, + Err(broadcast::error::TryRecvError::Lagged(count)) => { + Some(Err(BroadcastRecvError::Lagged(count))) + } + } + } +} + +impl Drop for BroadcastReceiver { + fn drop(&mut self) { + self.subscriber_count.fetch_sub(1, Ordering::SeqCst); + } +} + +/// Error when receiving broadcast messages +#[derive(Debug, Clone, Copy)] +pub enum BroadcastRecvError { + /// Some messages were missed because the receiver is too slow + Lagged(u64), +} + +impl std::fmt::Display for BroadcastRecvError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Lagged(count) => write!(f, "Lagged behind by {} messages", count), + } + } +} + +impl std::error::Error for BroadcastRecvError {} diff --git a/crates/rustapi-ws/src/error.rs b/crates/rustapi-ws/src/error.rs new file mode 100644 index 0000000..7d9e25e --- /dev/null +++ b/crates/rustapi-ws/src/error.rs @@ -0,0 +1,98 @@ +//! WebSocket error types + +use thiserror::Error; + +/// Error type for WebSocket operations +#[derive(Error, Debug)] +pub enum WebSocketError { + /// Invalid WebSocket upgrade request + #[error("Invalid WebSocket upgrade request: {0}")] + InvalidUpgrade(String), + + /// WebSocket handshake failed + #[error("WebSocket handshake failed: {0}")] + HandshakeFailed(String), + + /// Connection closed unexpectedly + #[error("Connection closed unexpectedly")] + ConnectionClosed, + + /// Failed to send message + #[error("Failed to send message: {0}")] + SendFailed(String), + + /// Failed to receive message + #[error("Failed to receive message: {0}")] + ReceiveFailed(String), + + /// Message serialization error + #[error("Message serialization error: {0}")] + SerializationError(String), + + /// Message deserialization error + #[error("Message deserialization error: {0}")] + DeserializationError(String), + + /// Protocol error + #[error("WebSocket protocol error: {0}")] + ProtocolError(String), + + /// IO error + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + + /// Tungstenite error + #[error("WebSocket error: {0}")] + Tungstenite(#[from] tungstenite::Error), +} + +impl WebSocketError { + /// Create an invalid upgrade error + pub fn invalid_upgrade(msg: impl Into) -> Self { + Self::InvalidUpgrade(msg.into()) + } + + /// Create a handshake failed error + pub fn handshake_failed(msg: impl Into) -> Self { + Self::HandshakeFailed(msg.into()) + } + + /// Create a send failed error + pub fn send_failed(msg: impl Into) -> Self { + Self::SendFailed(msg.into()) + } + + /// Create a receive failed error + pub fn receive_failed(msg: impl Into) -> Self { + Self::ReceiveFailed(msg.into()) + } + + /// Create a serialization error + pub fn serialization_error(msg: impl Into) -> Self { + Self::SerializationError(msg.into()) + } + + /// Create a deserialization error + pub fn deserialization_error(msg: impl Into) -> Self { + Self::DeserializationError(msg.into()) + } + + /// Create a protocol error + pub fn protocol_error(msg: impl Into) -> Self { + Self::ProtocolError(msg.into()) + } +} + +impl From for rustapi_core::ApiError { + fn from(err: WebSocketError) -> Self { + match err { + WebSocketError::InvalidUpgrade(msg) => { + rustapi_core::ApiError::bad_request(format!("WebSocket upgrade failed: {}", msg)) + } + WebSocketError::HandshakeFailed(msg) => { + rustapi_core::ApiError::bad_request(format!("WebSocket handshake failed: {}", msg)) + } + _ => rustapi_core::ApiError::internal(err.to_string()), + } + } +} diff --git a/crates/rustapi-ws/src/extractor.rs b/crates/rustapi-ws/src/extractor.rs new file mode 100644 index 0000000..b0f8c55 --- /dev/null +++ b/crates/rustapi-ws/src/extractor.rs @@ -0,0 +1,97 @@ +//! WebSocket extractor + +use crate::upgrade::{validate_upgrade_request, WebSocketUpgrade}; +use rustapi_core::{ApiError, FromRequestParts, Request, Result}; +use rustapi_openapi::{Operation, OperationModifier}; + +/// WebSocket extractor for upgrading HTTP connections to WebSocket +/// +/// Use this extractor in your handler to initiate a WebSocket upgrade. +/// The extractor validates the upgrade request and returns a `WebSocket` +/// that can be used to set up the connection handler. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_ws::{WebSocket, Message}; +/// +/// async fn ws_handler(ws: WebSocket) -> impl IntoResponse { +/// ws.on_upgrade(|socket| async move { +/// let (mut sender, mut receiver) = socket.split(); +/// +/// while let Some(Ok(msg)) = receiver.next().await { +/// match msg { +/// Message::Text(text) => { +/// // Echo back +/// let _ = sender.send(Message::text(format!("Echo: {}", text))).await; +/// } +/// Message::Close(_) => break, +/// _ => {} +/// } +/// } +/// }) +/// } +/// ``` +pub struct WebSocket { + sec_key: String, + protocols: Vec, +} + +impl WebSocket { + /// Create a WebSocket upgrade response with a handler + /// + /// The provided callback will be called with the established WebSocket + /// stream once the upgrade is complete. + pub fn on_upgrade(self, callback: F) -> WebSocketUpgrade + where + F: FnOnce(crate::WebSocketStream) -> Fut + Send + 'static, + Fut: std::future::Future + Send + 'static, + { + let upgrade = WebSocketUpgrade::new(self.sec_key); + + // If protocols were requested, select the first one + let upgrade = if let Some(protocol) = self.protocols.first() { + upgrade.protocol(protocol) + } else { + upgrade + }; + + upgrade.on_upgrade(callback) + } + + /// Get the requested protocols + pub fn protocols(&self) -> &[String] { + &self.protocols + } + + /// Check if a specific protocol was requested + pub fn has_protocol(&self, protocol: &str) -> bool { + self.protocols.iter().any(|p| p == protocol) + } +} + +impl FromRequestParts for WebSocket { + fn from_request_parts(req: &Request) -> Result { + let headers = req.headers(); + let method = req.method(); + + // Validate the upgrade request + let sec_key = validate_upgrade_request(method, headers).map_err(ApiError::from)?; + + // Parse requested protocols + let protocols = headers + .get("Sec-WebSocket-Protocol") + .and_then(|v| v.to_str().ok()) + .map(|s| s.split(',').map(|p| p.trim().to_string()).collect()) + .unwrap_or_default(); + + Ok(Self { sec_key, protocols }) + } +} + +impl OperationModifier for WebSocket { + fn update_operation(_op: &mut Operation) { + // WebSocket endpoints don't have regular request body parameters + // The upgrade is indicated by the response + } +} diff --git a/crates/rustapi-ws/src/lib.rs b/crates/rustapi-ws/src/lib.rs new file mode 100644 index 0000000..e944f0a --- /dev/null +++ b/crates/rustapi-ws/src/lib.rs @@ -0,0 +1,73 @@ +//! # rustapi-ws +//! +//! WebSocket support for the RustAPI framework. +//! +//! This crate provides WebSocket upgrade handling, message types, and utilities +//! for building real-time bidirectional communication in your RustAPI applications. +//! +//! ## Features +//! +//! - **WebSocket Upgrade**: Seamless HTTP to WebSocket upgrade via the `WebSocket` extractor +//! - **Message Types**: Support for Text, Binary, Ping/Pong messages +//! - **Type-Safe JSON**: Serialize/deserialize JSON messages with serde +//! - **Connection Management**: Clean connection lifecycle with proper close handling +//! - **Broadcast Support**: Send messages to multiple connected clients +//! +//! ## Quick Start +//! +//! ```rust,ignore +//! use rustapi_rs::prelude::*; +//! use rustapi_ws::{WebSocket, Message}; +//! +//! async fn ws_handler(ws: WebSocket) -> impl IntoResponse { +//! ws.on_upgrade(|socket| async move { +//! let (mut sender, mut receiver) = socket.split(); +//! +//! while let Some(msg) = receiver.next().await { +//! match msg { +//! Ok(Message::Text(text)) => { +//! let _ = sender.send(Message::Text(format!("Echo: {}", text))).await; +//! } + +// Allow large error types in Results - WebSocket errors include tungstenite errors which are large +#![allow(clippy::result_large_err)] +//! Ok(Message::Close(_)) => break, +//! _ => {} +//! } +//! } +//! }) +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! RustApi::new() +//! .route("/ws", get(ws_handler)) +//! .run("127.0.0.1:8080") +//! .await +//! } +//! ``` + +#![warn(missing_docs)] +#![warn(rustdoc::missing_crate_level_docs)] + +mod broadcast; +mod error; +mod extractor; +mod message; +mod socket; +mod upgrade; + +pub use broadcast::Broadcast; +pub use error::WebSocketError; +pub use extractor::WebSocket; +pub use message::{CloseCode, CloseFrame, Message}; +pub use socket::{WebSocketReceiver, WebSocketSender, WebSocketStream}; +pub use upgrade::WebSocketUpgrade; + +/// Prelude module for convenient imports +pub mod prelude { + pub use crate::{ + Broadcast, CloseCode, CloseFrame, Message, WebSocket, WebSocketError, WebSocketReceiver, + WebSocketSender, WebSocketStream, WebSocketUpgrade, + }; +} diff --git a/crates/rustapi-ws/src/message.rs b/crates/rustapi-ws/src/message.rs new file mode 100644 index 0000000..7a1478a --- /dev/null +++ b/crates/rustapi-ws/src/message.rs @@ -0,0 +1,324 @@ +//! WebSocket message types + +use serde::{de::DeserializeOwned, Serialize}; +use std::borrow::Cow; + +/// WebSocket message type +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Message { + /// Text message (UTF-8 encoded) + Text(String), + /// Binary message + Binary(Vec), + /// Ping message + Ping(Vec), + /// Pong message + Pong(Vec), + /// Close message + Close(Option), +} + +impl Message { + /// Create a text message + pub fn text(text: impl Into) -> Self { + Self::Text(text.into()) + } + + /// Create a binary message + pub fn binary(data: impl Into>) -> Self { + Self::Binary(data.into()) + } + + /// Create a ping message + pub fn ping(data: impl Into>) -> Self { + Self::Ping(data.into()) + } + + /// Create a pong message + pub fn pong(data: impl Into>) -> Self { + Self::Pong(data.into()) + } + + /// Create a close message + pub fn close() -> Self { + Self::Close(None) + } + + /// Create a close message with a frame + pub fn close_with(code: CloseCode, reason: impl Into) -> Self { + Self::Close(Some(CloseFrame { + code, + reason: Cow::Owned(reason.into()), + })) + } + + /// Create a JSON text message from a serializable type + pub fn json(value: &T) -> Result { + serde_json::to_string(value) + .map(Self::Text) + .map_err(|e| crate::WebSocketError::serialization_error(e.to_string())) + } + + /// Try to deserialize a text message as JSON + pub fn as_json(&self) -> Result { + match self { + Self::Text(text) => serde_json::from_str(text) + .map_err(|e| crate::WebSocketError::deserialization_error(e.to_string())), + _ => Err(crate::WebSocketError::deserialization_error( + "Expected text message for JSON deserialization", + )), + } + } + + /// Check if this is a text message + pub fn is_text(&self) -> bool { + matches!(self, Self::Text(_)) + } + + /// Check if this is a binary message + pub fn is_binary(&self) -> bool { + matches!(self, Self::Binary(_)) + } + + /// Check if this is a ping message + pub fn is_ping(&self) -> bool { + matches!(self, Self::Ping(_)) + } + + /// Check if this is a pong message + pub fn is_pong(&self) -> bool { + matches!(self, Self::Pong(_)) + } + + /// Check if this is a close message + pub fn is_close(&self) -> bool { + matches!(self, Self::Close(_)) + } + + /// Get the text content if this is a text message + pub fn as_text(&self) -> Option<&str> { + match self { + Self::Text(text) => Some(text), + _ => None, + } + } + + /// Get the binary content if this is a binary message + pub fn as_bytes(&self) -> Option<&[u8]> { + match self { + Self::Binary(data) => Some(data), + _ => None, + } + } + + /// Convert to text, consuming the message + pub fn into_text(self) -> Option { + match self { + Self::Text(text) => Some(text), + _ => None, + } + } + + /// Convert to bytes, consuming the message + pub fn into_bytes(self) -> Option> { + match self { + Self::Binary(data) => Some(data), + _ => None, + } + } +} + +impl From for Message { + fn from(text: String) -> Self { + Self::Text(text) + } +} + +impl From<&str> for Message { + fn from(text: &str) -> Self { + Self::Text(text.to_string()) + } +} + +impl From> for Message { + fn from(data: Vec) -> Self { + Self::Binary(data) + } +} + +impl From<&[u8]> for Message { + fn from(data: &[u8]) -> Self { + Self::Binary(data.to_vec()) + } +} + +/// Convert from tungstenite Message +impl From for Message { + fn from(msg: tungstenite::Message) -> Self { + match msg { + tungstenite::Message::Text(text) => Self::Text(text.to_string()), + tungstenite::Message::Binary(data) => Self::Binary(data.to_vec()), + tungstenite::Message::Ping(data) => Self::Ping(data.to_vec()), + tungstenite::Message::Pong(data) => Self::Pong(data.to_vec()), + tungstenite::Message::Close(frame) => Self::Close(frame.map(|f| CloseFrame { + code: CloseCode::from(f.code), + reason: Cow::Owned(f.reason.to_string()), + })), + tungstenite::Message::Frame(_) => Self::Binary(vec![]), // Raw frames treated as binary + } + } +} + +/// Convert to tungstenite Message +impl From for tungstenite::Message { + fn from(msg: Message) -> Self { + match msg { + Message::Text(text) => tungstenite::Message::Text(text), + Message::Binary(data) => tungstenite::Message::Binary(data), + Message::Ping(data) => tungstenite::Message::Ping(data), + Message::Pong(data) => tungstenite::Message::Pong(data), + Message::Close(frame) => { + tungstenite::Message::Close(frame.map(|f| tungstenite::protocol::CloseFrame { + code: f.code.into(), + reason: f.reason, + })) + } + } + } +} + +/// WebSocket close frame +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CloseFrame { + /// Close code + pub code: CloseCode, + /// Close reason + pub reason: Cow<'static, str>, +} + +impl CloseFrame { + /// Create a new close frame + pub fn new(code: CloseCode, reason: impl Into>) -> Self { + Self { + code, + reason: reason.into(), + } + } + + /// Create a normal close frame + pub fn normal() -> Self { + Self::new(CloseCode::Normal, "") + } + + /// Create a going away close frame + pub fn going_away() -> Self { + Self::new(CloseCode::Away, "Going away") + } +} + +/// WebSocket close codes +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CloseCode { + /// Normal closure (1000) + Normal, + /// Going away (1001) + Away, + /// Protocol error (1002) + Protocol, + /// Unsupported data (1003) + Unsupported, + /// No status received (1005) + Status, + /// Abnormal closure (1006) + Abnormal, + /// Invalid frame payload data (1007) + Invalid, + /// Policy violation (1008) + Policy, + /// Message too big (1009) + Size, + /// Mandatory extension (1010) + Extension, + /// Internal error (1011) + Error, + /// Service restart (1012) + Restart, + /// Try again later (1013) + Again, + /// Bad TLS handshake (1015) + Tls, + /// Reserved codes + Reserved(u16), + /// Library/framework-specific codes (3000-3999) + Library(u16), + /// Private use codes (4000-4999) + Private(u16), +} + +impl CloseCode { + /// Get the numeric code + pub fn as_u16(&self) -> u16 { + match self { + Self::Normal => 1000, + Self::Away => 1001, + Self::Protocol => 1002, + Self::Unsupported => 1003, + Self::Status => 1005, + Self::Abnormal => 1006, + Self::Invalid => 1007, + Self::Policy => 1008, + Self::Size => 1009, + Self::Extension => 1010, + Self::Error => 1011, + Self::Restart => 1012, + Self::Again => 1013, + Self::Tls => 1015, + Self::Reserved(code) => *code, + Self::Library(code) => *code, + Self::Private(code) => *code, + } + } +} + +impl From for CloseCode { + fn from(code: u16) -> Self { + match code { + 1000 => Self::Normal, + 1001 => Self::Away, + 1002 => Self::Protocol, + 1003 => Self::Unsupported, + 1005 => Self::Status, + 1006 => Self::Abnormal, + 1007 => Self::Invalid, + 1008 => Self::Policy, + 1009 => Self::Size, + 1010 => Self::Extension, + 1011 => Self::Error, + 1012 => Self::Restart, + 1013 => Self::Again, + 1015 => Self::Tls, + 1004 | 1014 | 1016..=2999 => Self::Reserved(code), + 3000..=3999 => Self::Library(code), + 4000..=4999 => Self::Private(code), + _ => Self::Reserved(code), + } + } +} + +impl From for u16 { + fn from(code: CloseCode) -> Self { + code.as_u16() + } +} + +impl From for CloseCode { + fn from(code: tungstenite::protocol::frame::coding::CloseCode) -> Self { + Self::from(u16::from(code)) + } +} + +impl From for tungstenite::protocol::frame::coding::CloseCode { + fn from(code: CloseCode) -> Self { + tungstenite::protocol::frame::coding::CloseCode::from(code.as_u16()) + } +} diff --git a/crates/rustapi-ws/src/socket.rs b/crates/rustapi-ws/src/socket.rs new file mode 100644 index 0000000..dc26a14 --- /dev/null +++ b/crates/rustapi-ws/src/socket.rs @@ -0,0 +1,209 @@ +//! WebSocket stream implementation + +use crate::{Message, WebSocketError}; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, Stream, StreamExt, +}; +use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_tungstenite::WebSocketStream as TungsteniteStream; + +/// Type alias for the upgraded connection +type UpgradedConnection = TungsteniteStream>; + +/// A WebSocket stream that wraps the underlying tungstenite stream +/// +/// This provides a simple interface for sending and receiving WebSocket messages. +/// You can either use the stream directly with `send`/`recv` methods, or split +/// it into separate sender and receiver halves for concurrent operations. +#[allow(dead_code)] +pub struct WebSocketStream { + inner: UpgradedConnection, +} + +impl WebSocketStream { + /// Create a new WebSocket stream from an upgraded connection + #[allow(dead_code)] + pub(crate) fn new(inner: UpgradedConnection) -> Self { + Self { inner } + } + + /// Split the stream into sender and receiver halves + /// + /// This allows concurrent sending and receiving on the same connection. + /// + /// # Example + /// + /// ```rust,ignore + /// let (mut sender, mut receiver) = socket.split(); + /// + /// // Now you can use sender and receiver concurrently + /// tokio::select! { + /// msg = receiver.recv() => { /* handle incoming */ } + /// _ = sender.send(Message::text("ping")) => { /* sent */ } + /// } + /// ``` + pub fn split(self) -> (WebSocketSender, WebSocketReceiver) { + let (sink, stream) = self.inner.split(); + ( + WebSocketSender { inner: sink }, + WebSocketReceiver { inner: stream }, + ) + } + + /// Send a message + pub async fn send(&mut self, msg: Message) -> Result<(), WebSocketError> { + self.inner + .send(msg.into()) + .await + .map_err(WebSocketError::from) + } + + /// Send a text message + pub async fn send_text(&mut self, text: impl Into) -> Result<(), WebSocketError> { + self.send(Message::text(text)).await + } + + /// Send a binary message + pub async fn send_binary(&mut self, data: impl Into>) -> Result<(), WebSocketError> { + self.send(Message::binary(data)).await + } + + /// Send a JSON message + pub async fn send_json( + &mut self, + value: &T, + ) -> Result<(), WebSocketError> { + let msg = Message::json(value)?; + self.send(msg).await + } + + /// Receive the next message + pub async fn recv(&mut self) -> Option> { + self.inner + .next() + .await + .map(|result| result.map(Message::from).map_err(WebSocketError::from)) + } + + /// Close the connection + pub async fn close(mut self) -> Result<(), WebSocketError> { + self.inner.close(None).await.map_err(WebSocketError::from) + } + + /// Close the connection with a close frame + pub async fn close_with( + mut self, + code: crate::CloseCode, + reason: impl Into, + ) -> Result<(), WebSocketError> { + let frame = tungstenite::protocol::CloseFrame { + code: code.into(), + reason: reason.into().into(), + }; + self.inner + .close(Some(frame)) + .await + .map_err(WebSocketError::from) + } +} + +/// Sender half of a WebSocket stream +/// +/// This is obtained by calling `split()` on a `WebSocketStream`. +pub struct WebSocketSender { + inner: SplitSink, +} + +impl WebSocketSender { + /// Send a message + pub async fn send(&mut self, msg: Message) -> Result<(), WebSocketError> { + self.inner + .send(msg.into()) + .await + .map_err(WebSocketError::from) + } + + /// Send a text message + pub async fn send_text(&mut self, text: impl Into) -> Result<(), WebSocketError> { + self.send(Message::text(text)).await + } + + /// Send a binary message + pub async fn send_binary(&mut self, data: impl Into>) -> Result<(), WebSocketError> { + self.send(Message::binary(data)).await + } + + /// Send a JSON message + pub async fn send_json( + &mut self, + value: &T, + ) -> Result<(), WebSocketError> { + let msg = Message::json(value)?; + self.send(msg).await + } + + /// Flush any buffered messages + pub async fn flush(&mut self) -> Result<(), WebSocketError> { + self.inner.flush().await.map_err(WebSocketError::from) + } + + /// Close the sender + pub async fn close(mut self) -> Result<(), WebSocketError> { + self.inner.close().await.map_err(WebSocketError::from) + } +} + +/// Receiver half of a WebSocket stream +/// +/// This is obtained by calling `split()` on a `WebSocketStream`. +pub struct WebSocketReceiver { + inner: SplitStream, +} + +impl WebSocketReceiver { + /// Receive the next message + pub async fn recv(&mut self) -> Option> { + self.next().await + } + + /// Receive the next text message, skipping non-text messages + pub async fn recv_text(&mut self) -> Option> { + loop { + match self.recv().await { + Some(Ok(Message::Text(text))) => return Some(Ok(text)), + Some(Ok(Message::Close(_))) => return None, + Some(Err(e)) => return Some(Err(e)), + Some(Ok(_)) => continue, // Skip non-text messages + None => return None, + } + } + } + + /// Receive and deserialize a JSON message + pub async fn recv_json( + &mut self, + ) -> Option> { + match self.recv().await { + Some(Ok(msg)) => Some(msg.as_json()), + Some(Err(e)) => Some(Err(e)), + None => None, + } + } +} + +impl Stream for WebSocketReceiver { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(Message::from(msg)))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(WebSocketError::from(e)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/rustapi-ws/src/upgrade.rs b/crates/rustapi-ws/src/upgrade.rs new file mode 100644 index 0000000..e87f8b8 --- /dev/null +++ b/crates/rustapi-ws/src/upgrade.rs @@ -0,0 +1,201 @@ +//! WebSocket upgrade response + +use crate::{WebSocketError, WebSocketStream}; +use bytes::Bytes; +use http::{header, Response, StatusCode}; +use http_body_util::Full; +use rustapi_core::IntoResponse; +use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec}; +use std::future::Future; +use std::pin::Pin; + +/// Type alias for WebSocket upgrade callback +type UpgradeCallback = + Box Pin + Send>> + Send>; + +/// WebSocket upgrade response +/// +/// This type is returned from WebSocket handlers to initiate the upgrade +/// handshake and establish a WebSocket connection. +pub struct WebSocketUpgrade { + /// The upgrade response + response: Response>, + /// Callback to handle the WebSocket connection + on_upgrade: Option, + /// SEC-WebSocket-Key from request + sec_key: String, +} + +impl WebSocketUpgrade { + /// Create a new WebSocket upgrade from request headers + pub(crate) fn new(sec_key: String) -> Self { + // Generate accept key + let accept_key = generate_accept_key(&sec_key); + + // Build upgrade response + let response = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::UPGRADE, "websocket") + .header(header::CONNECTION, "Upgrade") + .header("Sec-WebSocket-Accept", accept_key) + .body(Full::new(Bytes::new())) + .unwrap(); + + Self { + response, + on_upgrade: None, + sec_key, + } + } + + /// Set the callback to handle the upgraded WebSocket connection + /// + /// # Example + /// + /// ```rust,ignore + /// ws.on_upgrade(|socket| async move { + /// let (mut sender, mut receiver) = socket.split(); + /// while let Some(msg) = receiver.next().await { + /// // Handle messages... + /// } + /// }) + /// ``` + pub fn on_upgrade(mut self, callback: F) -> Self + where + F: FnOnce(WebSocketStream) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream)))); + self + } + + /// Add a protocol to the response + pub fn protocol(mut self, protocol: &str) -> Self { + self.response = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::UPGRADE, "websocket") + .header(header::CONNECTION, "Upgrade") + .header("Sec-WebSocket-Accept", generate_accept_key(&self.sec_key)) + .header("Sec-WebSocket-Protocol", protocol) + .body(Full::new(Bytes::new())) + .unwrap(); + self + } + + /// Get the underlying response (for implementing IntoResponse) + #[allow(dead_code)] + pub(crate) fn into_response_inner(self) -> Response> { + self.response + } + + /// Get the on_upgrade callback + #[allow(dead_code)] + pub(crate) fn take_callback(&mut self) -> Option { + self.on_upgrade.take() + } +} + +impl IntoResponse for WebSocketUpgrade { + fn into_response(self) -> http::Response> { + self.response + } +} + +impl ResponseModifier for WebSocketUpgrade { + fn update_response(op: &mut Operation) { + op.responses.insert( + "101".to_string(), + ResponseSpec { + description: "WebSocket upgrade successful".to_string(), + content: None, + }, + ); + } +} + +/// Generate the Sec-WebSocket-Accept key from the client's Sec-WebSocket-Key +fn generate_accept_key(key: &str) -> String { + use base64::Engine; + use sha1::{Digest, Sha1}; + + const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + let mut hasher = Sha1::new(); + hasher.update(key.as_bytes()); + hasher.update(GUID.as_bytes()); + let hash = hasher.finalize(); + + base64::engine::general_purpose::STANDARD.encode(hash) +} + +/// Validate that a request is a valid WebSocket upgrade request +pub(crate) fn validate_upgrade_request( + method: &http::Method, + headers: &http::HeaderMap, +) -> Result { + // Must be GET + if method != http::Method::GET { + return Err(WebSocketError::invalid_upgrade("Method must be GET")); + } + + // Must have Upgrade: websocket header + let upgrade = headers + .get(header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?; + + if !upgrade.eq_ignore_ascii_case("websocket") { + return Err(WebSocketError::invalid_upgrade( + "Upgrade header must be 'websocket'", + )); + } + + // Must have Connection: Upgrade header + let connection = headers + .get(header::CONNECTION) + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?; + + let has_upgrade = connection + .split(',') + .any(|s| s.trim().eq_ignore_ascii_case("upgrade")); + + if !has_upgrade { + return Err(WebSocketError::invalid_upgrade( + "Connection header must contain 'Upgrade'", + )); + } + + // Must have Sec-WebSocket-Key header + let sec_key = headers + .get("Sec-WebSocket-Key") + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?; + + // Must have Sec-WebSocket-Version: 13 + let version = headers + .get("Sec-WebSocket-Version") + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?; + + if version != "13" { + return Err(WebSocketError::invalid_upgrade( + "Sec-WebSocket-Version must be 13", + )); + } + + Ok(sec_key.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_accept_key_generation() { + // Example from RFC 6455 + let key = "dGhlIHNhbXBsZSBub25jZQ=="; + let accept = generate_accept_key(key); + assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } +} diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 65d4a80..5c3b888 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -29,19 +29,26 @@ RustAPI uses a **layered facade architecture** where complexity is hidden behind │ HTTP Engine │ │ Proc Macros │ │ Swagger/OpenAPI │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ - ├─────────────────┬─────────────────┐ - ▼ ▼ ▼ -┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ -│rustapi-validate │ │ rustapi-toon │ │ rustapi-extras │ -│ Validation │ │ LLM Format │ │ JWT/CORS/Rate │ -└─────────────────┘ └─────────────────┘ └─────────────────┘ - │ │ │ - └─────────────────┴─────────────────┘ + ├─────────────────┬─────────────────┬─────────────────┐ + ▼ ▼ ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│rustapi-validate │ │ rustapi-toon │ │ rustapi-extras │ │ rustapi-ws │ +│ Validation │ │ LLM Format │ │ JWT/CORS/Rate │ │ WebSocket │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ │ │ + └─────────────────┴─────────────────┴─────────────────┤ + │ ▼ + │ ┌─────────────────┐ + │ │ rustapi-view │ + │ │ Template Engine │ + │ └─────────────────┘ + │ │ + └─────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────┐ │ Foundation Layer │ -│ tokio │ hyper │ serde │ matchit │ tower │ +│ tokio │ hyper │ serde │ matchit │ tower │ tungstenite │ tera │ └─────────────────────────────────────────────────────────────────┘ ``` @@ -81,6 +88,12 @@ pub mod prelude { #[cfg(feature = "jwt")] pub use rustapi_extras::jwt::*; + + #[cfg(feature = "ws")] + pub use rustapi_ws::{WebSocket, WebSocketUpgrade, WebSocketStream, Message, Broadcast}; + + #[cfg(feature = "view")] + pub use rustapi_view::{Templates, View, ContextBuilder}; } ``` @@ -205,6 +218,29 @@ Headers provided by `LlmResponse`: | Body Limit | default | Max request body size | | Request ID | default | Unique request tracking | +### `rustapi-ws` — WebSocket Support + +**Real-time bidirectional communication.** + +| Type | Purpose | +|------|---------| +| `WebSocket` | Extractor for WebSocket upgrades | +| `WebSocketUpgrade` | Response type for upgrade handshake | +| `WebSocketStream` | Async stream for send/recv | +| `Message` | Text, Binary, Ping, Pong, Close | +| `Broadcast` | Pub/sub channel for broadcasting | + +### `rustapi-view` — Template Engine + +**Server-side HTML rendering with Tera.** + +| Type | Purpose | +|------|---------| +| `Templates` | Template engine instance | +| `View` | Response type with template rendering | +| `ContextBuilder` | Build template context | +| `TemplatesConfig` | Configuration (directory, extension) | + --- ## Request Flow diff --git a/docs/FEATURES.md b/docs/FEATURES.md index 67355ba..41ebf87 100644 --- a/docs/FEATURES.md +++ b/docs/FEATURES.md @@ -12,9 +12,11 @@ 4. [OpenAPI & Swagger](#openapi--swagger) 5. [Middleware](#middleware) 6. [TOON Format](#toon-format) -7. [Testing](#testing) -8. [Error Handling](#error-handling) -9. [Configuration](#configuration) +7. [WebSocket](#websocket) +8. [Template Engine](#template-engine) +9. [Testing](#testing) +10. [Error Handling](#error-handling) +11. [Configuration](#configuration) --- @@ -681,6 +683,283 @@ users[(id:1,name:Alice,email:alice@example.com)(id:2,name:Bob,email:bob@example. --- +## WebSocket + +Real-time bidirectional communication support (requires `ws` feature). + +### Basic WebSocket Handler + +```rust +use rustapi_rs::ws::{WebSocket, WebSocketUpgrade, WebSocketStream, Message}; + +#[rustapi_rs::get("/ws")] +async fn websocket(ws: WebSocket) -> WebSocketUpgrade { + ws.on_upgrade(handle_connection) +} + +async fn handle_connection(mut stream: WebSocketStream) { + while let Some(msg) = stream.recv().await { + match msg { + Message::Text(text) => { + // Echo the message back + stream.send(Message::Text(format!("Echo: {}", text))).await.ok(); + } + Message::Binary(data) => { + // Handle binary data + stream.send(Message::Binary(data)).await.ok(); + } + Message::Ping(data) => { + stream.send(Message::Pong(data)).await.ok(); + } + Message::Close(_) => break, + _ => {} + } + } +} +``` + +### Message Types + +| Type | Description | +|------|-------------| +| `Message::Text(String)` | UTF-8 text message | +| `Message::Binary(Vec)` | Binary data | +| `Message::Ping(Vec)` | Ping frame (keepalive) | +| `Message::Pong(Vec)` | Pong response | +| `Message::Close(Option)` | Connection close | + +### Broadcast Channel + +For pub/sub patterns (chat rooms, live updates): + +```rust +use rustapi_rs::ws::{Broadcast, Message}; +use std::sync::Arc; + +#[tokio::main] +async fn main() { + let broadcast = Arc::new(Broadcast::new()); + + RustApi::new() + .state(broadcast) + .route("/ws", get(websocket)) + .route("/broadcast", post(send_broadcast)) + .run("0.0.0.0:8080") + .await +} + +#[rustapi_rs::get("/ws")] +async fn websocket( + ws: WebSocket, + State(broadcast): State>, +) -> WebSocketUpgrade { + let mut rx = broadcast.subscribe(); + ws.on_upgrade(move |mut stream| async move { + loop { + tokio::select! { + // Receive from client + msg = stream.recv() => { + match msg { + Some(Message::Close(_)) | None => break, + _ => {} + } + } + // Receive broadcasts + Ok(msg) = rx.recv() => { + if stream.send(msg).await.is_err() { + break; + } + } + } + } + }) +} + +#[rustapi_rs::post("/broadcast")] +async fn send_broadcast( + State(broadcast): State>, + body: String, +) -> &'static str { + broadcast.send(Message::Text(body)); + "Sent" +} +``` + +### WebSocket with State + +```rust +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +struct ConnectionCounter(AtomicUsize); + +#[rustapi_rs::get("/ws")] +async fn websocket( + ws: WebSocket, + State(counter): State>, +) -> WebSocketUpgrade { + ws.on_upgrade(move |stream| async move { + counter.0.fetch_add(1, Ordering::SeqCst); + handle_connection(stream).await; + counter.0.fetch_sub(1, Ordering::SeqCst); + }) +} +``` + +--- + +## Template Engine + +Server-side HTML rendering with Tera templates (requires `view` feature). + +### Setup + +```rust +use rustapi_rs::view::{Templates, TemplatesConfig}; + +#[tokio::main] +async fn main() { + let templates = Templates::new(TemplatesConfig { + directory: "templates".into(), + extension: "html".into(), + }).expect("Failed to load templates"); + + RustApi::new() + .state(templates) + .route("/", get(home)) + .run("0.0.0.0:8080") + .await +} +``` + +### Basic Template Rendering + +```rust +use rustapi_rs::view::{Templates, View}; + +#[rustapi_rs::get("/")] +async fn home(templates: Templates) -> View<()> { + View::new(&templates, "index.html", ()) +} + +#[derive(Serialize)] +struct UserData { + name: String, + email: String, +} + +#[rustapi_rs::get("/user/{id}")] +async fn user_page( + templates: Templates, + Path(id): Path, +) -> View { + let user = UserData { + name: "Alice".into(), + email: "alice@example.com".into(), + }; + View::new(&templates, "user.html", user) +} +``` + +### Template with Extra Context + +```rust +use rustapi_rs::view::{Templates, View, ContextBuilder}; + +#[rustapi_rs::get("/dashboard")] +async fn dashboard(templates: Templates) -> View { + let data = get_dashboard_data(); + + View::with_context(&templates, "dashboard.html", data, |ctx| { + ctx.insert("title", &"Dashboard"); + ctx.insert("year", &2024); + ctx.insert("nav_items", &vec!["Home", "Users", "Settings"]); + }) +} +``` + +### Tera Template Syntax + +**templates/base.html:** +```html + + + + {% block title %}My App{% endblock %} + + + +
{% block content %}{% endblock %}
+ + +``` + +**templates/user.html:** +```html +{% extends "base.html" %} + +{% block title %}{{ name }} - My App{% endblock %} + +{% block content %} + +{% endblock %} +``` + +### Template Features + +| Feature | Syntax | Description | +|---------|--------|-------------| +| Variables | `{{ name }}` | Output variable | +| Filters | `{{ name \| upper }}` | Transform values | +| Conditionals | `{% if x %}...{% endif %}` | Conditional rendering | +| Loops | `{% for x in items %}` | Iteration | +| Inheritance | `{% extends "base.html" %}` | Template inheritance | +| Blocks | `{% block name %}` | Overridable sections | +| Includes | `{% include "partial.html" %}` | Include templates | +| Macros | `{% macro name() %}` | Reusable snippets | + +### Built-in Filters + +| Filter | Example | Description | +|--------|---------|-------------| +| `upper` | `{{ name \| upper }}` | UPPERCASE | +| `lower` | `{{ name \| lower }}` | lowercase | +| `capitalize` | `{{ name \| capitalize }}` | Capitalize | +| `trim` | `{{ text \| trim }}` | Remove whitespace | +| `length` | `{{ items \| length }}` | Array/string length | +| `default` | `{{ x \| default(value="N/A") }}` | Default value | +| `date` | `{{ dt \| date(format="%Y-%m-%d") }}` | Date formatting | +| `json_encode` | `{{ obj \| json_encode }}` | JSON string | + +### Error Handling + +```rust +#[rustapi_rs::get("/user/{id}")] +async fn user_page( + templates: Templates, + Path(id): Path, +) -> Result> { + let user = find_user(id) + .ok_or_else(|| ApiError::not_found("User not found"))?; + + Ok(View::new(&templates, "user.html", user)) +} +``` + +--- + ## Testing ### TestClient @@ -854,6 +1133,8 @@ rustapi-rs = { version = "0.1.4", features = ["full"] } | `rate-limit` | Rate limiting | | `toon` | TOON format | | `cookies` | Cookie extraction | +| `ws` | WebSocket support | +| `view` | Template engine (Tera) | | `full` | All features | --- diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 080481a..773c31a 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -29,7 +29,7 @@ Or with specific features: ```toml [dependencies] -rustapi-rs = { version = "0.1.4", features = ["jwt", "cors", "toon"] } +rustapi-rs = { version = "0.1.4", features = ["jwt", "cors", "toon", "ws", "view"] } ``` ### Available Features @@ -41,6 +41,8 @@ rustapi-rs = { version = "0.1.4", features = ["jwt", "cors", "toon"] } | `cors` | CORS middleware | | `rate-limit` | IP-based rate limiting | | `toon` | LLM-optimized TOON format | +| `ws` | WebSocket support | +| `view` | Template engine (Tera) | | `full` | All features | --- @@ -449,6 +451,120 @@ Response includes token counting headers: --- +## WebSocket Support + +Real-time bidirectional communication: + +```toml +rustapi-rs = { version = "0.1.4", features = ["ws"] } +``` + +```rust +use rustapi_rs::ws::{WebSocket, WebSocketUpgrade, WebSocketStream, Message}; + +#[rustapi_rs::get("/ws")] +async fn websocket(ws: WebSocket) -> WebSocketUpgrade { + ws.on_upgrade(handle_connection) +} + +async fn handle_connection(mut stream: WebSocketStream) { + while let Some(msg) = stream.recv().await { + match msg { + Message::Text(text) => { + // Echo the message back + stream.send(Message::Text(format!("Echo: {}", text))).await.ok(); + } + Message::Close(_) => break, + _ => {} + } + } +} +``` + +Test with `websocat`: +```bash +websocat ws://localhost:8080/ws +``` + +--- + +## Template Engine + +Server-side HTML rendering with Tera: + +```toml +rustapi-rs = { version = "0.1.4", features = ["view"] } +``` + +Create a template file `templates/index.html`: +```html + + +{{ title }} + +

Hello, {{ name }}!

+ + +``` + +Use in your handler: +```rust +use rustapi_rs::view::{Templates, View, TemplatesConfig}; + +#[tokio::main] +async fn main() { + let templates = Templates::new(TemplatesConfig { + directory: "templates".into(), + extension: "html".into(), + }).unwrap(); + + RustApi::new() + .state(templates) + .route("/", get(home)) + .run("0.0.0.0:8080") + .await +} + +#[derive(Serialize)] +struct HomeData { + title: String, + name: String, +} + +#[rustapi_rs::get("/")] +async fn home(templates: Templates) -> View { + View::new(&templates, "index.html", HomeData { + title: "Welcome".into(), + name: "World".into(), + }) +} +``` + +--- + +## CLI Tool + +Scaffold new RustAPI projects quickly: + +```bash +# Install the CLI +cargo install cargo-rustapi + +# Create a new project +cargo rustapi new my-api + +# Interactive mode with template selection +cargo rustapi new my-api --interactive +``` + +Available templates: +- `minimal` — Basic RustAPI setup +- `api` — REST API with CRUD operations +- `web` — Full web app with templates and WebSocket +- `full` — Everything included + +--- + ## Testing ```rust diff --git a/examples/mcp-server/src/main.rs b/examples/mcp-server/src/main.rs index 879b7c3..7793d64 100644 --- a/examples/mcp-server/src/main.rs +++ b/examples/mcp-server/src/main.rs @@ -207,7 +207,10 @@ async fn list_tools(accept: AcceptHeader) -> LlmResponse { PropertySchema { prop_type: "string".to_string(), description: "Temperature units".to_string(), - enum_values: Some(vec!["celsius".to_string(), "fahrenheit".to_string()]), + enum_values: Some(vec![ + "celsius".to_string(), + "fahrenheit".to_string(), + ]), }, ), ]), @@ -226,13 +229,19 @@ async fn list_tools(accept: AcceptHeader) -> LlmResponse { async fn execute_tool(Json(request): Json) -> Toon { match request.tool.as_str() { "calculate" => { - let operation = request.arguments.get("operation") + let operation = request + .arguments + .get("operation") .map(|v| v.as_str()) .unwrap_or("add"); - let a = request.arguments.get("a") + let a = request + .arguments + .get("a") .and_then(|v| v.parse::().ok()) .unwrap_or(0.0); - let b = request.arguments.get("b") + let b = request + .arguments + .get("b") .and_then(|v| v.parse::().ok()) .unwrap_or(0.0); @@ -266,10 +275,14 @@ async fn execute_tool(Json(request): Json) -> Toon { - let location = request.arguments.get("location") + let location = request + .arguments + .get("location") .map(|v| v.as_str()) .unwrap_or("Unknown"); - let units = request.arguments.get("units") + let units = request + .arguments + .get("units") .map(|v| v.as_str()) .unwrap_or("celsius"); @@ -333,5 +346,5 @@ async fn main() { println!("📦 Resources: http://localhost:8080/mcp/resources"); println!("\n💡 Tip: Use 'Accept: application/toon' header for LLM-optimized responses\n"); - RustApi::auto().run("127.0.0.1:8080").await; + let _ = RustApi::auto().run("127.0.0.1:8080").await; } diff --git a/examples/proof-of-concept/src/models.rs b/examples/proof-of-concept/src/models.rs index 358bc85..cd26b11 100644 --- a/examples/proof-of-concept/src/models.rs +++ b/examples/proof-of-concept/src/models.rs @@ -268,12 +268,14 @@ pub struct ImportResponse { // ============================================ /// Standard error response format +#[allow(dead_code)] #[derive(Debug, Serialize, Schema)] pub struct ErrorResponse { pub error: ErrorDetail, } /// Error detail structure +#[allow(dead_code)] #[derive(Debug, Serialize, Schema)] pub struct ErrorDetail { #[serde(rename = "type")] @@ -284,6 +286,7 @@ pub struct ErrorDetail { } /// Field-level validation error +#[allow(dead_code)] #[derive(Debug, Serialize, Schema)] pub struct FieldError { pub field: String, diff --git a/examples/templates/Cargo.toml b/examples/templates/Cargo.toml new file mode 100644 index 0000000..599d821 --- /dev/null +++ b/examples/templates/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "templates-example" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +rustapi-rs = { path = "../../crates/rustapi-rs", features = ["view"] } +utoipa = { workspace = true } +tokio = { workspace = true } +serde = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } diff --git a/examples/templates/src/main.rs b/examples/templates/src/main.rs new file mode 100644 index 0000000..38634d8 --- /dev/null +++ b/examples/templates/src/main.rs @@ -0,0 +1,224 @@ +//! Template Rendering Example +//! +//! This example demonstrates Tera template support in RustAPI: +//! - Server-side HTML rendering +//! - Template inheritance (layouts) +//! - Context building +//! - Static file serving +//! +//! Run with: cargo run --package templates-example + +use rustapi_rs::prelude::*; +use rustapi_rs::view::{ContextBuilder, Templates, View}; + +/// Contact form params +#[derive(Debug, Clone, Deserialize, IntoParams)] +struct ContactForm { + name: Option, + message: Option, +} + +/// Home page context +#[derive(Serialize)] +struct HomeContext { + title: String, + features: Vec, +} + +#[derive(Serialize)] +struct Feature { + name: String, + description: String, +} + +/// About page context +#[derive(Serialize)] +struct AboutContext { + title: String, + version: String, + rust_version: String, +} + +/// Contact form context +#[derive(Serialize)] +struct ContactContext { + title: String, + submitted: bool, + name: Option, + message: Option, +} + +/// Blog post context +#[derive(Serialize)] +struct BlogContext { + title: String, + posts: Vec, +} + +#[derive(Serialize)] +struct BlogPost { + id: u32, + title: String, + excerpt: String, + author: String, + date: String, +} + +/// Home page handler +async fn home(State(templates): State) -> View { + let features = vec![ + Feature { + name: "Type-Safe".to_string(), + description: "Compile-time route and schema validation".to_string(), + }, + Feature { + name: "Fast".to_string(), + description: "Built on Tokio and Hyper for maximum performance".to_string(), + }, + Feature { + name: "Easy".to_string(), + description: "Minimal boilerplate, intuitive API".to_string(), + }, + Feature { + name: "Documented".to_string(), + description: "Auto-generated OpenAPI + Swagger UI".to_string(), + }, + ]; + + View::render( + &templates, + "index.html", + HomeContext { + title: "Home".to_string(), + features, + }, + ) + .await +} + +/// About page handler +async fn about(State(templates): State) -> View { + View::render( + &templates, + "about.html", + AboutContext { + title: "About".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + rust_version: "1.75+".to_string(), + }, + ) + .await +} + +/// Contact page handler (GET) +async fn contact_get(State(templates): State) -> View { + View::render( + &templates, + "contact.html", + ContactContext { + title: "Contact".to_string(), + submitted: false, + name: None, + message: None, + }, + ) + .await +} + +/// Contact form submission (POST) +async fn contact_post( + State(templates): State, + Query(params): Query, +) -> View { + tracing::info!("Contact form submitted: {:?}", params); + + View::render( + &templates, + "contact.html", + ContactContext { + title: "Contact".to_string(), + submitted: true, + name: params.name, + message: params.message, + }, + ) + .await +} + +/// Blog listing page +async fn blog(State(templates): State) -> View { + let posts = vec![ + BlogPost { + id: 1, + title: "Getting Started with RustAPI".to_string(), + excerpt: "Learn how to build your first API with RustAPI...".to_string(), + author: "RustAPI Team".to_string(), + date: "2026-01-05".to_string(), + }, + BlogPost { + id: 2, + title: "WebSocket Support in RustAPI".to_string(), + excerpt: "Real-time communication made easy...".to_string(), + author: "RustAPI Team".to_string(), + date: "2026-01-04".to_string(), + }, + BlogPost { + id: 3, + title: "Template Rendering with Tera".to_string(), + excerpt: "Server-side rendering for your web apps...".to_string(), + author: "RustAPI Team".to_string(), + date: "2026-01-03".to_string(), + }, + ]; + + View::render( + &templates, + "blog.html", + BlogContext { + title: "Blog".to_string(), + posts, + }, + ) + .await +} + +/// Dynamic context example using ContextBuilder +async fn dynamic(State(templates): State) -> View<()> { + let context = ContextBuilder::new() + .insert("title", &"Dynamic Page") + .insert("items", &vec!["One", "Two", "Three"]) + .insert("count", &3) + .insert_if("show_banner", &true, |_| true) + .build(); + + View::render_context(&templates, "dynamic.html", &context).await +} + +#[rustapi_rs::main] +async fn main() -> std::result::Result<(), Box> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("templates_example=debug".parse().unwrap()) + .add_directive("info".parse().unwrap()), + ) + .init(); + + // Initialize templates from the templates directory + let templates = Templates::new("examples/templates/templates/**/*.html")?; + + let addr = "127.0.0.1:8080"; + tracing::info!("🚀 Server running at http://{}", addr); + + RustApi::new() + .state(templates) + .route("/", get(home)) + .route("/about", get(about)) + .route("/contact", get(contact_get)) + .route("/contact", post(contact_post)) + .route("/blog", get(blog)) + .route("/dynamic", get(dynamic)) + .serve_static("/static", "examples/templates/static") + .run(addr) + .await +} diff --git a/examples/templates/static/style.css b/examples/templates/static/style.css new file mode 100644 index 0000000..201dae1 --- /dev/null +++ b/examples/templates/static/style.css @@ -0,0 +1,258 @@ +* { + box-sizing: border-box; + margin: 0; + padding: 0; +} + +:root { + --primary: #f74c00; + --primary-dark: #d14000; + --text: #333; + --text-light: #666; + --bg: #fff; + --bg-alt: #f5f5f5; + --border: #ddd; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; + line-height: 1.6; + color: var(--text); + background: var(--bg); +} + +/* Header */ +header { + background: var(--bg); + border-bottom: 1px solid var(--border); + padding: 1rem 2rem; +} + +nav { + max-width: 1200px; + margin: 0 auto; + display: flex; + justify-content: space-between; + align-items: center; +} + +.logo { + font-size: 1.5rem; + font-weight: bold; + color: var(--primary); + text-decoration: none; +} + +.nav-links a { + margin-left: 1.5rem; + color: var(--text); + text-decoration: none; +} + +.nav-links a:hover { + color: var(--primary); +} + +/* Main content */ +main { + max-width: 1200px; + margin: 0 auto; + padding: 2rem; + min-height: calc(100vh - 200px); +} + +h1 { + margin-bottom: 1.5rem; + color: var(--text); +} + +h2 { + margin-top: 2rem; + margin-bottom: 1rem; +} + +p { + margin-bottom: 1rem; +} + +/* Hero section */ +.hero { + text-align: center; + padding: 4rem 2rem; + background: linear-gradient(135deg, var(--primary), var(--primary-dark)); + color: white; + border-radius: 8px; + margin-bottom: 3rem; +} + +.hero h1 { + font-size: 2.5rem; + color: white; + margin-bottom: 0.5rem; +} + +.hero p { + font-size: 1.2rem; + opacity: 0.9; +} + +/* Features */ +.feature-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); + gap: 1.5rem; +} + +.feature-card { + background: var(--bg-alt); + padding: 1.5rem; + border-radius: 8px; + border: 1px solid var(--border); +} + +.feature-card h3 { + color: var(--primary); + margin-bottom: 0.5rem; +} + +/* CTA */ +.cta { + text-align: center; + padding: 2rem; + margin-top: 3rem; +} + +.cta pre { + display: inline-block; + background: var(--text); + color: white; + padding: 1rem 2rem; + border-radius: 4px; + font-size: 1.1rem; +} + +/* Blog */ +.blog-list { + display: flex; + flex-direction: column; + gap: 2rem; +} + +.blog-post { + background: var(--bg-alt); + padding: 1.5rem; + border-radius: 8px; + border: 1px solid var(--border); +} + +.blog-post h2 { + margin-top: 0; + margin-bottom: 0.5rem; +} + +.post-meta { + color: var(--text-light); + font-size: 0.9rem; + margin-bottom: 1rem; +} + +.post-meta .author::after { + content: " • "; +} + +.read-more { + color: var(--primary); + text-decoration: none; + font-weight: 500; +} + +.read-more:hover { + text-decoration: underline; +} + +/* Contact form */ +.contact-form { + max-width: 500px; +} + +.form-group { + margin-bottom: 1.5rem; +} + +.form-group label { + display: block; + margin-bottom: 0.5rem; + font-weight: 500; +} + +.form-group input, +.form-group textarea { + width: 100%; + padding: 0.75rem; + border: 1px solid var(--border); + border-radius: 4px; + font-size: 1rem; +} + +.form-group input:focus, +.form-group textarea:focus { + outline: none; + border-color: var(--primary); +} + +button[type="submit"] { + background: var(--primary); + color: white; + border: none; + padding: 0.75rem 1.5rem; + border-radius: 4px; + font-size: 1rem; + cursor: pointer; +} + +button[type="submit"]:hover { + background: var(--primary-dark); +} + +.success-message { + background: #e8f5e9; + border: 1px solid #4caf50; + padding: 2rem; + border-radius: 8px; +} + +.success-message h2 { + margin-top: 0; + color: #2e7d32; +} + +.success-message blockquote { + background: white; + padding: 1rem; + margin: 1rem 0; + border-left: 4px solid #4caf50; +} + +/* Banner */ +.banner { + background: var(--primary); + color: white; + padding: 1rem; + border-radius: 4px; + margin-bottom: 1.5rem; + text-align: center; +} + +/* Footer */ +footer { + text-align: center; + padding: 2rem; + border-top: 1px solid var(--border); + color: var(--text-light); +} + +/* Utilities */ +ul { + margin-left: 2rem; + margin-bottom: 1rem; +} diff --git a/examples/templates/templates/about.html b/examples/templates/templates/about.html new file mode 100644 index 0000000..d0167fd --- /dev/null +++ b/examples/templates/templates/about.html @@ -0,0 +1,22 @@ +{% extends "base.html" %} + +{% block content %} +

About RustAPI

+ +

RustAPI is a modern web framework for Rust, inspired by Python's FastAPI.

+ +

Version Information

+
    +
  • Version: {{ version }}
  • +
  • Rust: {{ rust_version }}
  • +
+ +

Philosophy

+

RustAPI believes that web development should be:

+
    +
  • Fast - Both in development and runtime
  • +
  • Type-safe - Catch errors at compile time
  • +
  • Ergonomic - Write business logic, not boilerplate
  • +
  • Well-documented - Auto-generated API docs
  • +
+{% endblock %} diff --git a/examples/templates/templates/base.html b/examples/templates/templates/base.html new file mode 100644 index 0000000..2d1b2bc --- /dev/null +++ b/examples/templates/templates/base.html @@ -0,0 +1,34 @@ + + + + + + {% block title %}{{ title }}{% endblock %} - RustAPI + + {% block head %}{% endblock %} + + +
+ +
+ +
+ {% block content %}{% endblock %} +
+ +
+

© 2026 RustAPI. Built with ❤️ and Rust.

+
+ + {% block scripts %}{% endblock %} + + diff --git a/examples/templates/templates/blog.html b/examples/templates/templates/blog.html new file mode 100644 index 0000000..20f747f --- /dev/null +++ b/examples/templates/templates/blog.html @@ -0,0 +1,23 @@ +{% extends "base.html" %} + +{% block content %} +

Blog

+ +
+ {% for post in posts %} +
+

{{ post.title }}

+ +

{{ post.excerpt }}

+ Read more → +
+ {% endfor %} +
+ +{% if posts | length == 0 %} +

No blog posts yet.

+{% endif %} +{% endblock %} diff --git a/examples/templates/templates/contact.html b/examples/templates/templates/contact.html new file mode 100644 index 0000000..8917502 --- /dev/null +++ b/examples/templates/templates/contact.html @@ -0,0 +1,35 @@ +{% extends "base.html" %} + +{% block content %} +

Contact Us

+ +{% if submitted %} +
+

Thank you{% if name %}, {{ name }}{% endif %}!

+

Your message has been received.

+ {% if message %} +
{{ message }}
+ {% endif %} + Send another message +
+{% else %} +
+
+ + +
+ +
+ + +
+ +
+ + +
+ + +
+{% endif %} +{% endblock %} diff --git a/examples/templates/templates/dynamic.html b/examples/templates/templates/dynamic.html new file mode 100644 index 0000000..bafd25c --- /dev/null +++ b/examples/templates/templates/dynamic.html @@ -0,0 +1,20 @@ +{% extends "base.html" %} + +{% block content %} +

Dynamic Content Example

+ +

This page demonstrates dynamic context building with ContextBuilder.

+ +{% if show_banner %} + +{% endif %} + +

Items ({{ count }} total)

+
    + {% for item in items %} +
  • {{ item }}
  • + {% endfor %} +
+{% endblock %} diff --git a/examples/templates/templates/index.html b/examples/templates/templates/index.html new file mode 100644 index 0000000..c43f9c0 --- /dev/null +++ b/examples/templates/templates/index.html @@ -0,0 +1,25 @@ +{% extends "base.html" %} + +{% block content %} +
+

Welcome to RustAPI

+

A FastAPI-like web framework for Rust

+
+ +
+

Features

+
+ {% for feature in features %} +
+

{{ feature.name }}

+

{{ feature.description }}

+
+ {% endfor %} +
+
+ +
+

Get Started

+
cargo add rustapi-rs
+
+{% endblock %} diff --git a/examples/websocket/Cargo.toml b/examples/websocket/Cargo.toml new file mode 100644 index 0000000..90c1da8 --- /dev/null +++ b/examples/websocket/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "websocket-example" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +rustapi-rs = { path = "../../crates/rustapi-rs", features = ["ws"] } +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +futures-util = { workspace = true } diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs new file mode 100644 index 0000000..b6a6122 --- /dev/null +++ b/examples/websocket/src/main.rs @@ -0,0 +1,338 @@ +//! WebSocket Example +//! +//! This example demonstrates WebSocket support in RustAPI: +//! - Basic echo server +//! - JSON message handling +//! - Broadcast to multiple clients +//! +//! Run with: cargo run --package websocket-example +//! Test with a WebSocket client (e.g., websocat): +//! websocat ws://localhost:8080/ws/echo +//! websocat ws://localhost:8080/ws/chat + +use rustapi_rs::prelude::*; +use rustapi_rs::ws::{Broadcast, Message, WebSocket, WebSocketUpgrade}; +use std::sync::Arc; + +/// Chat message for JSON serialization +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ChatMessage { + username: String, + content: String, + timestamp: u64, +} + +/// Application state containing the broadcast channel +struct AppState { + chat_broadcast: Arc, +} + +/// Simple echo WebSocket endpoint +async fn ws_echo(ws: WebSocket) -> WebSocketUpgrade { + ws.on_upgrade(|mut socket| async move { + tracing::info!("New echo connection"); + + while let Some(result) = socket.recv().await { + match result { + Ok(Message::Text(text)) => { + tracing::debug!("Received: {}", text); + if let Err(e) = socket.send(Message::text(format!("Echo: {}", text))).await { + tracing::error!("Send error: {}", e); + break; + } + } + Ok(Message::Binary(data)) => { + if let Err(e) = socket.send(Message::binary(data)).await { + tracing::error!("Send error: {}", e); + break; + } + } + Ok(Message::Ping(data)) => { + let _ = socket.send(Message::pong(data)).await; + } + Ok(Message::Close(_)) => { + tracing::info!("Client disconnected"); + break; + } + Ok(_) => {} + Err(e) => { + tracing::error!("Receive error: {}", e); + break; + } + } + } + }) +} + +/// JSON echo WebSocket endpoint +async fn ws_json(ws: WebSocket) -> WebSocketUpgrade { + ws.on_upgrade(|mut socket| async move { + tracing::info!("New JSON connection"); + + while let Some(result) = socket.recv().await { + match result { + Ok(msg) => { + if msg.is_text() { + // Try to parse as ChatMessage + match msg.as_json::() { + Ok(chat_msg) => { + tracing::info!( + "Message from {}: {}", + chat_msg.username, + chat_msg.content + ); + + // Echo back with modified content + let response = ChatMessage { + username: "server".to_string(), + content: format!("Received: {}", chat_msg.content), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + if let Err(e) = socket.send_json(&response).await { + tracing::error!("Send error: {}", e); + break; + } + } + Err(e) => { + tracing::warn!("Invalid JSON: {}", e); + } + } + } + } + Err(e) => { + tracing::error!("Receive error: {}", e); + break; + } + } + } + }) +} + +/// Chat room WebSocket endpoint with broadcasting +async fn ws_chat(ws: WebSocket, State(state): State>) -> WebSocketUpgrade { + ws.on_upgrade(move |socket| async move { + let (mut sender, mut receiver) = socket.split(); + let broadcast = state.chat_broadcast.clone(); + + // Subscribe to broadcast messages + let mut broadcast_rx = broadcast.subscribe(); + + tracing::info!( + "New chat connection (total: {})", + broadcast.subscriber_count() + ); + + // Announce new user + let _ = broadcast.send_json(&ChatMessage { + username: "system".to_string(), + content: "A new user has joined".to_string(), + timestamp: now(), + }); + + // Spawn task to forward broadcasts to this client + let send_task = tokio::spawn(async move { + while let Some(result) = broadcast_rx.recv().await { + match result { + Ok(msg) => { + if let Err(e) = sender.send(msg).await { + tracing::debug!("Send error: {}", e); + break; + } + } + Err(e) => { + tracing::debug!("Broadcast error: {}", e); + } + } + } + }); + + // Handle incoming messages + while let Some(result) = receiver.recv().await { + match result { + Ok(msg) => { + if let Some(text) = msg.as_text() { + // Broadcast to all clients + if let Ok(chat_msg) = serde_json::from_str::(text) { + broadcast.send(Message::text(text.to_string())); + tracing::info!("[{}] {}", chat_msg.username, chat_msg.content); + } + } + } + Err(e) => { + tracing::debug!("Receive error: {}", e); + break; + } + } + } + + // Clean up + send_task.abort(); + + // Announce user left + let _ = broadcast.send_json(&ChatMessage { + username: "system".to_string(), + content: "A user has left".to_string(), + timestamp: now(), + }); + + tracing::info!( + "Chat connection closed (remaining: {})", + broadcast.subscriber_count() + ); + }) +} + +fn now() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() +} + +/// Index page with WebSocket test client +async fn index() -> Html<&'static str> { + Html( + r#" + + + WebSocket Example + + + +

🔌 WebSocket Example

+ +
+

Echo Test (/ws/echo)

+ + + + +
+
+ +
+

Chat Room (/ws/chat)

+ + + + + +
+
+ + + +"#, + ) +} + +#[rustapi_rs::main] +async fn main() -> std::result::Result<(), Box> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("websocket_example=debug".parse().unwrap()) + .add_directive("info".parse().unwrap()), + ) + .init(); + + let state = Arc::new(AppState { + chat_broadcast: Arc::new(Broadcast::new()), + }); + + let addr = "127.0.0.1:8080"; + tracing::info!("🚀 Server running at http://{}", addr); + tracing::info!("📡 WebSocket endpoints:"); + tracing::info!(" ws://{}/ws/echo - Echo server", addr); + tracing::info!(" ws://{}/ws/json - JSON echo", addr); + tracing::info!(" ws://{}/ws/chat - Chat room", addr); + + RustApi::new() + .state(state) + .route("/", get(index)) + .route("/ws/echo", get(ws_echo)) + .route("/ws/json", get(ws_json)) + .route("/ws/chat", get(ws_chat)) + .run(addr) + .await +}