Skip to content

Commit

Permalink
OAuth 2.0 Device Authorization Grant (RFC 8628)
Browse files Browse the repository at this point in the history
  Most requests had to be done by our own implementation to work around
  with oauth2 crate limitations.
  • Loading branch information
sorah committed Nov 22, 2024
1 parent c13e39a commit 9a5bfe0
Show file tree
Hide file tree
Showing 10 changed files with 469 additions and 20 deletions.
21 changes: 21 additions & 0 deletions proto/mairu.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ service Agent {
rpc InitiateOauthCode(InitiateOAuthCodeRequest) returns (InitiateOAuthCodeResponse);
rpc CompleteOauthCode(CompleteOAuthCodeRequest) returns (CompleteOAuthCodeResponse);

rpc InitiateOauthDeviceCode(InitiateOAuthDeviceCodeRequest) returns (InitiateOAuthDeviceCodeResponse);
rpc CompleteOauthDeviceCode(CompleteOAuthDeviceCodeRequest) returns (CompleteOAuthDeviceCodeResponse);

rpc RefreshAwsSsoClientRegistration(RefreshAwsSsoClientRegistrationRequest) returns (RefreshAwsSsoClientRegistrationResponse);
rpc InitiateAwsSsoDevice(InitiateAwsSsoDeviceRequest) returns (InitiateAwsSsoDeviceResponse);
rpc CompleteAwsSsoDevice(CompleteAwsSsoDeviceRequest) returns (CompleteAwsSsoDeviceResponse);
Expand Down Expand Up @@ -86,6 +89,24 @@ message CompleteOAuthCodeRequest {
message CompleteOAuthCodeResponse {
}

message InitiateOAuthDeviceCodeRequest {
string server_id = 1;
}
message InitiateOAuthDeviceCodeResponse {
string handle = 1;
string user_code = 2;
string verification_uri = 3;
string verification_uri_complete = 4;
google.protobuf.Timestamp expires_at = 5;
int32 interval = 6;
}

message CompleteOAuthDeviceCodeRequest {
string handle = 1;
}
message CompleteOAuthDeviceCodeResponse {
}

message RefreshAwsSsoClientRegistrationRequest {
string server_id = 1;
}
Expand Down
70 changes: 69 additions & 1 deletion src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,68 @@ impl crate::proto::agent_server::Agent for Agent {
Ok(tonic::Response::new(CompleteOAuthCodeResponse {}))
}

#[tracing::instrument(skip_all)]
async fn initiate_oauth_device_code(
&self,
request: tonic::Request<InitiateOAuthDeviceCodeRequest>,
) -> Result<tonic::Response<InitiateOAuthDeviceCodeResponse>, tonic::Status> {
let req = request.get_ref();

let server = match crate::config::Server::find_from_fs(&req.server_id).await {
Ok(server) => server,
Err(crate::Error::ConfigError(e)) => return Err(tonic::Status::internal(e)),
Err(crate::Error::UserError(e)) => return Err(tonic::Status::not_found(e)),
Err(e) => return Err(tonic::Status::internal(e.to_string())),
};

server.validate().map_err(|e| {
tonic::Status::failed_precondition(format!(
"Server '{}' has invalid configuration; {:}",
server.id(),
e,
))
})?;

let flow = crate::oauth_device_code::OAuthDeviceCodeFlow::initiate(&server)
.await
.map_err(|e| {
tracing::error!(err = ?e, "OAuthDeviceCodeFlow initiate failure");
tonic::Status::internal(e.to_string())
})?;

let response = (&flow).into();

tracing::debug!(flow = ?flow, "Initiated OAuth 2.0 Device Code flow");
self.auth_flow_manager
.store(crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow));

return Ok(tonic::Response::new(response));
}

#[tracing::instrument(skip_all)]
async fn complete_oauth_device_code(
&self,
request: tonic::Request<CompleteOAuthDeviceCodeRequest>,
) -> Result<tonic::Response<CompleteOAuthDeviceCodeResponse>, tonic::Status> {
let req = request.get_ref();
let Some(flow0) = self.auth_flow_manager.retrieve(&req.handle) else {
return Err(tonic::Status::not_found("flow handle not found"));
};
let completion = {
let crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow) = flow0.as_ref() else {
return Err(tonic::Status::invalid_argument(
"flow handle is not for the grant type",
));
};
tracing::trace!(flow = ?flow0.as_ref(), "Completing OAuth 2.0 Device Code Grant flow...");
flow.complete().await
};

self.accept_completed_auth_flow(flow0, completion)?;

Ok(tonic::Response::new(CompleteOAuthDeviceCodeResponse {}))
}

#[tracing::instrument(skip_all)]
async fn refresh_aws_sso_client_registration(
&self,
Expand Down Expand Up @@ -325,7 +387,13 @@ impl Agent {
) -> tonic::Result<()> {
let token = match completion {
Ok(t) => t,
Err(crate::Error::AuthNotReadyError) => {
Err(crate::Error::AuthNotReadyError { slow_down: true }) => {
tracing::debug!(flow = ?flow.as_ref(), "not yet ready, slow down");
return Err(tonic::Status::resource_exhausted(
"not yet ready, slow down".to_string(),
));
}
Err(crate::Error::AuthNotReadyError { slow_down: false }) => {
tracing::debug!(flow = ?flow.as_ref(), "not yet ready");
return Err(tonic::Status::failed_precondition(
"not yet ready".to_string(),
Expand Down
2 changes: 2 additions & 0 deletions src/auth_flow_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub const MAX_ITEMS: usize = 15;
pub enum AuthFlow {
Nop,
OAuthCode(crate::oauth_code::OAuthCodeFlow),
OAuthDeviceCode(crate::oauth_device_code::OAuthDeviceCodeFlow),
AwsSsoDevice(crate::oauth_awssso::AwsSsoDeviceFlow),
}

Expand All @@ -13,6 +14,7 @@ impl AuthFlow {
match self {
AuthFlow::Nop => "",
AuthFlow::OAuthCode(f) => &f.handle,
AuthFlow::OAuthDeviceCode(f) => &f.handle,
AuthFlow::AwsSsoDevice(f) => &f.handle,
}
}
Expand Down
68 changes: 65 additions & 3 deletions src/cmd/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ pub async fn login(
server.validate()?;

let oauth = server.oauth.as_ref().unwrap();
let oauth_grant_type = args
.oauth_grant_type
.unwrap_or_else(|| oauth.default_grant_type());
let oauth_grant_type = match args.oauth_grant_type {
Some(x) => Ok(x),
None => oauth.default_grant_type(),
}?;

tracing::debug!(oauth_grant_type = ?oauth_grant_type, server = ?server, "Using OAuth");

match oauth_grant_type {
crate::config::OAuthGrantType::Code => do_oauth_code(agent, server).await,
crate::config::OAuthGrantType::DeviceCode => do_oauth_device_code(agent, server).await,
crate::config::OAuthGrantType::AwsSso => do_awssso(agent, server).await,
}
}
Expand Down Expand Up @@ -89,6 +91,66 @@ pub async fn do_oauth_code(
Ok(())
}

pub async fn do_oauth_device_code(
agent: &mut crate::agent::AgentConn,
server: crate::config::Server,
) -> Result<(), anyhow::Error> {
server.try_oauth_device_code_grant()?;

let session = agent
.initiate_oauth_device_code(crate::proto::InitiateOAuthDeviceCodeRequest {
server_id: server.id().to_owned(),
})
.await?
.into_inner();
tracing::debug!(session = ?session, "Initiated flow");

let product = env!("CARGO_PKG_NAME");
let server_id = server.id();
let server_url = &server.url;
let user_code = &session.user_code;
let mut authorize_url = &session.verification_uri_complete;
if authorize_url.is_empty() {
authorize_url = &session.verification_uri;
}

crate::terminal::send(&indoc::formatdoc! {"
:: {product} :: Login to {server_id} ({server_url}) ::::::::
:: {product} ::
:: {product} :: Your Verification Code: {user_code}
:: {product} :: To authorize, visit: {authorize_url}
:: {product} ::
"})
.await;

let mut interval = session.interval as u64;
loop {
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
let completion = agent
.complete_oauth_device_code(crate::proto::CompleteOAuthDeviceCodeRequest {
handle: session.handle.clone(),
})
.await;

match completion {
Ok(_) => break,
Err(e) if e.code() == tonic::Code::ResourceExhausted => {
interval += 5;
tracing::debug!(interval = ?interval, "Received slow_down request");
}
Err(e) if e.code() == tonic::Code::FailedPrecondition => {
// continue
}
Err(e) => {
anyhow::bail!(e);
}
}
}

tracing::info!("Logged in");
Ok(())
}

pub async fn do_awssso(
agent: &mut crate::agent::AgentConn,
server: crate::config::Server,
Expand Down
54 changes: 40 additions & 14 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,24 @@ impl Server {
Ok((oauth, code_grant))
}

pub fn try_oauth_device_code_grant(
&self,
) -> crate::Result<(&ServerOAuth, &ServerDeviceCodeGrant)> {
let Some(oauth) = self.oauth.as_ref() else {
return Err(crate::Error::ConfigError(format!(
"Server '{}' is missing OAuth 2.0 client configuration",
self.id()
)));
};
match oauth.device_code_grant {
Some(ref grant) => Ok((oauth, grant)),
None => Err(crate::Error::ConfigError(format!(
"Server '{}' is missing OAuth 2.0 Device Code Grant configuration",
self.id()
))),
}
}

pub fn try_oauth_awssso(&self) -> crate::Result<&ServerOAuth> {
if self.aws_sso.is_none() {
return Err(crate::Error::ConfigError(format!(
Expand Down Expand Up @@ -307,6 +325,7 @@ impl TryFrom<crate::proto::GetServerResponse> for Server {
#[serde(rename_all = "snake_case")]
pub enum OAuthGrantType {
Code,
DeviceCode,
AwsSso,
}

Expand All @@ -315,6 +334,7 @@ impl std::str::FromStr for OAuthGrantType {
fn from_str(s: &str) -> Result<OAuthGrantType, crate::Error> {
match s {
"code" => Ok(OAuthGrantType::Code),
"device_code" => Ok(OAuthGrantType::DeviceCode),
"aws_sso" => Ok(OAuthGrantType::AwsSso),
_ => Err(crate::Error::UserError(
"unknown oauth_grant_type".to_owned(),
Expand All @@ -333,7 +353,7 @@ pub struct ServerOAuth {
pub scope: Vec<String>,
default_grant_type: Option<OAuthGrantType>,
pub code_grant: Option<ServerCodeGrant>,
pub device_grant: Option<ServerDeviceGrant>,
pub device_code_grant: Option<ServerDeviceCodeGrant>,

/// Expiration time of dynamically registered OAuth2 client registration, such as AWS SSO
/// clients.
Expand All @@ -346,14 +366,15 @@ fn default_oauth_scope() -> Vec<String> {

impl ServerOAuth {
pub fn validate(&self) -> Result<(), crate::Error> {
if self.code_grant.is_none() && self.device_grant.is_none() {
if self.code_grant.is_none() && self.device_code_grant.is_none() {
return Err(crate::Error::ConfigError(
"Either oauth.code_grant or oauth.device_grant must be provided, but absent"
"Either oauth.code_grant or oauth.device_code_grant must be provided, but absent"
.to_owned(),
));
}
if match self.default_grant_type {
None => false,
Some(OAuthGrantType::DeviceCode) => self.device_code_grant.is_none(),
Some(OAuthGrantType::Code) => self.code_grant.is_none(),
Some(OAuthGrantType::AwsSso) => false,
} {
Expand All @@ -364,16 +385,21 @@ impl ServerOAuth {
Ok(())
}

pub fn default_grant_type(&self) -> OAuthGrantType {
self.default_grant_type.unwrap_or_else(|| {
if self.code_grant.is_some() {
return OAuthGrantType::Code;
}
if self.device_grant.is_some() {
// TODO: implement
pub fn default_grant_type(&self) -> Result<OAuthGrantType, crate::Error> {
match self.default_grant_type {
Some(x) => Ok(x),
None => {
if self.code_grant.is_some() {
return Ok(OAuthGrantType::Code);
}
if self.device_code_grant.is_some() {
return Ok(OAuthGrantType::DeviceCode);
}
Err(crate::Error::ConfigError(
"cannot determine default grant_type".to_string(),
))
}
unreachable!();
})
}
}
}

Expand All @@ -384,7 +410,7 @@ pub struct ServerCodeGrant {
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct ServerDeviceGrant {
pub struct ServerDeviceCodeGrant {
pub device_authorization_endpoint: Option<url::Url>,
}

Expand Down Expand Up @@ -435,7 +461,7 @@ impl AwsSsoClientRegistrationCache {
token_endpoint: None,
scope: sso.scope.clone(),
code_grant: None,
device_grant: None,
device_code_grant: None,
client_expires_at: Some(self.expires_at),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub enum Error {
UrlParseError(#[from] url::ParseError),

#[error("AuthNotReadyError: flow not yet ready")]
AuthNotReadyError,
AuthNotReadyError { slow_down: bool },

#[error(transparent)]
OAuth2RequestTokenError(
Expand Down
Loading

0 comments on commit 9a5bfe0

Please sign in to comment.