1use std::collections::BTreeMap;
8
9use camino::Utf8PathBuf;
10use mas_iana::jose::JsonWebSignatureAlg;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize, de::Error};
13use serde_with::{serde_as, skip_serializing_none};
14use ulid::Ulid;
15use url::Url;
16
17use crate::{ClientSecret, ClientSecretRaw, ConfigurationSection};
18
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
21pub struct UpstreamOAuth2Config {
22 pub providers: Vec<Provider>,
24}
25
26impl UpstreamOAuth2Config {
27 pub(crate) fn is_default(&self) -> bool {
29 self.providers.is_empty()
30 }
31}
32
33impl ConfigurationSection for UpstreamOAuth2Config {
34 const PATH: Option<&'static str> = Some("upstream_oauth2");
35
36 fn validate(
37 &self,
38 figment: &figment::Figment,
39 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
40 for (index, provider) in self.providers.iter().enumerate() {
41 let annotate = |mut error: figment::Error| {
42 error.metadata = figment
43 .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
44 .cloned();
45 error.profile = Some(figment::Profile::Default);
46 error.path = vec![
47 Self::PATH.unwrap().to_owned(),
48 "providers".to_owned(),
49 index.to_string(),
50 ];
51 error
52 };
53
54 if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
55 && provider.issuer.is_none()
56 {
57 return Err(annotate(figment::Error::custom(
58 "The `issuer` field is required when discovery is enabled",
59 ))
60 .into());
61 }
62
63 match provider.token_endpoint_auth_method {
64 TokenAuthMethod::None
65 | TokenAuthMethod::PrivateKeyJwt
66 | TokenAuthMethod::SignInWithApple => {
67 if provider.client_secret.is_some() {
68 return Err(annotate(figment::Error::custom(
69 "Unexpected field `client_secret` for the selected authentication method",
70 )).into());
71 }
72 }
73 TokenAuthMethod::ClientSecretBasic
74 | TokenAuthMethod::ClientSecretPost
75 | TokenAuthMethod::ClientSecretJwt => {
76 if provider.client_secret.is_none() {
77 return Err(annotate(figment::Error::missing_field("client_secret")).into());
78 }
79 }
80 }
81
82 match provider.token_endpoint_auth_method {
83 TokenAuthMethod::None
84 | TokenAuthMethod::ClientSecretBasic
85 | TokenAuthMethod::ClientSecretPost
86 | TokenAuthMethod::SignInWithApple => {
87 if provider.token_endpoint_auth_signing_alg.is_some() {
88 return Err(annotate(figment::Error::custom(
89 "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
90 )).into());
91 }
92 }
93 TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
94 if provider.token_endpoint_auth_signing_alg.is_none() {
95 return Err(annotate(figment::Error::missing_field(
96 "token_endpoint_auth_signing_alg",
97 ))
98 .into());
99 }
100 }
101 }
102
103 match provider.token_endpoint_auth_method {
104 TokenAuthMethod::SignInWithApple => {
105 if provider.sign_in_with_apple.is_none() {
106 return Err(
107 annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
108 );
109 }
110 }
111
112 _ => {
113 if provider.sign_in_with_apple.is_some() {
114 return Err(annotate(figment::Error::custom(
115 "Unexpected field `sign_in_with_apple` for the selected authentication method",
116 )).into());
117 }
118 }
119 }
120
121 if provider.claims_imports.skip_confirmation {
122 if provider.claims_imports.localpart.action != ImportAction::Require {
123 return Err(annotate(figment::Error::custom(
124 "The field `action` must be `require` when `skip_confirmation` is set to `true`",
125 )).with_path("claims_imports.localpart").into());
126 }
127
128 if provider.claims_imports.email.action == ImportAction::Suggest {
129 return Err(annotate(figment::Error::custom(
130 "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
131 )).with_path("claims_imports.email").into());
132 }
133
134 if provider.claims_imports.displayname.action == ImportAction::Suggest {
135 return Err(annotate(figment::Error::custom(
136 "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
137 )).with_path("claims_imports.displayname").into());
138 }
139 }
140
141 if matches!(
142 provider.claims_imports.localpart.on_conflict,
143 OnConflict::Add | OnConflict::Replace | OnConflict::Set
144 ) && !matches!(
145 provider.claims_imports.localpart.action,
146 ImportAction::Force | ImportAction::Require
147 ) {
148 return Err(annotate(figment::Error::custom(
149 "The field `action` must be either `force` or `require` when `on_conflict` is set to `add`, `replace` or `set`",
150 )).with_path("claims_imports.localpart").into());
151 }
152 }
153
154 Ok(())
155 }
156}
157
158#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
160#[serde(rename_all = "snake_case")]
161pub enum ResponseMode {
162 Query,
165
166 FormPost,
171}
172
173#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
175#[serde(rename_all = "snake_case")]
176pub enum TokenAuthMethod {
177 None,
179
180 ClientSecretBasic,
183
184 ClientSecretPost,
187
188 ClientSecretJwt,
191
192 PrivateKeyJwt,
195
196 SignInWithApple,
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
202#[serde(rename_all = "lowercase")]
203pub enum ImportAction {
204 #[default]
206 Ignore,
207
208 Suggest,
210
211 Force,
213
214 Require,
216}
217
218impl ImportAction {
219 #[allow(clippy::trivially_copy_pass_by_ref)]
220 const fn is_default(&self) -> bool {
221 matches!(self, ImportAction::Ignore)
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
227#[serde(rename_all = "lowercase")]
228pub enum OnConflict {
229 #[default]
231 Fail,
232
233 Add,
236
237 Replace,
239
240 Set,
243}
244
245impl OnConflict {
246 #[allow(clippy::trivially_copy_pass_by_ref)]
247 const fn is_default(&self) -> bool {
248 matches!(self, OnConflict::Fail)
249 }
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
254pub struct SubjectImportPreference {
255 #[serde(default, skip_serializing_if = "Option::is_none")]
259 pub template: Option<String>,
260}
261
262impl SubjectImportPreference {
263 const fn is_default(&self) -> bool {
264 self.template.is_none()
265 }
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
270pub struct LocalpartImportPreference {
271 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
273 pub action: ImportAction,
274
275 #[serde(default, skip_serializing_if = "Option::is_none")]
279 pub template: Option<String>,
280
281 #[serde(default, skip_serializing_if = "OnConflict::is_default")]
283 pub on_conflict: OnConflict,
284}
285
286impl LocalpartImportPreference {
287 const fn is_default(&self) -> bool {
288 self.action.is_default() && self.template.is_none()
289 }
290}
291
292#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
294pub struct DisplaynameImportPreference {
295 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
297 pub action: ImportAction,
298
299 #[serde(default, skip_serializing_if = "Option::is_none")]
303 pub template: Option<String>,
304}
305
306impl DisplaynameImportPreference {
307 const fn is_default(&self) -> bool {
308 self.action.is_default() && self.template.is_none()
309 }
310}
311
312#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
314pub struct EmailImportPreference {
315 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
317 pub action: ImportAction,
318
319 #[serde(default, skip_serializing_if = "Option::is_none")]
323 pub template: Option<String>,
324}
325
326impl EmailImportPreference {
327 const fn is_default(&self) -> bool {
328 self.action.is_default() && self.template.is_none()
329 }
330}
331
332#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
334pub struct AccountNameImportPreference {
335 #[serde(default, skip_serializing_if = "Option::is_none")]
340 pub template: Option<String>,
341}
342
343impl AccountNameImportPreference {
344 const fn is_default(&self) -> bool {
345 self.template.is_none()
346 }
347}
348
349#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
351pub struct ClaimsImports {
352 #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
354 pub subject: SubjectImportPreference,
355
356 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
361 pub skip_confirmation: bool,
362
363 #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
365 pub localpart: LocalpartImportPreference,
366
367 #[serde(
369 default,
370 skip_serializing_if = "DisplaynameImportPreference::is_default"
371 )]
372 pub displayname: DisplaynameImportPreference,
373
374 #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
376 pub email: EmailImportPreference,
377
378 #[serde(
380 default,
381 skip_serializing_if = "AccountNameImportPreference::is_default"
382 )]
383 pub account_name: AccountNameImportPreference,
384}
385
386impl ClaimsImports {
387 const fn is_default(&self) -> bool {
388 self.subject.is_default()
389 && self.localpart.is_default()
390 && !self.skip_confirmation
391 && self.displayname.is_default()
392 && self.email.is_default()
393 && self.account_name.is_default()
394 }
395}
396
397#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
399#[serde(rename_all = "snake_case")]
400pub enum DiscoveryMode {
401 #[default]
403 Oidc,
404
405 Insecure,
407
408 Disabled,
410}
411
412impl DiscoveryMode {
413 #[allow(clippy::trivially_copy_pass_by_ref)]
414 const fn is_default(&self) -> bool {
415 matches!(self, DiscoveryMode::Oidc)
416 }
417}
418
419#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
422#[serde(rename_all = "snake_case")]
423pub enum PkceMethod {
424 #[default]
428 Auto,
429
430 Always,
432
433 Never,
435}
436
437impl PkceMethod {
438 #[allow(clippy::trivially_copy_pass_by_ref)]
439 const fn is_default(&self) -> bool {
440 matches!(self, PkceMethod::Auto)
441 }
442}
443
444fn default_true() -> bool {
445 true
446}
447
448#[allow(clippy::trivially_copy_pass_by_ref)]
449fn is_default_true(value: &bool) -> bool {
450 *value
451}
452
453#[allow(clippy::ref_option)]
454fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
455 *signed_response_alg == signed_response_alg_default()
456}
457
458#[allow(clippy::unnecessary_wraps)]
459fn signed_response_alg_default() -> JsonWebSignatureAlg {
460 JsonWebSignatureAlg::Rs256
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
464pub struct SignInWithApple {
465 #[serde(skip_serializing_if = "Option::is_none")]
467 #[schemars(with = "Option<String>")]
468 pub private_key_file: Option<Utf8PathBuf>,
469
470 #[serde(skip_serializing_if = "Option::is_none")]
472 pub private_key: Option<String>,
473
474 pub team_id: String,
476
477 pub key_id: String,
479}
480
481fn default_scope() -> String {
482 "openid".to_owned()
483}
484
485fn is_default_scope(scope: &str) -> bool {
486 scope == default_scope()
487}
488
489#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
491#[serde(rename_all = "snake_case")]
492pub enum OnBackchannelLogout {
493 #[default]
495 DoNothing,
496
497 LogoutBrowserOnly,
499
500 LogoutAll,
503}
504
505impl OnBackchannelLogout {
506 #[allow(clippy::trivially_copy_pass_by_ref)]
507 const fn is_default(&self) -> bool {
508 matches!(self, OnBackchannelLogout::DoNothing)
509 }
510}
511
512#[serde_as]
514#[skip_serializing_none]
515#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
516pub struct Provider {
517 #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
521 pub enabled: bool,
522
523 #[schemars(
525 with = "String",
526 regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
527 description = "A ULID as per https://github.com/ulid/spec"
528 )]
529 pub id: Ulid,
530
531 #[serde(skip_serializing_if = "Option::is_none")]
546 pub synapse_idp_id: Option<String>,
547
548 #[serde(skip_serializing_if = "Option::is_none")]
552 pub issuer: Option<String>,
553
554 #[serde(skip_serializing_if = "Option::is_none")]
556 pub human_name: Option<String>,
557
558 #[serde(skip_serializing_if = "Option::is_none")]
571 pub brand_name: Option<String>,
572
573 pub client_id: String,
575
576 #[schemars(with = "ClientSecretRaw")]
581 #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
582 #[serde(flatten)]
583 pub client_secret: Option<ClientSecret>,
584
585 pub token_endpoint_auth_method: TokenAuthMethod,
587
588 #[serde(skip_serializing_if = "Option::is_none")]
590 pub sign_in_with_apple: Option<SignInWithApple>,
591
592 #[serde(skip_serializing_if = "Option::is_none")]
597 pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
598
599 #[serde(
604 default = "signed_response_alg_default",
605 skip_serializing_if = "is_signed_response_alg_default"
606 )]
607 pub id_token_signed_response_alg: JsonWebSignatureAlg,
608
609 #[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
613 pub scope: String,
614
615 #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
620 pub discovery_mode: DiscoveryMode,
621
622 #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
627 pub pkce_method: PkceMethod,
628
629 #[serde(default)]
635 pub fetch_userinfo: bool,
636
637 #[serde(skip_serializing_if = "Option::is_none")]
643 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
644
645 #[serde(skip_serializing_if = "Option::is_none")]
649 pub authorization_endpoint: Option<Url>,
650
651 #[serde(skip_serializing_if = "Option::is_none")]
655 pub userinfo_endpoint: Option<Url>,
656
657 #[serde(skip_serializing_if = "Option::is_none")]
661 pub token_endpoint: Option<Url>,
662
663 #[serde(skip_serializing_if = "Option::is_none")]
667 pub jwks_uri: Option<Url>,
668
669 #[serde(skip_serializing_if = "Option::is_none")]
671 pub response_mode: Option<ResponseMode>,
672
673 #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
676 pub claims_imports: ClaimsImports,
677
678 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
682 pub additional_authorization_parameters: BTreeMap<String, String>,
683
684 #[serde(default)]
689 pub forward_login_hint: bool,
690
691 #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
695 pub on_backchannel_logout: OnBackchannelLogout,
696}
697
698impl Provider {
699 pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
707 Ok(match &self.client_secret {
708 Some(client_secret) => Some(client_secret.value().await?),
709 None => None,
710 })
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use std::str::FromStr;
717
718 use figment::{
719 Figment, Jail,
720 providers::{Format, Yaml},
721 };
722 use tokio::{runtime::Handle, task};
723
724 use super::*;
725
726 #[tokio::test]
727 async fn load_config() {
728 task::spawn_blocking(|| {
729 Jail::expect_with(|jail| {
730 jail.create_file(
731 "config.yaml",
732 r#"
733 upstream_oauth2:
734 providers:
735 - id: 01GFWR28C4KNE04WG3HKXB7C9R
736 client_id: upstream-oauth2
737 token_endpoint_auth_method: none
738
739 - id: 01GFWR32NCQ12B8Z0J8CPXRRB6
740 client_id: upstream-oauth2
741 client_secret_file: secret
742 token_endpoint_auth_method: client_secret_basic
743
744 - id: 01GFWR3WHR93Y5HK389H28VHZ9
745 client_id: upstream-oauth2
746 client_secret: c1!3n753c237
747 token_endpoint_auth_method: client_secret_post
748
749 - id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
750 client_id: upstream-oauth2
751 client_secret_file: secret
752 token_endpoint_auth_method: client_secret_jwt
753
754 - id: 01GFWR4BNFDCC4QDG6AMSP1VRR
755 client_id: upstream-oauth2
756 token_endpoint_auth_method: private_key_jwt
757 jwks:
758 keys:
759 - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
760 kty: "RSA"
761 alg: "RS256"
762 use: "sig"
763 e: "AQAB"
764 n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
765
766 - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
767 kty: "RSA"
768 alg: "RS256"
769 use: "sig"
770 e: "AQAB"
771 n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
772 "#,
773 )?;
774 jail.create_file("secret", r"c1!3n753c237")?;
775
776 let config = Figment::new()
777 .merge(Yaml::file("config.yaml"))
778 .extract_inner::<UpstreamOAuth2Config>("upstream_oauth2")?;
779
780 assert_eq!(config.providers.len(), 5);
781
782 assert_eq!(
783 config.providers[1].id,
784 Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
785 );
786
787 assert!(config.providers[0].client_secret.is_none());
788 assert!(matches!(config.providers[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
789 assert!(matches!(config.providers[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
790 assert!(matches!(config.providers[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
791 assert!(config.providers[4].client_secret.is_none());
792
793 Handle::current().block_on(async move {
794 assert_eq!(config.providers[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
795 assert_eq!(config.providers[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
796 assert_eq!(config.providers[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
797 });
798
799 Ok(())
800 });
801 }).await.unwrap();
802 }
803}