Skip to content

Commit

Permalink
Merge pull request #15 from sorah/device-code
Browse files Browse the repository at this point in the history
OAuth 2.0 Device Authorization Grant (RFC 8628)
  • Loading branch information
sorah authored Nov 22, 2024
2 parents c13e39a + a23443a commit 53fff8f
Show file tree
Hide file tree
Showing 10 changed files with 454 additions and 20 deletions.
21 changes: 21 additions & 0 deletions proto/mairu.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ service Agent {
rpc InitiateOauthCode(InitiateOAuthCodeRequest) returns (InitiateOAuthCodeResponse);
rpc CompleteOauthCode(CompleteOAuthCodeRequest) returns (CompleteOAuthCodeResponse);

rpc InitiateOauthDeviceCode(InitiateOAuthDeviceCodeRequest) returns (InitiateOAuthDeviceCodeResponse);
rpc CompleteOauthDeviceCode(CompleteOAuthDeviceCodeRequest) returns (CompleteOAuthDeviceCodeResponse);

rpc RefreshAwsSsoClientRegistration(RefreshAwsSsoClientRegistrationRequest) returns (RefreshAwsSsoClientRegistrationResponse);
rpc InitiateAwsSsoDevice(InitiateAwsSsoDeviceRequest) returns (InitiateAwsSsoDeviceResponse);
rpc CompleteAwsSsoDevice(CompleteAwsSsoDeviceRequest) returns (CompleteAwsSsoDeviceResponse);
Expand Down Expand Up @@ -86,6 +89,24 @@ message CompleteOAuthCodeRequest {
message CompleteOAuthCodeResponse {
}

message InitiateOAuthDeviceCodeRequest {
string server_id = 1;
}
message InitiateOAuthDeviceCodeResponse {
string handle = 1;
string user_code = 2;
string verification_uri = 3;
string verification_uri_complete = 4;
google.protobuf.Timestamp expires_at = 5;
int32 interval = 6;
}

message CompleteOAuthDeviceCodeRequest {
string handle = 1;
}
message CompleteOAuthDeviceCodeResponse {
}

message RefreshAwsSsoClientRegistrationRequest {
string server_id = 1;
}
Expand Down
70 changes: 69 additions & 1 deletion src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,68 @@ impl crate::proto::agent_server::Agent for Agent {
Ok(tonic::Response::new(CompleteOAuthCodeResponse {}))
}

#[tracing::instrument(skip_all)]
async fn initiate_oauth_device_code(
&self,
request: tonic::Request<InitiateOAuthDeviceCodeRequest>,
) -> Result<tonic::Response<InitiateOAuthDeviceCodeResponse>, tonic::Status> {
let req = request.get_ref();

let server = match crate::config::Server::find_from_fs(&req.server_id).await {
Ok(server) => server,
Err(crate::Error::ConfigError(e)) => return Err(tonic::Status::internal(e)),
Err(crate::Error::UserError(e)) => return Err(tonic::Status::not_found(e)),
Err(e) => return Err(tonic::Status::internal(e.to_string())),
};

server.validate().map_err(|e| {
tonic::Status::failed_precondition(format!(
"Server '{}' has invalid configuration; {:}",
server.id(),
e,
))
})?;

let flow = crate::oauth_device_code::OAuthDeviceCodeFlow::initiate(&server)
.await
.map_err(|e| {
tracing::error!(err = ?e, "OAuthDeviceCodeFlow initiate failure");
tonic::Status::internal(e.to_string())
})?;

let response = (&flow).into();

tracing::debug!(flow = ?flow, "Initiated OAuth 2.0 Device Code flow");
self.auth_flow_manager
.store(crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow));

return Ok(tonic::Response::new(response));
}

#[tracing::instrument(skip_all)]
async fn complete_oauth_device_code(
&self,
request: tonic::Request<CompleteOAuthDeviceCodeRequest>,
) -> Result<tonic::Response<CompleteOAuthDeviceCodeResponse>, tonic::Status> {
let req = request.get_ref();
let Some(flow0) = self.auth_flow_manager.retrieve(&req.handle) else {
return Err(tonic::Status::not_found("flow handle not found"));
};
let completion = {
let crate::auth_flow_manager::AuthFlow::OAuthDeviceCode(flow) = flow0.as_ref() else {
return Err(tonic::Status::invalid_argument(
"flow handle is not for the grant type",
));
};
tracing::trace!(flow = ?flow0.as_ref(), "Completing OAuth 2.0 Device Code Grant flow...");
flow.complete().await
};

self.accept_completed_auth_flow(flow0, completion)?;

Ok(tonic::Response::new(CompleteOAuthDeviceCodeResponse {}))
}

#[tracing::instrument(skip_all)]
async fn refresh_aws_sso_client_registration(
&self,
Expand Down Expand Up @@ -325,7 +387,13 @@ impl Agent {
) -> tonic::Result<()> {
let token = match completion {
Ok(t) => t,
Err(crate::Error::AuthNotReadyError) => {
Err(crate::Error::AuthNotReadyError { slow_down: true }) => {
tracing::debug!(flow = ?flow.as_ref(), "not yet ready, slow down");
return Err(tonic::Status::resource_exhausted(
"not yet ready, slow down".to_string(),
));
}
Err(crate::Error::AuthNotReadyError { slow_down: false }) => {
tracing::debug!(flow = ?flow.as_ref(), "not yet ready");
return Err(tonic::Status::failed_precondition(
"not yet ready".to_string(),
Expand Down
2 changes: 2 additions & 0 deletions src/auth_flow_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub const MAX_ITEMS: usize = 15;
pub enum AuthFlow {
Nop,
OAuthCode(crate::oauth_code::OAuthCodeFlow),
OAuthDeviceCode(crate::oauth_device_code::OAuthDeviceCodeFlow),
AwsSsoDevice(crate::oauth_awssso::AwsSsoDeviceFlow),
}

Expand All @@ -13,6 +14,7 @@ impl AuthFlow {
match self {
AuthFlow::Nop => "",
AuthFlow::OAuthCode(f) => &f.handle,
AuthFlow::OAuthDeviceCode(f) => &f.handle,
AuthFlow::AwsSsoDevice(f) => &f.handle,
}
}
Expand Down
68 changes: 65 additions & 3 deletions src/cmd/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ pub async fn login(
server.validate()?;

let oauth = server.oauth.as_ref().unwrap();
let oauth_grant_type = args
.oauth_grant_type
.unwrap_or_else(|| oauth.default_grant_type());
let oauth_grant_type = match args.oauth_grant_type {
Some(x) => Ok(x),
None => oauth.default_grant_type(),
}?;

tracing::debug!(oauth_grant_type = ?oauth_grant_type, server = ?server, "Using OAuth");

match oauth_grant_type {
crate::config::OAuthGrantType::Code => do_oauth_code(agent, server).await,
crate::config::OAuthGrantType::DeviceCode => do_oauth_device_code(agent, server).await,
crate::config::OAuthGrantType::AwsSso => do_awssso(agent, server).await,
}
}
Expand Down Expand Up @@ -89,6 +91,66 @@ pub async fn do_oauth_code(
Ok(())
}

pub async fn do_oauth_device_code(
agent: &mut crate::agent::AgentConn,
server: crate::config::Server,
) -> Result<(), anyhow::Error> {
server.try_oauth_device_code_grant()?;

let session = agent
.initiate_oauth_device_code(crate::proto::InitiateOAuthDeviceCodeRequest {
server_id: server.id().to_owned(),
})
.await?
.into_inner();
tracing::debug!(session = ?session, "Initiated flow");

let product = env!("CARGO_PKG_NAME");
let server_id = server.id();
let server_url = &server.url;
let user_code = &session.user_code;
let mut authorize_url = &session.verification_uri_complete;
if authorize_url.is_empty() {
authorize_url = &session.verification_uri;
}

crate::terminal::send(&indoc::formatdoc! {"
:: {product} :: Login to {server_id} ({server_url}) ::::::::
:: {product} ::
:: {product} :: Your Verification Code: {user_code}
:: {product} :: To authorize, visit: {authorize_url}
:: {product} ::
"})
.await;

let mut interval = session.interval as u64;
loop {
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
let completion = agent
.complete_oauth_device_code(crate::proto::CompleteOAuthDeviceCodeRequest {
handle: session.handle.clone(),
})
.await;

match completion {
Ok(_) => break,
Err(e) if e.code() == tonic::Code::ResourceExhausted => {
interval += 5;
tracing::debug!(interval = ?interval, "Received slow_down request");
}
Err(e) if e.code() == tonic::Code::FailedPrecondition => {
// continue
}
Err(e) => {
anyhow::bail!(e);
}
}
}

tracing::info!("Logged in");
Ok(())
}

pub async fn do_awssso(
agent: &mut crate::agent::AgentConn,
server: crate::config::Server,
Expand Down
54 changes: 40 additions & 14 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,24 @@ impl Server {
Ok((oauth, code_grant))
}

pub fn try_oauth_device_code_grant(
&self,
) -> crate::Result<(&ServerOAuth, &ServerDeviceCodeGrant)> {
let Some(oauth) = self.oauth.as_ref() else {
return Err(crate::Error::ConfigError(format!(
"Server '{}' is missing OAuth 2.0 client configuration",
self.id()
)));
};
match oauth.device_code_grant {
Some(ref grant) => Ok((oauth, grant)),
None => Err(crate::Error::ConfigError(format!(
"Server '{}' is missing OAuth 2.0 Device Code Grant configuration",
self.id()
))),
}
}

pub fn try_oauth_awssso(&self) -> crate::Result<&ServerOAuth> {
if self.aws_sso.is_none() {
return Err(crate::Error::ConfigError(format!(
Expand Down Expand Up @@ -307,6 +325,7 @@ impl TryFrom<crate::proto::GetServerResponse> for Server {
#[serde(rename_all = "snake_case")]
pub enum OAuthGrantType {
Code,
DeviceCode,
AwsSso,
}

Expand All @@ -315,6 +334,7 @@ impl std::str::FromStr for OAuthGrantType {
fn from_str(s: &str) -> Result<OAuthGrantType, crate::Error> {
match s {
"code" => Ok(OAuthGrantType::Code),
"device_code" => Ok(OAuthGrantType::DeviceCode),
"aws_sso" => Ok(OAuthGrantType::AwsSso),
_ => Err(crate::Error::UserError(
"unknown oauth_grant_type".to_owned(),
Expand All @@ -333,7 +353,7 @@ pub struct ServerOAuth {
pub scope: Vec<String>,
default_grant_type: Option<OAuthGrantType>,
pub code_grant: Option<ServerCodeGrant>,
pub device_grant: Option<ServerDeviceGrant>,
pub device_code_grant: Option<ServerDeviceCodeGrant>,

/// Expiration time of dynamically registered OAuth2 client registration, such as AWS SSO
/// clients.
Expand All @@ -346,14 +366,15 @@ fn default_oauth_scope() -> Vec<String> {

impl ServerOAuth {
pub fn validate(&self) -> Result<(), crate::Error> {
if self.code_grant.is_none() && self.device_grant.is_none() {
if self.code_grant.is_none() && self.device_code_grant.is_none() {
return Err(crate::Error::ConfigError(
"Either oauth.code_grant or oauth.device_grant must be provided, but absent"
"Either oauth.code_grant or oauth.device_code_grant must be provided, but absent"
.to_owned(),
));
}
if match self.default_grant_type {
None => false,
Some(OAuthGrantType::DeviceCode) => self.device_code_grant.is_none(),
Some(OAuthGrantType::Code) => self.code_grant.is_none(),
Some(OAuthGrantType::AwsSso) => false,
} {
Expand All @@ -364,16 +385,21 @@ impl ServerOAuth {
Ok(())
}

pub fn default_grant_type(&self) -> OAuthGrantType {
self.default_grant_type.unwrap_or_else(|| {
if self.code_grant.is_some() {
return OAuthGrantType::Code;
}
if self.device_grant.is_some() {
// TODO: implement
pub fn default_grant_type(&self) -> Result<OAuthGrantType, crate::Error> {
match self.default_grant_type {
Some(x) => Ok(x),
None => {
if self.code_grant.is_some() {
return Ok(OAuthGrantType::Code);
}
if self.device_code_grant.is_some() {
return Ok(OAuthGrantType::DeviceCode);
}
Err(crate::Error::ConfigError(
"cannot determine default grant_type".to_string(),
))
}
unreachable!();
})
}
}
}

Expand All @@ -384,7 +410,7 @@ pub struct ServerCodeGrant {
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct ServerDeviceGrant {
pub struct ServerDeviceCodeGrant {
pub device_authorization_endpoint: Option<url::Url>,
}

Expand Down Expand Up @@ -435,7 +461,7 @@ impl AwsSsoClientRegistrationCache {
token_endpoint: None,
scope: sso.scope.clone(),
code_grant: None,
device_grant: None,
device_code_grant: None,
client_expires_at: Some(self.expires_at),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub enum Error {
UrlParseError(#[from] url::ParseError),

#[error("AuthNotReadyError: flow not yet ready")]
AuthNotReadyError,
AuthNotReadyError { slow_down: bool },

#[error(transparent)]
OAuth2RequestTokenError(
Expand Down
Loading

0 comments on commit 53fff8f

Please sign in to comment.