Skip to content

Commit

Permalink
Merge pull request #8 from fetchfern/main
Browse files Browse the repository at this point in the history
PYRO-82: auth
  • Loading branch information
fetchfern authored Apr 21, 2024
2 parents 09250b5 + c514d75 commit 6bfa113
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 53 deletions.
1 change: 1 addition & 0 deletions crates/alerion_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ reqwest = { version = "0.12.3" }
smallvec = { version = "1.13.2", features = ["serde"] }
directories = "5.0.1"
bollard = "0.16.1"
bitflags = "2.5.0"
1 change: 1 addition & 0 deletions crates/alerion_core/src/filesystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pub async fn setup_directories() -> anyhow::Result<ProjectDirs> {

Ok(project_dirs)
}

68 changes: 37 additions & 31 deletions crates/alerion_core/src/websocket/auth.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashSet;

use bitflags::bitflags;
use jsonwebtoken::{Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
Expand All @@ -21,59 +22,62 @@ struct Claims {
unique_id: String,
}

#[derive(Debug, Default)]
pub struct Permissions {
pub connect: bool,
pub start: bool,
pub stop: bool,
pub restart: bool,
pub console: bool,
pub backup_read: bool,
pub admin_errors: bool,
pub admin_install: bool,
pub admin_transfer: bool,
bitflags! {
#[derive(Debug, Clone, Copy)]
pub struct Permissions: u32 {
const CONNECT = 1;
const START = 1 << 1;
const STOP = 1 << 2;
const RESTART = 1 << 3;
const CONSOLE = 1 << 4;
const BACKUP_READ = 1 << 5;
const ADMIN_ERRORS = 1 << 6;
const ADMIN_INSTALL = 1 << 7;
const ADMIN_TRANSFER = 1 << 8;
}
}

impl Permissions {
pub fn from_strings(strings: &[impl AsRef<str>]) -> Self {
let mut this = Permissions::default();
let mut this = Permissions::empty();

for s in strings {
match s.as_ref() {
"*" => {
this.connect = true;
this.start = true;
this.stop = true;
this.restart = true;
this.console = true;
this.backup_read = true;
this.insert(Permissions::CONNECT);
this.insert(Permissions::START);
this.insert(Permissions::STOP);
this.insert(Permissions::RESTART);
this.insert(Permissions::CONSOLE);
this.insert(Permissions::BACKUP_READ);

}
"websocket.connect" => {
this.connect = true;
this.insert(Permissions::CONNECT);
}
"control.start" => {
this.start = true;
this.insert(Permissions::START);
}
"control.stop" => {
this.stop = true;
this.insert(Permissions::STOP);
}
"control.restart" => {
this.restart = true;
this.insert(Permissions::RESTART);
}
"control.console" => {
this.console = true;
this.insert(Permissions::CONSOLE);
}
"backup.read" => {
this.backup_read = true;
this.insert(Permissions::BACKUP_READ);
}
"admin.websocket.errors" => {
this.admin_errors = true;
this.insert(Permissions::ADMIN_ERRORS);
}
"admin.websocket.install" => {
this.admin_install = true;
this.insert(Permissions::ADMIN_INSTALL);
}
"admin.websocket.transfer" => {
this.admin_transfer = true;
this.insert(Permissions::ADMIN_TRANSFER);
}
_ => {}
}
Expand All @@ -91,8 +95,10 @@ pub struct Auth {
impl Auth {
pub fn from_config(cfg: &AlerionConfig) -> Self {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims =
HashSet::from(["exp", "nbf", "aud", "iss"].map(ToOwned::to_owned));

let spec_claims = ["exp", "nbf", "aud", "iss"].map(ToOwned::to_owned);

validation.required_spec_claims = HashSet::from(spec_claims);
validation.leeway = 10;
validation.reject_tokens_expiring_in_less_than = 0;
validation.validate_exp = false;
Expand All @@ -107,10 +113,10 @@ impl Auth {
Self { validation, key }
}

pub fn is_valid(&self, auth: &str, server_uuid: &Uuid) -> bool {
pub fn validate(&self, auth: &str, server_uuid: &Uuid) -> Option<Permissions> {
jsonwebtoken::decode::<Claims>(auth, &self.key, &self.validation)
.ok()
.filter(|result| &result.claims.server_uuid == server_uuid)
.is_some()
.map(|result| Permissions::from_strings(&result.claims.permissions))
}
}
75 changes: 60 additions & 15 deletions crates/alerion_core/src/websocket/conn.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::convert::Infallible;
use std::cell::Cell;

use actix::{Actor, ActorContext, Addr, Handler, StreamHandler};
use actix_web_actors::ws;
use alerion_datamodel::websocket::*;
use uuid::Uuid;

use super::auth::Auth;
use crate::config::AlerionConfig;

use super::auth::{Auth, Permissions};
use super::relay::ServerConnection;
use crate::config::AlerionConfig;

Expand All @@ -24,6 +27,7 @@ impl actix::Message for ServerMessage {
pub enum PanelMessage {
Command(String),
ReceiveLogs,
ReceiveInstallLog,
ReceiveStats,
}

Expand All @@ -33,10 +37,16 @@ impl actix::Message for PanelMessage {

pub type ConnectionAddr = Addr<WebsocketConnectionImpl>;

enum MessageError {
InvalidJwt,
Generic(String),
}

pub struct WebsocketConnectionImpl {
server_uuid: Uuid,
server_conn: ServerConnection,
auth: Auth,
permissions: Cell<Permissions>,
}

impl Actor for WebsocketConnectionImpl {
Expand Down Expand Up @@ -71,7 +81,7 @@ impl Handler<ServerMessage> for WebsocketConnectionImpl {
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WebsocketConnectionImpl {
fn handle(&mut self, item: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
use ws::Message;
// just ignore bad messages

let Ok(msg) = item else {
return;
};
Expand All @@ -91,6 +101,7 @@ impl WebsocketConnectionImpl {
server_uuid,
server_conn,
auth: Auth::from_config(cfg),
permissions: Cell::new(Permissions::empty()),
}
}

Expand All @@ -100,34 +111,52 @@ impl WebsocketConnectionImpl {

match event.event() {
EventType::Authentication => {
if self
.auth
.is_valid(&event.into_first_arg()?, &self.server_uuid)
{
self.server_conn.set_authenticated();
ctx.text(RawMessage::new_no_args(EventType::AuthenticationSuccess));
let maybe_permissions = self.auth.validate(&event.into_first_arg()?, &self.server_uuid);

if let Some(permissions) = maybe_permissions {
if permissions.contains(Permissions::CONNECT) {
self.permissions.set(permissions);
self.server_conn.set_authenticated();
ctx.text(RawMessage::new_no_args(EventType::AuthenticationSuccess));
}
} else {
self.send_error(ctx, MessageError::InvalidJwt);
}

Some(())
}

ty => {
if self.server_conn.is_authenticated() {
let permissions = self.permissions.get();

match ty {
EventType::SendCommand => {
self.server_conn.send_if_authenticated(|| {
PanelMessage::Command("silly".to_owned())
});
if permissions.contains(Permissions::CONSOLE) {
if let Some(command) = event.into_first_arg() {
self.server_conn.send_if_authenticated(PanelMessage::Command(command));
} else {
self.send_error(ctx, MessageError::InvalidJwt);
}

}
}

EventType::SendStats => {
self.server_conn
.send_if_authenticated(|| PanelMessage::ReceiveStats);
if permissions.contains(Permissions::CONSOLE) {
self.server_conn
.send_if_authenticated(PanelMessage::ReceiveStats);
}
}

EventType::SendLogs => {
self.server_conn
.send_if_authenticated(|| PanelMessage::ReceiveLogs);
if permissions.contains(Permissions::CONSOLE) {
self.server_conn.send_if_authenticated(PanelMessage::ReceiveLogs);

if permissions.contains(Permissions::ADMIN_INSTALL) {
self.server_conn.force_send(PanelMessage::ReceiveInstallLog);
}
}
}

e => todo!("{e:?}"),
Expand All @@ -138,4 +167,20 @@ impl WebsocketConnectionImpl {
}
}
}

#[inline(always)]
fn send_error(&self, ctx: &mut <Self as Actor>::Context, err: MessageError) {
let precise_errors = self.permissions.get().contains(Permissions::ADMIN_ERRORS);

let raw_msg = if precise_errors {
match err {
MessageError::InvalidJwt => RawMessage::new_no_args(EventType::JwtError),
MessageError::Generic(s) => RawMessage::new(EventType::DaemonError, s),
}
} else {
RawMessage::new(EventType::DaemonError, "An unexpected error occurred".to_owned())
};

ctx.text(raw_msg)
}
}
14 changes: 7 additions & 7 deletions crates/alerion_core/src/websocket/relay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ impl ServerConnection {
self.auth_tracker.get_auth()
}

#[inline]
pub fn send_if_authenticated<F>(&self, msg: F)
where
F: FnOnce() -> PanelMessage,
{
pub fn send_if_authenticated(&self, msg: PanelMessage) {
if self.auth_tracker.get_auth() {
let value = msg();
let _ = self.sender.try_send((self.id, value));
let _ = self.sender.try_send((self.id, msg));
}
}

pub fn force_send(&self, msg: PanelMessage) {
let _ = self.sender.try_send((self.id, msg));
}

pub fn auth_tracker(&self) -> Arc<AuthTracker> {
Arc::clone(&self.auth_tracker)
}
Expand All @@ -61,6 +60,7 @@ impl ClientConnection {
}
}

/// Uses a closure because many messages might be expensive to compute.
pub fn send_if_authenticated<F>(&self, msg: F)
where
F: FnOnce() -> ServerMessage,
Expand Down

0 comments on commit 6bfa113

Please sign in to comment.