diff --git a/src/http/mod.rs b/src/http/mod.rs index 3dd0c56..3e54065 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -3,6 +3,7 @@ use axum::{ Router, }; mod health_check; +mod subscription; mod types; mod unsubscription; @@ -14,4 +15,5 @@ pub fn router() -> Router { Router::new() .route("/health_check", get(health_check::health_check)) .route("/unsubscribe", post(unsubscription::handle_unsubscribe)) + .route("/subscriptions", post(subscription::create_subscription)) } diff --git a/src/http/subscription.rs b/src/http/subscription.rs new file mode 100644 index 0000000..2818b09 --- /dev/null +++ b/src/http/subscription.rs @@ -0,0 +1,67 @@ +use axum::{extract::State, http::StatusCode, Json}; +use sqlx::PgPool; + +use super::types::{CreateSubscriptionRequest, CreateSubscriptionResponse}; +use crate::AppState; + +pub async fn create_subscription( + State(state): State, + Json(payload): Json, +) -> Result, StatusCode> { + if payload.from_token.len() != payload.percentage.len() { + return Err(StatusCode::BAD_REQUEST); + } + + if !payload.to_token.starts_with("0x") && payload.to_token.len() != 42 { + return Err(StatusCode::BAD_REQUEST); + } + + if !payload.wallet_address.starts_with("0x") && payload.wallet_address.len() != 42 { + return Err(StatusCode::BAD_REQUEST); + } + + let mut tx = state + .db + .pool + .begin() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + sqlx::query!( + r#" + INSERT INTO swap_subscription (wallet_address, to_token, is_active) + VALUES ($1, $2, true) + ON CONFLICT (wallet_address) + DO UPDATE SET to_token = $2, is_active = true, updated_at = NOW() + "#, + payload.wallet_address, + payload.to_token, + ) + .execute(&mut *tx) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + for (token, percentage) in payload.from_token.iter().zip(payload.percentage.iter()) { + sqlx::query!( + r#" + INSERT INTO swap_subscription_from_token + (wallet_address, from_token, percentage) + VALUES ($1, $2, $3) + "#, + payload.wallet_address, + token, + percentage, + ) + .execute(&mut *tx) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } + + tx.commit() + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(CreateSubscriptionResponse { + wallet_address: payload.wallet_address, + })) +} diff --git a/src/http/types.rs b/src/http/types.rs index 2e77464..d074a08 100644 --- a/src/http/types.rs +++ b/src/http/types.rs @@ -4,6 +4,19 @@ use std::fmt::Formatter; use time::format_description::well_known::Rfc3339; use time::OffsetDateTime; +#[derive(Debug, Deserialize)] +pub struct CreateSubscriptionRequest { + pub wallet_address: String, + pub to_token: String, + pub from_token: Vec, + pub percentage: Vec, +} + +#[derive(Debug, Serialize)] +pub struct CreateSubscriptionResponse { + pub wallet_address: String, +} + #[derive(sqlx::Type)] pub struct TimeStamptz(pub OffsetDateTime); diff --git a/tests/api/main.rs b/tests/api/main.rs index 0ab062a..aab42b9 100644 --- a/tests/api/main.rs +++ b/tests/api/main.rs @@ -1,3 +1,4 @@ mod health_check; mod helpers; +mod subscription; mod unsubscription; diff --git a/tests/api/subscription.rs b/tests/api/subscription.rs new file mode 100644 index 0000000..d9ac293 --- /dev/null +++ b/tests/api/subscription.rs @@ -0,0 +1,223 @@ +use axum::{ + body::Body, + http::{header::CONTENT_TYPE, Request, StatusCode}, +}; +use serde_json::json; +use sqlx::PgPool; + +use crate::helpers::*; + +async fn clean_database(pool: &PgPool) { + let _ = sqlx::query!("SELECT COUNT(*) FROM swap_subscription") + .fetch_one(pool) + .await + .unwrap_or_else(|_| panic!("Database tables not ready")); + + sqlx::query!("DELETE FROM swap_subscription_from_token") + .execute(pool) + .await + .unwrap(); + sqlx::query!("DELETE FROM swap_subscription") + .execute(pool) + .await + .unwrap(); + + let count = sqlx::query!("SELECT COUNT(*) as count FROM swap_subscription") + .fetch_one(pool) + .await + .unwrap(); + + println!("Database cleaned. Subscription count: {:?}", count.count); +} + +#[tokio::test] +async fn test_subscribe_ok() { + let app = TestApp::new().await; + + clean_database(&app.db.pool).await; + + let payload = json!({ + "wallet_address": "0x742d35Cc6634C0532925a3b844Bc454e4438f44e", + "to_token": "0x1234567890123456789012345678901234567890", + "from_token": [ + "0xabcdef0123456789abcdef0123456789abcdef01", + "0x9876543210987654321098765432109876543210" + ], + "percentage": [60, 40] + }); + + let req = Request::builder() + .method("POST") + .uri("/subscriptions") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.request(req).await; + assert_eq!(resp.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_successful_subscription_creation() { + let app = TestApp::new().await; + + clean_database(&app.db.pool).await; + + let wallet_address = "0x742d35Cc6634C0532925a3b844Bc454e4438f44e"; + let to_token = "0x1234567890123456789012345678901234567890"; + let from_tokens = vec![ + "0xabcdef0123456789abcdef0123456789abcdef01", + "0x9876543210987654321098765432109876543210", + ]; + let percentages = vec![60, 40]; + + let payload = json!({ + "wallet_address": wallet_address, + "to_token": to_token, + "from_token": from_tokens, + "percentage": percentages + }); + + let req = Request::builder() + .method("POST") + .uri("/subscriptions") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.request(req).await; + + assert_eq!(resp.status(), StatusCode::OK); + + let subscription = sqlx::query!( + r#" + SELECT wallet_address, to_token, is_active + FROM swap_subscription + WHERE wallet_address = $1 + "#, + wallet_address + ) + .fetch_one(&app.db.pool) + .await + .unwrap(); + + assert_eq!(subscription.wallet_address, wallet_address); + assert_eq!(subscription.to_token, to_token); + assert!(subscription.is_active); + + let from_token_records = sqlx::query!( + r#" + SELECT from_token, percentage + FROM swap_subscription_from_token + WHERE wallet_address = $1 + "#, + wallet_address + ) + .fetch_all(&app.db.pool) + .await + .unwrap(); + + assert_eq!(from_token_records.len(), 2); + + let token_percentages: std::collections::HashMap<&str, i16> = from_token_records + .iter() + .map(|record| (record.from_token.as_str(), record.percentage)) + .collect(); + + assert_eq!( + token_percentages.get(from_tokens[0]), + Some(&(percentages[0] as i16)), + "First token {} should have percentage {}", + from_tokens[0], + percentages[0] + ); + + assert_eq!( + token_percentages.get(from_tokens[1]), + Some(&(percentages[1] as i16)), + "Second token {} should have percentage {}", + from_tokens[1], + percentages[1] + ); +} + +#[tokio::test] +async fn test_invalid_percentage_length() { + let app = TestApp::new().await; + + clean_database(&app.db.pool).await; + + let payload = json!({ + "wallet_address": "0x742d35Cc6634C0532925a3b844Bc454e4438f44e", + "to_token": "0x1234567890123456789012345678901234567890", + "from_token": [ + "0xabcdef0123456789abcdef0123456789abcdef01", + "0x9876543210987654321098765432109876543210" + ], + "percentage": [20] + }); + + let req = Request::builder() + .method("POST") + .uri("/subscriptions") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.request(req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_invalid_wallet_address() { + let app = TestApp::new().await; + + clean_database(&app.db.pool).await; + + let payload = json!({ + "wallet_address": "invalid_wallet_address", + "to_token": "0x1234567890123456789012345678901234567890", + "from_token": [ + "0xabcdef0123456789abcdef0123456789abcdef01", + "0x9876543210987654321098765432109876543210" + ], + "percentage": [20, 80] + }); + + let req = Request::builder() + .method("POST") + .uri("/subscriptions") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.request(req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_invalid_to_token_address() { + let app = TestApp::new().await; + + clean_database(&app.db.pool).await; + + let payload = json!({ + "wallet_address": "0x742d35Cc6634C0532925a3b844Bc454e4438f44e", + "to_token": "invalid_to_token", + "from_token": [ + "0xabcdef0123456789abcdef0123456789abcdef01", + "0x9876543210987654321098765432109876543210" + ], + "percentage": [20, 80] + }); + + let req = Request::builder() + .method("POST") + .uri("/subscriptions") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.request(req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); +}