Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Commit

Permalink
Feature/remote shutdown (#45)
Browse files Browse the repository at this point in the history
* chore: Update dependencies

* feat: Add line number to logging

* feat: Add ability to shutdown server via websockets

* chore: Bump pyo3 and pyo3-log

* chore: Implement clippy suggestions

* chore: Run cargo fmt
  • Loading branch information
danielvschoor authored Nov 11, 2022
1 parent 6f8c6c1 commit 671b6fd
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 28 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ requires-dist = ["portpicker", "aiohttp"]

[dependencies]
sc2-proto = { git = "https://github.com/aiarena/sc2-proto-rs.git" }
protobuf = { version = "^3.1.0", features = ["with-bytes"] }
protobuf = { version = "=3.2.0", features = ["with-bytes"] }
log = "^0.4.13"
env_logger = "0.9.0"
shellexpand = "^2.1.0"
Expand All @@ -32,16 +32,17 @@ serde = { version = "^1.0", features = ["derive"] }
serde_json = "^1.0"
bincode = { version = "^1.3.1", optional = true }
csv = "1.1.3"
pyo3-log = { version="^0.6.0", optional=true }
pyo3-log = { version= "0.7.0", optional=true }
tokio = { version = "1.19.0", features = ["time","macros","rt","rt-multi-thread"] }
futures-util = "0.3.21"
anyhow = "1.0.58"
chrono = "0.4.22"

[dependencies.tokio-tungstenite]
version = "0.17.1"
version = "0.17.2"

[dependencies.pyo3]
version = "^0.16"
version = "0.17.3"
optional = true
features = ["auto-initialize"]

Expand Down
13 changes: 12 additions & 1 deletion rust_ac/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,23 @@ def __init__(self, ip_addr: str, config: Optional[GameConfig] = None):
else:
self._config: GameConfig = config

async def shutdown_request(self):
"""
Connects to address with headers
"""
headers = {"shutdown": "true"}
addr = self._parse_url()

session = ClientSession()
await session.ws_connect(addr, headers=headers)
return

async def connect(self):
"""
Connects to address with headers
"""
ws, session = None, None
headers = {"Supervisor": "true"}
headers = {"supervisor": "true"}
addr = self._parse_url()
for i in range(60):
await asyncio.sleep(1)
Expand Down
2 changes: 1 addition & 1 deletion src/config/race.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::sc2::Race;

#[derive(PartialOrd, PartialEq, Debug)]
#[derive(PartialOrd, PartialEq, Eq, Debug)]
pub enum BotRace {
NoRace = 0,
Terran = 1,
Expand Down
4 changes: 2 additions & 2 deletions src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl Controller {
async fn client_join_game(&mut self, index: usize, req: RequestJoinGame) -> Option<()> {
let ((client_name, client_race), client, old_req) = self.clients.remove(index);
debug!("{} client_join_game", client_name);
if old_req != None {
if old_req.is_some() {
error!("Client attempted to join a handler twice (dropping connection)");
return None;
}
Expand Down Expand Up @@ -382,7 +382,7 @@ impl Controller {
debug!("JoinGame from {:?}", self.clients[i].0);
let join_response = self.client_join_game(i, req).await;

if join_response == None {
if join_response.is_none() {
error!("Game creation / joining failed");
}
}
Expand Down
1 change: 1 addition & 0 deletions src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod proxy_error;
5 changes: 5 additions & 0 deletions src/errors/proxy_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#[derive(Debug, Clone)]
pub enum ProxyError {
ShutdownRequest,
AcceptError,
}
2 changes: 1 addition & 1 deletion src/handler/player.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ impl Player {
r.set_save_replay(RequestSaveReplay::new());
if let Some(response) = self.sc2_query(&r).await {
if response.has_save_replay() {
match File::create(&path) {
match File::create(path) {
Ok(mut buffer) => {
let data: &[u8] = response.save_replay().data();
buffer
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::prelude::*;
pub mod build_info;
pub mod config;
pub mod controller;
pub mod errors;
pub mod handler;
pub mod maps;
pub mod paths;
Expand Down
16 changes: 15 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,24 @@ mod result;
pub mod sc2;
mod sc2process;
pub mod server;
use std::io::Write;
pub mod errors;

#[tokio::main]
async fn main() {
env_logger::init();
env_logger::Builder::new()
.format(|buf, record| {
writeln!(
buf,
"{}:{} {} [{}] - {}",
record.file().unwrap_or("unknown"),
record.line().unwrap_or(0),
chrono::Local::now().format("%Y-%m-%dT%H:%M:%S"),
record.level(),
record.args()
)
})
.init();
let s = server::RustServer::new("127.0.0.1:8642");
s.run().await.expect("Could not join");
}
58 changes: 40 additions & 18 deletions src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! Proxy WebSocket receiver
use crate::errors::proxy_error::ProxyError;
use crate::server::ClientType;
use crossbeam::channel::Sender;
use futures_util::SinkExt;
use futures_util::StreamExt;
use log::info;
use log::{error, info};
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_tungstenite::tungstenite::handshake::server::{
Expand Down Expand Up @@ -56,12 +57,15 @@ impl Client {
}

/// Accept a new connection
async fn get_connection(server: &mut TcpListener) -> Option<(ClientType, Client)> {
async fn get_connection(server: &mut TcpListener) -> Result<(ClientType, Client), ProxyError> {
let mut is_supervisor = false;
let callback = |req: &Request, response: Response| {
if req.headers().contains_key("supervisor") {
is_supervisor = true;
}
if req.headers().contains_key("shutdown") {
return Err(ErrorResponse::new(Some("Shutdown Requested".to_string())));
}
Ok(response)
};
let config = Some(WebSocketConfig {
Expand All @@ -77,31 +81,49 @@ async fn get_connection(server: &mut TcpListener) -> Option<(ClientType, Client)
match server.accept().await {
Ok((stream, peer)) => {
// let peer = stream.peer_addr().expect("connected streams should have a peer address");
if let Ok(ws_stream) = accept_hdr_async_with_config(stream, callback, config).await {
let client = Client {
stream: ws_stream,
addr: peer,
};
return if is_supervisor {
Some((ClientType::Controller, client))
} else {
Some((ClientType::Bot, client))
};
match accept_hdr_async_with_config(stream, callback, config).await {
Ok(ws_stream) => {
let client = Client {
stream: ws_stream,
addr: peer,
};
if is_supervisor {
Ok((ClientType::Controller, client))
} else {
Ok((ClientType::Bot, client))
}
}
Err(e) => {
info!("1{:?}", e);
Err(ProxyError::ShutdownRequest)
}
}
None
}
_ => None,
Err(e) => {
info!("2{:?}", e);
Err(ProxyError::AcceptError)
}
}
}

/// Run the proxy server
pub async fn run<A: ToSocketAddrs>(addr: A, channel_out: Sender<(ClientType, Client)>) -> ! {
pub async fn run<A: ToSocketAddrs>(addr: A, channel_out: Sender<(ClientType, Client)>) {
let mut server = TcpListener::bind(addr).await.expect("Unable to bind");

loop {
if let Some((c_type, client)) = get_connection(&mut server).await {
info!("Connection accepted: {:?}", client.addr);
channel_out.send((c_type, client)).expect("Send failed");
match get_connection(&mut server).await {
Ok((c_type, client)) => {
info!("Connection accepted: {:?}", client.addr);
channel_out.send((c_type, client)).expect("Send failed");
}
Err(ProxyError::AcceptError) => {
error!("Could not accept incoming request");
break;
}
Err(ProxyError::ShutdownRequest) => {
info!("Shutdown requested");
break;
}
}
}
}

0 comments on commit 671b6fd

Please sign in to comment.