diff --git a/crates/rohas-codegen/src/generator.rs b/crates/rohas-codegen/src/generator.rs index f2954e5..2116842 100644 --- a/crates/rohas-codegen/src/generator.rs +++ b/crates/rohas-codegen/src/generator.rs @@ -48,6 +48,7 @@ impl Generator { "handlers/events", "handlers/cron", "handlers/websockets", + "middlewares", ]; for dir in &dirs { @@ -77,6 +78,7 @@ impl Generator { typescript::generate_events(schema, output_dir)?; typescript::generate_crons(schema, output_dir)?; typescript::generate_websockets(schema, output_dir)?; + typescript::generate_middlewares(schema, output_dir)?; typescript::generate_index(schema, output_dir)?; info!("Generating TypeScript configuration files"); @@ -98,6 +100,7 @@ impl Generator { python::generate_events(schema, output_dir)?; python::generate_crons(schema, output_dir)?; python::generate_websockets(schema, output_dir)?; + python::generate_middlewares(schema, output_dir)?; python::generate_init(schema, output_dir)?; info!("Generating Python configuration files"); diff --git a/crates/rohas-codegen/src/python.rs b/crates/rohas-codegen/src/python.rs index f0c849c..f0fc686 100644 --- a/crates/rohas-codegen/src/python.rs +++ b/crates/rohas-codegen/src/python.rs @@ -344,6 +344,88 @@ pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> { Ok(()) } +pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> { + use std::collections::HashSet; + + let mut middleware_names = HashSet::new(); + + for api in &schema.apis { + for middleware in &api.middlewares { + middleware_names.insert(middleware.clone()); + } + } + + for ws in &schema.websockets { + for middleware in &ws.middlewares { + middleware_names.insert(middleware.clone()); + } + } + + if middleware_names.is_empty() { + return Ok(()); + } + + let middlewares_dir = output_dir.join("middlewares"); + for middleware_name in middleware_names { + let file_name = format!("{}.py", templates::to_snake_case(&middleware_name)); + let middleware_path = middlewares_dir.join(&file_name); + + if !middleware_path.exists() { + let content = generate_middleware_stub(&middleware_name); + fs::write(middleware_path, content)?; + } + } + + Ok(()) +} + +fn generate_middleware_stub(middleware_name: &str) -> String { + let mut content = String::new(); + + content.push_str("from typing import Dict, Any, Optional\n"); + content.push_str("from generated.state import State\n\n"); + + content.push_str(&format!( + "async def {}_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]:\n", + templates::to_snake_case(middleware_name) + )); + content.push_str(" \"\"\"\n"); + content.push_str(&format!(" Middleware function for {}.\n\n", middleware_name)); + content.push_str(" Args:\n"); + content.push_str(" context: Request context containing:\n"); + content.push_str(" - payload: Request payload (for APIs)\n"); + content.push_str(" - query_params: Query parameters (for APIs)\n"); + content.push_str(" - connection: WebSocket connection info (for WebSockets)\n"); + content.push_str(" - websocket_name: WebSocket name (for WebSockets)\n"); + content.push_str(" - api_name: API name (for APIs)\n"); + content.push_str(" - trace_id: Trace ID\n"); + content.push_str(" state: State object for logging and triggering events\n\n"); + content.push_str(" Returns:\n"); + content.push_str(" Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys,\n"); + content.push_str(" or None to pass through unchanged. Return a dict with 'error' key to reject the request.\n\n"); + content.push_str(" To reject the request, raise an exception \n"); + content.push_str(" \"\"\"\n"); + content.push_str(" # TODO: Implement middleware logic\n"); + content.push_str(" # Example: Validate authentication\n"); + content.push_str(" # Example: Rate limiting\n"); + content.push_str(" # Example: Logging\n"); + content.push_str(" # Example: Modify payload/query_params\n"); + content.push_str(" # \n"); + content.push_str(" # To modify the request:\n"); + content.push_str(" # return {\n"); + content.push_str(" # 'payload': modified_payload,\n"); + content.push_str(" # 'query_params': modified_query_params\n"); + content.push_str(" # }\n"); + content.push_str(" # \n"); + content.push_str(" # To reject the request:\n"); + content.push_str(" # raise Exception('Access denied')\n"); + content.push_str(" \n"); + content.push_str(" # Pass through unchanged\n"); + content.push_str(" return None\n"); + + content +} + fn generate_websocket_content(ws: &WebSocket) -> String { let mut content = String::new(); diff --git a/crates/rohas-codegen/src/typescript.rs b/crates/rohas-codegen/src/typescript.rs index f04f475..9b3f7e0 100644 --- a/crates/rohas-codegen/src/typescript.rs +++ b/crates/rohas-codegen/src/typescript.rs @@ -504,6 +504,99 @@ pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> { Ok(()) } +pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> { + use std::collections::HashSet; + + let mut middleware_names = HashSet::new(); + + for api in &schema.apis { + for middleware in &api.middlewares { + middleware_names.insert(middleware.clone()); + } + } + + for ws in &schema.websockets { + for middleware in &ws.middlewares { + middleware_names.insert(middleware.clone()); + } + } + + if middleware_names.is_empty() { + return Ok(()); + } + + let middlewares_dir = output_dir.join("middlewares"); + for middleware_name in middleware_names { + let file_name = format!("{}.ts", middleware_name); + let middleware_path = middlewares_dir.join(&file_name); + + if !middleware_path.exists() { + let content = generate_middleware_stub(&middleware_name); + fs::write(middleware_path, content)?; + } + } + + Ok(()) +} + +fn generate_middleware_stub(middleware_name: &str) -> String { + let mut content = String::new(); + + content.push_str("import { State } from '@generated/state';\n\n"); + + content.push_str("export interface MiddlewareContext {\n"); + content.push_str(" payload?: any;\n"); + content.push_str(" query_params?: Record;\n"); + content.push_str(" connection?: any;\n"); + content.push_str(" websocket_name?: string;\n"); + content.push_str(" api_name?: string;\n"); + content.push_str(" trace_id?: string;\n"); + content.push_str("}\n\n"); + + content.push_str(&format!( + "export async function {}Middleware(\n", + middleware_name + )); + content.push_str(" context: MiddlewareContext,\n"); + content.push_str(" state: State\n"); + content.push_str("): Promise {\n"); + content.push_str(" /**\n"); + content.push_str(&format!(" * Middleware function for {}.\n", middleware_name)); + content.push_str(" * \n"); + content.push_str(" * @param context - Request context containing:\n"); + content.push_str(" * - payload: Request payload (for APIs)\n"); + content.push_str(" * - query_params: Query parameters (for APIs)\n"); + content.push_str(" * - connection: WebSocket connection info (for WebSockets)\n"); + content.push_str(" * - websocket_name: WebSocket name (for WebSockets)\n"); + content.push_str(" * - api_name: API name (for APIs)\n"); + content.push_str(" * - trace_id: Trace ID\n"); + content.push_str(" * @param state - State object for logging and triggering events\n"); + content.push_str(" * @returns Modified context with 'payload' and/or 'query_params' keys,\n"); + content.push_str(" * or null to pass through unchanged. Throw an error to reject the request.\n"); + content.push_str(" * \n"); + content.push_str(" * To reject the request, throw an error:\n"); + content.push_str(" * throw new Error('Access denied');\n"); + content.push_str(" * \n"); + content.push_str(" * To modify the request:\n"); + content.push_str(" * return {\n"); + content.push_str(" * ...context,\n"); + content.push_str(" * payload: modifiedPayload,\n"); + content.push_str(" * query_params: modifiedQueryParams\n"); + content.push_str(" * };\n"); + content.push_str(" */\n"); + content.push_str(" // TODO: Implement middleware logic\n"); + content.push_str(" // Example: Validate authentication\n"); + content.push_str(" // Example: Rate limiting\n"); + content.push_str(" // Example: Logging\n"); + content.push_str(" // Example: Modify payload/query_params\n"); + content.push_str(" \n"); + content.push_str(" // Pass through unchanged\n"); + content.push_str(" return null;\n"); + content.push_str("}\n"); + + content +} + fn generate_websocket_content(ws: &WebSocket) -> String { let mut content = String::new(); diff --git a/crates/rohas-engine/src/api.rs b/crates/rohas-engine/src/api.rs index 55dd7af..a2fede9 100644 --- a/crates/rohas-engine/src/api.rs +++ b/crates/rohas-engine/src/api.rs @@ -12,7 +12,7 @@ use chrono::Utc; use rohas_codegen::templates; use rohas_parser::{HttpMethod, Schema}; use rohas_runtime::Executor; -use serde_json::Value; +use serde_json::{json, Value}; use std::{collections::HashMap, sync::Arc}; use tracing::{debug, info_span}; @@ -307,11 +307,31 @@ async fn api_handler( } } + let middleware_result = execute_middlewares( + state.clone(), + &api.middlewares, + payload.clone(), + query_params.clone(), + &trace_id, + &api_name, + ) + .await; + + if let Err(e) = middleware_result { + state + .trace_store + .complete_trace(&trace_id, crate::trace::TraceStatus::Failed, Some(e.clone())) + .await; + return Err(ApiError::BadRequest(e)); + } + + let (final_payload, final_query_params) = middleware_result.unwrap(); + let result = execute_handler( state.clone(), handler_name.clone(), - payload, - query_params, + final_payload, + final_query_params, api_triggers, api_name, trace_id.clone(), @@ -383,6 +403,98 @@ fn parse_query_string(query: &str) -> HashMap { .collect() } +async fn execute_middlewares( + state: ApiState, + middlewares: &[String], + mut payload: Value, + mut query_params: HashMap, + trace_id: &str, + api_name: &str, +) -> Result<(Value, HashMap), String> { + if middlewares.is_empty() { + return Ok((payload, query_params)); + } + + debug!("Executing {} middlewares for API: {}", middlewares.len(), api_name); + + for middleware_name in middlewares { + let middleware_handler_name = match state.config.language { + config::Language::TypeScript => middleware_name.clone(), + config::Language::Python => templates::to_snake_case(middleware_name.as_str()), + }; + + debug!("Executing middleware: {}", middleware_handler_name); + + let middleware_context = json!({ + "payload": payload, + "query_params": query_params, + "api_name": api_name, + "trace_id": trace_id, + }); + + let mut context = rohas_runtime::HandlerContext::new(&middleware_handler_name, middleware_context); + context.metadata.insert("middleware".to_string(), "true".to_string()); + context.metadata.insert("api_name".to_string(), api_name.to_string()); + + let start = std::time::Instant::now(); + let result = state.executor.execute_with_context(context).await; + let duration_ms = start.elapsed().as_millis() as u64; + + if let Ok(ref exec_result) = result { + state + .trace_store + .add_step( + trace_id, + format!("middleware:{}", middleware_handler_name), + duration_ms.max(exec_result.execution_time_ms), + exec_result.success, + exec_result.error.clone(), + ) + .await; + } + + match result { + Ok(exec_result) => { + if !exec_result.success { + let error_msg = exec_result.error.unwrap_or_else(|| { + format!("Middleware '{}' rejected the request", middleware_name) + }); + return Err(error_msg); + } + + if let Some(data) = exec_result.data { + if let Value::Object(middleware_data) = data { + if let Some(new_payload) = middleware_data.get("payload") { + payload = new_payload.clone(); + } + + if let Some(new_query_params) = middleware_data.get("query_params") { + if let Value::Object(params_obj) = new_query_params { + query_params = params_obj + .iter() + .filter_map(|(k, v)| { + if let Value::String(s) = v { + Some((k.clone(), s.clone())) + } else { + None + } + }) + .collect(); + } + } + } + } + } + Err(e) => { + let error_msg = format!("Middleware '{}' execution failed: {}", middleware_name, e); + return Err(error_msg); + } + } + } + + Ok((payload, query_params)) +} + async fn execute_handler( state: ApiState, handler_name: String, diff --git a/crates/rohas-engine/src/workbench.rs b/crates/rohas-engine/src/workbench.rs index 818772b..e94841e 100644 --- a/crates/rohas-engine/src/workbench.rs +++ b/crates/rohas-engine/src/workbench.rs @@ -859,6 +859,7 @@ pub struct WebSocketEndpoint { pub on_disconnect: Vec, pub triggers: Vec, pub broadcast: bool, + pub middlewares: Vec, } #[derive(Serialize, Deserialize)] @@ -913,6 +914,7 @@ async fn get_endpoints(State(state): State) -> Result Result<(), String> { + if middlewares.is_empty() { + return Ok(()); + } + + debug!("Executing {} middlewares for WebSocket: {}", middlewares.len(), ws_name); + + for middleware_name in middlewares { + let middleware_handler_name = match state.config.language { + config::Language::TypeScript => middleware_name.clone(), + config::Language::Python => templates::to_snake_case(middleware_name.as_str()), + }; + + debug!("Executing WebSocket middleware: {}", middleware_handler_name); + + let mut context = rohas_runtime::HandlerContext::new(&middleware_handler_name, payload.clone()); + context.metadata.insert("middleware".to_string(), "true".to_string()); + context.metadata.insert("websocket_name".to_string(), ws_name.to_string()); + + let start = std::time::Instant::now(); + let result = state.executor.execute_with_context(context).await; + let duration_ms = start.elapsed().as_millis() as u64; + + if let Ok(ref exec_result) = result { + state + .trace_store + .add_step( + trace_id, + format!("middleware:{}", middleware_handler_name), + duration_ms.max(exec_result.execution_time_ms), + exec_result.success, + exec_result.error.clone(), + ) + .await; + } + + match result { + Ok(exec_result) => { + if !exec_result.success { + let error_msg = exec_result.error.unwrap_or_else(|| { + format!("Middleware '{}' rejected the WebSocket connection", middleware_name) + }); + return Err(error_msg); + } + } + Err(e) => { + let error_msg = format!("Middleware '{}' execution failed: {}", middleware_name, e); + return Err(error_msg); + } + } + } + + Ok(()) +} + pub async fn websocket_handler(socket: WebSocket, state: ApiState, ws_name: String) { let connection_id = Uuid::new_v4().to_string(); let ws_config = state @@ -41,6 +102,31 @@ pub async fn websocket_handler(socket: WebSocket, state: ApiState, ws_name: Stri ) .await; + if !ws_config.middlewares.is_empty() { + let middleware_payload = json!({ + "connection": connection.clone(), + "websocket_name": ws_name, + }); + + let middleware_result = execute_websocket_middlewares( + state.clone(), + &ws_config.middlewares, + middleware_payload, + &connection_trace_id, + &ws_name, + ) + .await; + + if let Err(e) = middleware_result { + error!("WebSocket middleware rejected connection: {}", e); + state + .trace_store + .complete_trace(&connection_trace_id, crate::trace::TraceStatus::Failed, Some(e)) + .await; + return; + } + } + if !ws_config.on_connect.is_empty() { for handler_name in &ws_config.on_connect { let handler_name = match state.config.language { diff --git a/crates/rohas-parser/src/ast.rs b/crates/rohas-parser/src/ast.rs index da3e59a..94c15f4 100644 --- a/crates/rohas-parser/src/ast.rs +++ b/crates/rohas-parser/src/ast.rs @@ -220,6 +220,7 @@ pub struct WebSocket { pub on_disconnect: Vec, pub triggers: Vec, pub broadcast: bool, + pub middlewares: Vec, } #[cfg(test)] diff --git a/crates/rohas-parser/src/parser.rs b/crates/rohas-parser/src/parser.rs index 83ef0bf..cf14bc2 100644 --- a/crates/rohas-parser/src/parser.rs +++ b/crates/rohas-parser/src/parser.rs @@ -199,7 +199,7 @@ impl Parser { Rule::trigger_list => { triggers = Self::parse_string_list(key)?; } - Rule::string_list => { + Rule::string_list | Rule::middleware_list => { middlewares = Self::parse_string_list(key)?; } _ => {} @@ -357,6 +357,7 @@ impl Parser { let mut on_disconnect = Vec::new(); let mut triggers = Vec::new(); let mut broadcast = false; + let mut middlewares = Vec::new(); for prop in inner { if prop.as_rule() == Rule::ws_property { @@ -387,6 +388,11 @@ impl Parser { Rule::trigger_list => { triggers = Self::parse_string_list(key)?; } + Rule::string_list | Rule::middleware_list => { + if prop_text.starts_with("middlewares:") { + middlewares = Self::parse_string_list(key)?; + } + } Rule::boolean => { if prop_text.starts_with("broadcast:") { broadcast = key.as_str() == "true"; @@ -407,6 +413,7 @@ impl Parser { on_disconnect, triggers, broadcast, + middlewares, }) } } diff --git a/crates/rohas-parser/src/rohas.pest b/crates/rohas-parser/src/rohas.pest index fb10993..3ddd164 100644 --- a/crates/rohas-parser/src/rohas.pest +++ b/crates/rohas-parser/src/rohas.pest @@ -34,12 +34,13 @@ api_property = { | ("body:" ~ ident) | ("response:" ~ ident) | ("triggers:" ~ trigger_list) - | ("middlewares:" ~ string_list) + | ("middlewares:" ~ middleware_list) } -http_method = { "GET" | "POST" | "PUT" | "PATCH" | "DELETE" } -trigger_list = { "[" ~ ident ~ ("," ~ ident)* ~ "]" } -string_list = { "[" ~ string ~ ("," ~ string)* ~ "]" } +http_method = { "GET" | "POST" | "PUT" | "PATCH" | "DELETE" } +trigger_list = { "[" ~ ident ~ ("," ~ ident)* ~ "]" } +string_list = { "[" ~ string ~ ("," ~ string)* ~ "]" } +middleware_list = { "[" ~ (ident | string) ~ ("," ~ (ident | string))* ~ "]" } // Event definition event = { "event" ~ ident ~ "{" ~ event_property* ~ "}" } @@ -71,4 +72,5 @@ ws_property = { | ("onDisconnect:" ~ handler_list) | ("triggers:" ~ trigger_list) | ("broadcast:" ~ boolean) + | ("middlewares:" ~ middleware_list) } diff --git a/crates/rohas-runtime/src/executor.rs b/crates/rohas-runtime/src/executor.rs index e2f0543..c5f20b5 100644 --- a/crates/rohas-runtime/src/executor.rs +++ b/crates/rohas-runtime/src/executor.rs @@ -110,10 +110,21 @@ impl Executor { fn resolve_handler_path(&self, handler_name: &str) -> Result { let handlers_dir = self.config.project_root.join("src/handlers"); + let middlewares_dir = self.config.project_root.join("src/middlewares"); let snake_case_name = templates::to_snake_case(handler_name); let possible_paths = [ + middlewares_dir.join(format!( + "{}.{}", + snake_case_name, + self.config.language.file_extension() + )), + middlewares_dir.join(format!( + "{}.{}", + handler_name, + self.config.language.file_extension() + )), handlers_dir.join(format!( "api/{}.{}", handler_name, diff --git a/crates/rohas-runtime/src/python_runtime.rs b/crates/rohas-runtime/src/python_runtime.rs index 4a7e676..fa36ee0 100644 --- a/crates/rohas-runtime/src/python_runtime.rs +++ b/crates/rohas-runtime/src/python_runtime.rs @@ -117,18 +117,22 @@ impl PythonRuntime { let sys = py.import("sys")?; let sys_path = sys.getattr("path")?; + if let Some(parent) = handler_path.parent() { + sys_path.call_method1("insert", (0, parent.to_str().unwrap()))?; + } + if let Some(root) = project_root { let src_path = root.join("src"); if src_path.exists() { - sys_path.call_method1("insert", (0, src_path.to_str().unwrap()))?; - debug!("Added to sys.path: {:?}", src_path); + let src_path_str = src_path.to_str().unwrap(); + let path_list: Vec = sys_path.extract()?; + if !path_list.contains(&src_path_str.to_string()) { + sys_path.call_method1("append", (src_path_str,))?; + debug!("Added to sys.path (appended): {:?}", src_path); + } } } - if let Some(parent) = handler_path.parent() { - sys_path.call_method1("insert", (0, parent.to_str().unwrap()))?; - } - let module_name = handler_path .file_stem() .and_then(|s| s.to_str()) @@ -161,8 +165,17 @@ impl PythonRuntime { .map(|n| n == "websockets") .unwrap_or(false); + let is_middleware = handler_path + .parent() + .and_then(|p| p.file_name()) + .and_then(|n| n.to_str()) + .map(|n| n == "middlewares") + .unwrap_or(false); + let function_name = if is_event_handler || is_websocket_handler { handler_name.to_string() + } else if is_middleware { + format!("{}_middleware", templates::to_snake_case(handler_name)) } else { Self::extract_function_name(handler_name) }; diff --git a/examples/hello-world/.gitignore b/examples/hello-world/.gitignore index ae689fe..bdf3be8 100644 --- a/examples/hello-world/.gitignore +++ b/examples/hello-world/.gitignore @@ -49,3 +49,4 @@ coverage/ # Rohas compiled output .rohas/ +src/generated/ diff --git a/examples/hello-world/schema/api/test_ws.ro b/examples/hello-world/schema/api/test_ws.ro index e10b9a2..c7e26c4 100644 --- a/examples/hello-world/schema/api/test_ws.ro +++ b/examples/hello-world/schema/api/test_ws.ro @@ -6,4 +6,5 @@ ws HelloWorld { onDisconnect: [on_disconnect_handler] triggers: [UserCreated] broadcast: true + middlewares: [auth, rate_limit, logging] } diff --git a/examples/hello-world/schema/api/user_api.ro b/examples/hello-world/schema/api/user_api.ro index 3c6cab6..40db2a5 100644 --- a/examples/hello-world/schema/api/user_api.ro +++ b/examples/hello-world/schema/api/user_api.ro @@ -16,6 +16,7 @@ api Test { path: "/test" response: String triggers: [UserCreated] + middlewares: [auth, rate_limit, logging] } diff --git a/examples/hello-world/src/middlewares/auth.py b/examples/hello-world/src/middlewares/auth.py new file mode 100644 index 0000000..b117404 --- /dev/null +++ b/examples/hello-world/src/middlewares/auth.py @@ -0,0 +1,40 @@ +from typing import Dict, Any, Optional +from generated.state import State + +async def auth_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]: + """ + Middleware function for auth. + + Args: + context: Request context containing: + - payload: Request payload (for APIs) + - query_params: Query parameters (for APIs) + - connection: WebSocket connection info (for WebSockets) + - websocket_name: WebSocket name (for WebSockets) + - api_name: API name (for APIs) + - trace_id: Trace ID + state: State object for logging and triggering events + + Returns: + Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys, + or None to pass through unchanged. Return a dict with 'error' key to reject the request. + + To reject the request, raise an exception + """ + # TODO: Implement middleware logic + # Example: Validate authentication + # Example: Rate limiting + # Example: Logging + # Example: Modify payload/query_params + # + # To modify the request: + # return { + # 'payload': modified_payload, + # 'query_params': modified_query_params + # } + # + # To reject the request: + # raise Exception('Access denied') + + # Pass through unchanged + return None diff --git a/examples/hello-world/src/middlewares/logging.py b/examples/hello-world/src/middlewares/logging.py new file mode 100644 index 0000000..3137b65 --- /dev/null +++ b/examples/hello-world/src/middlewares/logging.py @@ -0,0 +1,40 @@ +from typing import Dict, Any, Optional +from generated.state import State + +async def logging_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]: + """ + Middleware function for logging. + + Args: + context: Request context containing: + - payload: Request payload (for APIs) + - query_params: Query parameters (for APIs) + - connection: WebSocket connection info (for WebSockets) + - websocket_name: WebSocket name (for WebSockets) + - api_name: API name (for APIs) + - trace_id: Trace ID + state: State object for logging and triggering events + + Returns: + Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys, + or None to pass through unchanged. Return a dict with 'error' key to reject the request. + + To reject the request, raise an exception + """ + # TODO: Implement middleware logic + # Example: Validate authentication + # Example: Rate limiting + # Example: Logging + # Example: Modify payload/query_params + # + # To modify the request: + # return { + # 'payload': modified_payload, + # 'query_params': modified_query_params + # } + # + # To reject the request: + # raise Exception('Access denied') + + # Pass through unchanged + return None diff --git a/examples/hello-world/src/middlewares/rate_limit.py b/examples/hello-world/src/middlewares/rate_limit.py new file mode 100644 index 0000000..42ec858 --- /dev/null +++ b/examples/hello-world/src/middlewares/rate_limit.py @@ -0,0 +1,40 @@ +from typing import Dict, Any, Optional +from generated.state import State + +async def rate_limit_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]: + """ + Middleware function for rate_limit. + + Args: + context: Request context containing: + - payload: Request payload (for APIs) + - query_params: Query parameters (for APIs) + - connection: WebSocket connection info (for WebSockets) + - websocket_name: WebSocket name (for WebSockets) + - api_name: API name (for APIs) + - trace_id: Trace ID + state: State object for logging and triggering events + + Returns: + Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys, + or None to pass through unchanged. Return a dict with 'error' key to reject the request. + + To reject the request, raise an exception + """ + # TODO: Implement middleware logic + # Example: Validate authentication + # Example: Rate limiting + # Example: Logging + # Example: Modify payload/query_params + # + # To modify the request: + # return { + # 'payload': modified_payload, + # 'query_params': modified_query_params + # } + # + # To reject the request: + # raise Exception('Access denied') + + # Pass through unchanged + return None diff --git a/workbench/src/lib/workbench-data.ts b/workbench/src/lib/workbench-data.ts index ec11fa3..c48a07a 100644 --- a/workbench/src/lib/workbench-data.ts +++ b/workbench/src/lib/workbench-data.ts @@ -310,6 +310,7 @@ export type WebSocketEndpoint = { on_disconnect: string[]; triggers: string[]; broadcast: boolean; + middlewares: string[]; }; export type CronJob = {