diff --git a/crates/matrix-sdk/src/oidc/cross_process.rs b/crates/matrix-sdk/src/oidc/cross_process.rs index 9d7cca4cb..65c29c0cb 100644 --- a/crates/matrix-sdk/src/oidc/cross_process.rs +++ b/crates/matrix-sdk/src/oidc/cross_process.rs @@ -221,7 +221,7 @@ pub enum CrossProcessRefreshLockError { MissingLock, /// Cross-process lock was set, but without session callbacks. - #[error("reload session callback must be set with Oidc::set_callbacks() for the cross-process lock to work")] + #[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")] MissingReloadSession, /// Session tokens returned by the reload_session callback were not for @@ -271,7 +271,7 @@ mod tests { let tmp_dir = tempfile::tempdir()?; let client = test_client_builder(Some("https://example.org".to_owned())) - .sqlite_store(tmp_dir, None) + .sqlite_store(&tmp_dir, None) .build() .await .unwrap(); @@ -331,7 +331,7 @@ mod tests { let tmp_dir = tempfile::tempdir()?; let client = - test_client_builder(Some(server.uri())).sqlite_store(tmp_dir, None).build().await?; + test_client_builder(Some(server.uri())).sqlite_store(&tmp_dir, None).build().await?; let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) }; @@ -385,7 +385,7 @@ mod tests { let tmp_dir = tempfile::tempdir()?; let client = test_client_builder(Some("https://example.org".to_owned())) - .sqlite_store(tmp_dir, None) + .sqlite_store(&tmp_dir, None) .build() .await?; @@ -437,11 +437,150 @@ mod tests { Ok(()) } + #[async_test] + async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> { + // Create the backend. + let prev_tokens = OidcSessionTokens { + access_token: "prev-access-token".to_owned(), + refresh_token: Some("prev-refresh-token".to_owned()), + latest_id_token: None, + }; + + let next_tokens = OidcSessionTokens { + access_token: "next-access-token".to_owned(), + refresh_token: Some("next-refresh-token".to_owned()), + latest_id_token: None, + }; + + let backend = Arc::new( + MockImpl::new() + .next_session_tokens(next_tokens.clone()) + .expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()), + ); + + // Create the first client. + let tmp_dir = tempfile::tempdir()?; + let client = test_client_builder(Some("https://example.org".to_owned())) + .sqlite_store(&tmp_dir, None) + .build() + .await?; + + let oidc = Oidc { client: client.clone(), backend: backend.clone() }; + oidc.enable_cross_process_refresh_lock("client1".to_owned()).await?; + oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?; + + // Create a second client, without restoring it, to test that a token update + // before restoration doesn't cause new issues. + let unrestored_client = test_client_builder(Some("https://example.org".to_owned())) + .sqlite_store(&tmp_dir, None) + .build() + .await?; + let unrestored_oidc = Oidc { client: unrestored_client.clone(), backend: backend.clone() }; + unrestored_oidc.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?; + + { + // Create a third client that will run a refresh while the others two are doing + // nothing. + let client3 = test_client_builder(Some("https://example.org".to_owned())) + .sqlite_store(&tmp_dir, None) + .build() + .await?; + + let oidc3 = Oidc { client: client3.clone(), backend: backend.clone() }; + oidc3.enable_cross_process_refresh_lock("client3".to_owned()).await?; + oidc3.restore_session(tests::mock_session(prev_tokens.clone())).await?; + + // Run a refresh in the second client; this will invalidate the tokens from the + // first token. + oidc3.refresh_access_token().await?; + + assert_eq!(oidc3.session_tokens(), Some(next_tokens.clone())); + + // Reading from the cross-process lock for the second client only shows the new + // tokens. + let xp_manager = + oidc3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?; + let guard = xp_manager.spin_lock().await?; + let actual_hash = compute_session_hash(&next_tokens); + assert_eq!(guard.db_hash, Some(actual_hash)); + assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash)); + assert!(!guard.hash_mismatch); + } + + { + // Restoring the client that was not restored yet will work Just Fine. + let oidc = unrestored_oidc; + + unrestored_client.set_session_callbacks( + Box::new({ + // This is only called because of extra checks in the code. + let tokens = next_tokens.clone(); + move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone())) + }), + Box::new(|_| panic!("save_session_callback shouldn't be called here")), + )?; + + oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?; + + // And this client is now aware of the latest tokens. + let xp_manager = + oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?; + let guard = xp_manager.spin_lock().await?; + let next_hash = compute_session_hash(&next_tokens); + assert_eq!(guard.db_hash, Some(next_hash)); + assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash)); + assert!(!guard.hash_mismatch); + + drop(oidc); + drop(unrestored_client); + } + + { + // The cross process lock has been correctly updated, and the next attempt to + // take it will result in a mismatch. + let xp_manager = + oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?; + let guard = xp_manager.spin_lock().await?; + let previous_hash = compute_session_hash(&prev_tokens); + let next_hash = compute_session_hash(&next_tokens); + assert_eq!(guard.db_hash, Some(next_hash)); + assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash)); + assert!(guard.hash_mismatch); + } + + client.set_session_callbacks( + Box::new({ + // This is only called because of extra checks in the code. + let tokens = next_tokens.clone(); + move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone())) + }), + Box::new(|_| panic!("save_session_callback shouldn't be called here")), + )?; + + oidc.refresh_access_token().await?; + + { + // The next attempt to take the lock isn't a mismatch. + let xp_manager = + oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?; + let guard = xp_manager.spin_lock().await?; + let actual_hash = compute_session_hash(&next_tokens); + assert_eq!(guard.db_hash, Some(actual_hash)); + assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash)); + assert!(!guard.hash_mismatch); + } + + // There should have been at most one refresh. + assert_eq!(*backend.num_refreshes.lock().unwrap(), 1); + + Ok(()) + } + #[async_test] async fn test_logout() -> anyhow::Result<()> { let tmp_dir = tempfile::tempdir()?; let client = test_client_builder(Some("https://example.org".to_owned())) - .sqlite_store(tmp_dir, None) + .sqlite_store(&tmp_dir, None) .build() .await?;