Skip to content

Commit

Permalink
Save refresh tokens, correctly handle user code expiry (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpaoliello authored Sep 13, 2023
1 parent 407f361 commit cff8c3b
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 19 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ serde = "1.0"
serde_json = "1.0"
sys-info = "0.9"
tokio = { version = "1.26", features = ["rt", "net", "time", "rt-multi-thread"] }
windows-sys = {version = "0.48", features = ["Win32_Security_Credentials"] }

# Build openssl from source instead of linking it.
# Required for cross-compilation.
Expand Down
141 changes: 122 additions & 19 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::cred_store;
use crate::http::{AppendPaths, Client};
use anyhow::{anyhow, bail, Context, Result};
use reqwest::{StatusCode, Url};
Expand Down Expand Up @@ -56,13 +57,13 @@ pub enum AuthMessage {
}

impl Authenticator {
pub fn new(sender: Sender<AuthMessage>, base_url: &str) -> Self {
pub fn new(sender: Sender<AuthMessage>, base_url: &str, refresh_token: Option<String>) -> Self {
let base_url = Url::parse(base_url).unwrap();
Self {
client: Client::new(),
refresh_after: Instant::now(),
access_token: None,
refresh_token: None,
refresh_token,
sender,
device_code_url: base_url.append_path("devicecode"),
token_url: base_url.append_path("token"),
Expand All @@ -72,7 +73,8 @@ impl Authenticator {
pub async fn get_token(&mut self) -> Result<String> {
if self.access_token.is_none() || Instant::now() > self.refresh_after {
let response = if let Some(refresh_token) = &self.refresh_token {
self.client
let result = self
.client
.post::<TokenResponse>(
self.token_url.clone(),
&[
Expand All @@ -84,7 +86,12 @@ impl Authenticator {
None,
)
.await
.with_context(|| "Refresh token")?
.with_context(|| "Refresh token");
if result.is_err() {
// On error, assume we have a bad refresh token.
self.refresh_token = None;
}
result?
} else {
'outer: loop {
let device_response = self
Expand Down Expand Up @@ -126,17 +133,22 @@ impl Authenticator {
if let TokenResponse::Failure(TokenResponseError { error, .. }) =
&token_response
{
if error == "authorization_pending" {
tokio::time::sleep(Duration::from_secs(device_response.interval))
match error.as_str() {
"authorization_pending" => {
tokio::time::sleep(Duration::from_secs(
device_response.interval,
))
.await;

if device_response_expiry <= Instant::now() {
// Code has expired, get a new one.
continue 'outer;
} else {
// Check if the user has approved the code.
continue;
if device_response_expiry <= Instant::now() {
// Code has expired, get a new one.
continue 'outer;
} else {
// Check if the user has approved the code.
continue;
}
}
_ => continue 'outer,
}
}

Expand All @@ -157,6 +169,7 @@ impl Authenticator {
.checked_sub(REFRESH_TOKEN_PADDING)
.and_then(|expires_in| Instant::now().checked_add(expires_in))
.ok_or_else(|| anyhow!("Token expires too quickly"))?;
cred_store::store_refresh_token(&response.refresh_token);
self.refresh_token = Some(response.refresh_token);
self.access_token = Some(response.access_token);
}
Expand Down Expand Up @@ -196,17 +209,17 @@ async fn auth_then_refresh() {
.create();

let (sender, mut reciever) = tokio::sync::mpsc::channel(8);
let mut authenticator = Authenticator::new(sender, &url);
let mut authenticator = Authenticator::new(sender, &url, None);

// Initial get token.
let token = authenticator.get_token().await.unwrap();
assert_eq!(token, "ac");
assert_eq!(authenticator.refresh_token.as_ref().unwrap(), "rt");
assert_eq!(
reciever.recv().await.unwrap(),
reciever.try_recv().unwrap(),
AuthMessage::HasClientCode("vu".to_string(), "uc".to_string())
);
assert_eq!(reciever.recv().await.unwrap(), AuthMessage::Completed);
assert_eq!(reciever.try_recv().unwrap(), AuthMessage::Completed);

device_mock.assert();
token_mock.assert();
Expand Down Expand Up @@ -270,7 +283,7 @@ async fn device_code_expired() {
),
mockito::Matcher::UrlEncoded("device_code".into(), "dc1".into()),
]))
.with_body(r#"{ "error": "authorization_pending", "error_description": ""}"#)
.with_body(r#"{ "error": "expired_token", "error_description": ""}"#)
.with_status(400)
.expect(1)
.create();
Expand All @@ -290,20 +303,110 @@ async fn device_code_expired() {
.create();

let (sender, mut reciever) = tokio::sync::mpsc::channel(8);
let mut authenticator = Authenticator::new(sender, &url);
let mut authenticator = Authenticator::new(sender, &url, None);
let token = authenticator.get_token().await.unwrap();
assert_eq!(token, "ac");
assert_eq!(authenticator.refresh_token.as_ref().unwrap(), "rt");
assert_eq!(
reciever.recv().await.unwrap(),
reciever.try_recv().unwrap(),
AuthMessage::HasClientCode("vu1".to_string(), "uc1".to_string())
);
assert_eq!(
reciever.recv().await.unwrap(),
reciever.try_recv().unwrap(),
AuthMessage::HasClientCode("vu2".to_string(), "uc2".to_string())
);

device_mock.assert();
failed_token_mock.assert();
success_token_mock.assert();
}

#[tokio::test]
async fn with_existing_refresh_token() {
let mut server = mockito::Server::new();
let url = server.url();

let (sender, mut reciever) = tokio::sync::mpsc::channel(8);
let mut authenticator = Authenticator::new(sender, &url, Some("rt".to_string()));

// We have a refresh token, so it should be used.
let refresh_token_mock = server
.mock("POST", "/token")
.match_body(mockito::Matcher::AllOf(vec![
mockito::Matcher::UrlEncoded("client_id".into(), CLIENT_ID.into()),
mockito::Matcher::UrlEncoded("grant_type".into(), "refresh_token".into()),
mockito::Matcher::UrlEncoded("scope".into(), SCOPE.into()),
mockito::Matcher::UrlEncoded("refresh_token".into(), "rt".into()),
]))
.with_body(r#"{ "access_token": "ac2", "refresh_token": "rt2", "expires_in": 3600 } "#)
.expect(1)
.create();
let token = authenticator.get_token().await.unwrap();
assert_eq!(token, "ac2");
assert_eq!(authenticator.refresh_token.as_ref().unwrap(), "rt2");
assert!(matches!(
reciever.try_recv(),
Err(tokio::sync::mpsc::error::TryRecvError::Empty)
));
refresh_token_mock.assert();
}

#[tokio::test]
async fn with_existing_but_expired_refresh_token() {
let mut server = mockito::Server::new();
let url = server.url();

let (sender, mut reciever) = tokio::sync::mpsc::channel(8);
let mut authenticator = Authenticator::new(sender, &url, Some("rt".to_string()));

// We have a refresh token, so it should be used.
let refresh_token_mock = server
.mock("POST", "/token")
.match_body(mockito::Matcher::AllOf(vec![
mockito::Matcher::UrlEncoded("client_id".into(), CLIENT_ID.into()),
mockito::Matcher::UrlEncoded("grant_type".into(), "refresh_token".into()),
mockito::Matcher::UrlEncoded("scope".into(), SCOPE.into()),
mockito::Matcher::UrlEncoded("refresh_token".into(), "rt".into()),
]))
.with_body(r#"{ "error": "Refresh token is expired", "error_description": ""}"#)
.with_status(400)
.expect(1)
.create();
// But it's expired, so there will be a call to the normal flow
let device_mock = server.mock("POST", "/devicecode")
.match_body(mockito::Matcher::AllOf(vec![
mockito::Matcher::UrlEncoded("client_id".into(), CLIENT_ID.into()),
mockito::Matcher::UrlEncoded("scope".into(), SCOPE.into())
]))
.with_body(r#"{ "device_code": "dc", "user_code": "uc", "verification_uri": "vu", "interval": 0, "expires_in": 3600 } "#)
.expect(1)
.create();
let token_mock = server
.mock("POST", "/token")
.match_body(mockito::Matcher::AllOf(vec![
mockito::Matcher::UrlEncoded("client_id".into(), CLIENT_ID.into()),
mockito::Matcher::UrlEncoded(
"grant_type".into(),
"urn:ietf:params:oauth:grant-type:device_code".into(),
),
mockito::Matcher::UrlEncoded("device_code".into(), "dc".into()),
]))
.with_body(r#"{ "access_token": "ac2", "refresh_token": "rt2", "expires_in": 60 } "#)
.expect(1)
.create();
let token = authenticator.get_token().await;
assert!(token.is_err());
refresh_token_mock.assert();

let token = authenticator.get_token().await.unwrap();
assert_eq!(token, "ac2");
assert_eq!(authenticator.refresh_token.as_ref().unwrap(), "rt2");
assert_eq!(
reciever.try_recv().unwrap(),
AuthMessage::HasClientCode("vu".to_string(), "uc".to_string())
);
assert_eq!(reciever.try_recv().unwrap(), AuthMessage::Completed);
device_mock.assert();
token_mock.assert();
refresh_token_mock.assert();
}
69 changes: 69 additions & 0 deletions src/cred_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#[cfg(windows)]
mod windows {
use windows_sys::core::PCWSTR;
use windows_sys::w;
use windows_sys::Win32::Foundation::{FILETIME, TRUE};
use windows_sys::Win32::Security::Credentials::{
CredFree, CredReadW, CredWriteW, CREDENTIALW, CRED_PERSIST_LOCAL_MACHINE, CRED_TYPE_GENERIC,
};

const TARGET_NAME: PCWSTR = w!("OneDriveSlideShow");

pub fn get_refresh_token() -> Option<String> {
let mut p_credential: *mut CREDENTIALW = std::ptr::null_mut() as *mut _;
let bytes = unsafe {
if CredReadW(
TARGET_NAME,
CRED_TYPE_GENERIC,
0,
&mut p_credential as *mut _,
) != TRUE
{
return None;
}
std::slice::from_raw_parts(
(*p_credential).CredentialBlob,
(*p_credential).CredentialBlobSize as usize,
)
};
let token = String::from_utf8(bytes.to_vec()).map_err(Box::new);
unsafe { CredFree(p_credential as *mut _) };
token.ok()
}

pub fn store_refresh_token(cred: &str) {
let credential = CREDENTIALW {
Flags: 0,
Type: CRED_TYPE_GENERIC,
TargetName: TARGET_NAME as *mut _,
Comment: w!("OneDrive Slideshow Refresh Token") as *mut _,
LastWritten: FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
},
CredentialBlobSize: cred.len() as u32,
CredentialBlob: cred.as_bytes().as_ptr() as *mut u8,
Persist: CRED_PERSIST_LOCAL_MACHINE,
AttributeCount: 0,
Attributes: std::ptr::null_mut(),
TargetAlias: std::ptr::null_mut(),
UserName: std::ptr::null_mut(),
};
unsafe {
CredWriteW(&credential, 0);
}
}
}
#[cfg(windows)]
pub use windows::*;

#[cfg(not(windows))]
mod other {
pub fn get_refresh_token() -> Option<String> {
None
}

pub fn store_refresh_token(_cred: &str) {}
}
#[cfg(not(windows))]
pub use other::*;
2 changes: 2 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release

mod auth;
mod cred_store;
mod http;
mod image_loader;

Expand Down Expand Up @@ -164,6 +165,7 @@ async fn image_load_loop(ui_sender: Sender<Result<AppState>>, ctx: egui::Context
let mut authenticator = Authenticator::new(
auth_sender,
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
cred_store::get_refresh_token(),
);
let loader = ImageLoader::new(
"https://graph.microsoft.com/v1.0/me/drive",
Expand Down

0 comments on commit cff8c3b

Please sign in to comment.