Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/rohas-codegen/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl Generator {
"handlers/events",
"handlers/cron",
"handlers/websockets",
"middlewares",
];

for dir in &dirs {
Expand Down Expand Up @@ -77,6 +78,7 @@ impl Generator {
typescript::generate_events(schema, output_dir)?;
typescript::generate_crons(schema, output_dir)?;
typescript::generate_websockets(schema, output_dir)?;
typescript::generate_middlewares(schema, output_dir)?;
typescript::generate_index(schema, output_dir)?;

info!("Generating TypeScript configuration files");
Expand All @@ -98,6 +100,7 @@ impl Generator {
python::generate_events(schema, output_dir)?;
python::generate_crons(schema, output_dir)?;
python::generate_websockets(schema, output_dir)?;
python::generate_middlewares(schema, output_dir)?;
python::generate_init(schema, output_dir)?;

info!("Generating Python configuration files");
Expand Down
82 changes: 82 additions & 0 deletions crates/rohas-codegen/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,88 @@ pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> {
Ok(())
}

pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> {
use std::collections::HashSet;

let mut middleware_names = HashSet::new();

for api in &schema.apis {
for middleware in &api.middlewares {
middleware_names.insert(middleware.clone());
}
}

for ws in &schema.websockets {
for middleware in &ws.middlewares {
middleware_names.insert(middleware.clone());
}
}

if middleware_names.is_empty() {
return Ok(());
}

let middlewares_dir = output_dir.join("middlewares");
for middleware_name in middleware_names {
let file_name = format!("{}.py", templates::to_snake_case(&middleware_name));
let middleware_path = middlewares_dir.join(&file_name);

if !middleware_path.exists() {
let content = generate_middleware_stub(&middleware_name);
fs::write(middleware_path, content)?;
}
}

Ok(())
}

fn generate_middleware_stub(middleware_name: &str) -> String {
let mut content = String::new();

content.push_str("from typing import Dict, Any, Optional\n");
content.push_str("from generated.state import State\n\n");

content.push_str(&format!(
"async def {}_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]:\n",
templates::to_snake_case(middleware_name)
));
content.push_str(" \"\"\"\n");
content.push_str(&format!(" Middleware function for {}.\n\n", middleware_name));
content.push_str(" Args:\n");
content.push_str(" context: Request context containing:\n");
content.push_str(" - payload: Request payload (for APIs)\n");
content.push_str(" - query_params: Query parameters (for APIs)\n");
content.push_str(" - connection: WebSocket connection info (for WebSockets)\n");
content.push_str(" - websocket_name: WebSocket name (for WebSockets)\n");
content.push_str(" - api_name: API name (for APIs)\n");
content.push_str(" - trace_id: Trace ID\n");
content.push_str(" state: State object for logging and triggering events\n\n");
content.push_str(" Returns:\n");
content.push_str(" Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys,\n");
content.push_str(" or None to pass through unchanged. Return a dict with 'error' key to reject the request.\n\n");
content.push_str(" To reject the request, raise an exception \n");
content.push_str(" \"\"\"\n");
content.push_str(" # TODO: Implement middleware logic\n");
content.push_str(" # Example: Validate authentication\n");
content.push_str(" # Example: Rate limiting\n");
content.push_str(" # Example: Logging\n");
content.push_str(" # Example: Modify payload/query_params\n");
content.push_str(" # \n");
content.push_str(" # To modify the request:\n");
content.push_str(" # return {\n");
content.push_str(" # 'payload': modified_payload,\n");
content.push_str(" # 'query_params': modified_query_params\n");
content.push_str(" # }\n");
content.push_str(" # \n");
content.push_str(" # To reject the request:\n");
content.push_str(" # raise Exception('Access denied')\n");
content.push_str(" \n");
content.push_str(" # Pass through unchanged\n");
content.push_str(" return None\n");

content
}

fn generate_websocket_content(ws: &WebSocket) -> String {
let mut content = String::new();

Expand Down
93 changes: 93 additions & 0 deletions crates/rohas-codegen/src/typescript.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,99 @@ pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> {
Ok(())
}

pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> {
use std::collections::HashSet;

let mut middleware_names = HashSet::new();

for api in &schema.apis {
for middleware in &api.middlewares {
middleware_names.insert(middleware.clone());
}
}

for ws in &schema.websockets {
for middleware in &ws.middlewares {
middleware_names.insert(middleware.clone());
}
}

if middleware_names.is_empty() {
return Ok(());
}

let middlewares_dir = output_dir.join("middlewares");
for middleware_name in middleware_names {
let file_name = format!("{}.ts", middleware_name);
let middleware_path = middlewares_dir.join(&file_name);

if !middleware_path.exists() {
let content = generate_middleware_stub(&middleware_name);
fs::write(middleware_path, content)?;
}
}

Ok(())
}

fn generate_middleware_stub(middleware_name: &str) -> String {
let mut content = String::new();

content.push_str("import { State } from '@generated/state';\n\n");

content.push_str("export interface MiddlewareContext {\n");
content.push_str(" payload?: any;\n");
content.push_str(" query_params?: Record<string, string>;\n");
content.push_str(" connection?: any;\n");
content.push_str(" websocket_name?: string;\n");
content.push_str(" api_name?: string;\n");
content.push_str(" trace_id?: string;\n");
content.push_str("}\n\n");

content.push_str(&format!(
"export async function {}Middleware(\n",
middleware_name
));
content.push_str(" context: MiddlewareContext,\n");
content.push_str(" state: State\n");
content.push_str("): Promise<MiddlewareContext | null> {\n");
content.push_str(" /**\n");
content.push_str(&format!(" * Middleware function for {}.\n", middleware_name));
content.push_str(" * \n");
content.push_str(" * @param context - Request context containing:\n");
content.push_str(" * - payload: Request payload (for APIs)\n");
content.push_str(" * - query_params: Query parameters (for APIs)\n");
content.push_str(" * - connection: WebSocket connection info (for WebSockets)\n");
content.push_str(" * - websocket_name: WebSocket name (for WebSockets)\n");
content.push_str(" * - api_name: API name (for APIs)\n");
content.push_str(" * - trace_id: Trace ID\n");
content.push_str(" * @param state - State object for logging and triggering events\n");
content.push_str(" * @returns Modified context with 'payload' and/or 'query_params' keys,\n");
content.push_str(" * or null to pass through unchanged. Throw an error to reject the request.\n");
content.push_str(" * \n");
content.push_str(" * To reject the request, throw an error:\n");
content.push_str(" * throw new Error('Access denied');\n");
content.push_str(" * \n");
content.push_str(" * To modify the request:\n");
content.push_str(" * return {\n");
content.push_str(" * ...context,\n");
content.push_str(" * payload: modifiedPayload,\n");
content.push_str(" * query_params: modifiedQueryParams\n");
content.push_str(" * };\n");
content.push_str(" */\n");
content.push_str(" // TODO: Implement middleware logic\n");
content.push_str(" // Example: Validate authentication\n");
content.push_str(" // Example: Rate limiting\n");
content.push_str(" // Example: Logging\n");
content.push_str(" // Example: Modify payload/query_params\n");
content.push_str(" \n");
content.push_str(" // Pass through unchanged\n");
content.push_str(" return null;\n");
content.push_str("}\n");

content
}

fn generate_websocket_content(ws: &WebSocket) -> String {
let mut content = String::new();

Expand Down
118 changes: 115 additions & 3 deletions crates/rohas-engine/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use chrono::Utc;
use rohas_codegen::templates;
use rohas_parser::{HttpMethod, Schema};
use rohas_runtime::Executor;
use serde_json::Value;
use serde_json::{json, Value};
use std::{collections::HashMap, sync::Arc};
use tracing::{debug, info_span};

Expand Down Expand Up @@ -307,11 +307,31 @@ async fn api_handler(
}
}

let middleware_result = execute_middlewares(
state.clone(),
&api.middlewares,
payload.clone(),
query_params.clone(),
&trace_id,
&api_name,
)
.await;

if let Err(e) = middleware_result {
state
.trace_store
.complete_trace(&trace_id, crate::trace::TraceStatus::Failed, Some(e.clone()))
.await;
return Err(ApiError::BadRequest(e));
}

let (final_payload, final_query_params) = middleware_result.unwrap();

let result = execute_handler(
state.clone(),
handler_name.clone(),
payload,
query_params,
final_payload,
final_query_params,
api_triggers,
api_name,
trace_id.clone(),
Expand Down Expand Up @@ -383,6 +403,98 @@ fn parse_query_string(query: &str) -> HashMap<String, String> {
.collect()
}

async fn execute_middlewares(
state: ApiState,
middlewares: &[String],
mut payload: Value,
mut query_params: HashMap<String, String>,
trace_id: &str,
api_name: &str,
) -> Result<(Value, HashMap<String, String>), String> {
if middlewares.is_empty() {
return Ok((payload, query_params));
}

debug!("Executing {} middlewares for API: {}", middlewares.len(), api_name);

for middleware_name in middlewares {
let middleware_handler_name = match state.config.language {
config::Language::TypeScript => middleware_name.clone(),
config::Language::Python => templates::to_snake_case(middleware_name.as_str()),
};

debug!("Executing middleware: {}", middleware_handler_name);

let middleware_context = json!({
"payload": payload,
"query_params": query_params,
"api_name": api_name,
"trace_id": trace_id,
});

let mut context = rohas_runtime::HandlerContext::new(&middleware_handler_name, middleware_context);
context.metadata.insert("middleware".to_string(), "true".to_string());
context.metadata.insert("api_name".to_string(), api_name.to_string());

let start = std::time::Instant::now();
let result = state.executor.execute_with_context(context).await;
let duration_ms = start.elapsed().as_millis() as u64;

if let Ok(ref exec_result) = result {
state
.trace_store
.add_step(
trace_id,
format!("middleware:{}", middleware_handler_name),
duration_ms.max(exec_result.execution_time_ms),
exec_result.success,
exec_result.error.clone(),
)
.await;
}

match result {
Ok(exec_result) => {
if !exec_result.success {
let error_msg = exec_result.error.unwrap_or_else(|| {
format!("Middleware '{}' rejected the request", middleware_name)
});
return Err(error_msg);
}

if let Some(data) = exec_result.data {
if let Value::Object(middleware_data) = data {
if let Some(new_payload) = middleware_data.get("payload") {
payload = new_payload.clone();
}

if let Some(new_query_params) = middleware_data.get("query_params") {
if let Value::Object(params_obj) = new_query_params {
query_params = params_obj
.iter()
.filter_map(|(k, v)| {
if let Value::String(s) = v {
Some((k.clone(), s.clone()))
} else {
None
}
})
.collect();
}
}
}
}
}
Err(e) => {
let error_msg = format!("Middleware '{}' execution failed: {}", middleware_name, e);
return Err(error_msg);
}
}
}

Ok((payload, query_params))
}

async fn execute_handler(
state: ApiState,
handler_name: String,
Expand Down
2 changes: 2 additions & 0 deletions crates/rohas-engine/src/workbench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ pub struct WebSocketEndpoint {
pub on_disconnect: Vec<String>,
pub triggers: Vec<String>,
pub broadcast: bool,
pub middlewares: Vec<String>,
}

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -913,6 +914,7 @@ async fn get_endpoints(State(state): State<ApiState>) -> Result<Response, Workbe
on_disconnect: ws.on_disconnect.clone(),
triggers: ws.triggers.clone(),
broadcast: ws.broadcast,
middlewares: ws.middlewares.clone(),
})
.collect();

Expand Down
Loading