From 560ba6e20c70afdbc3a091d399c2e0d4923dfac6 Mon Sep 17 00:00:00 2001 From: Andrew Phillips Date: Fri, 13 Mar 2026 22:04:38 -0300 Subject: [PATCH] fix: count_bounded error counting, clippy if-let, auth test dedup, doc tests - count_bounded: break on iterator error instead of counting errors as tokens - collapse nested if-let chains with let-chains in auth middleware - document JWT/Basic Auth as mutually exclusive - TailTokensFilter::clone uses empty buffer (always pre-filter) - fix 9 broken doc examples in server/common.rs - remove 7 duplicate auth tests from auth.rs (covered by auth_tests.rs) --- src/filter_plugin/tokens.rs | 2 +- src/modes/server/auth.rs | 113 -------------------------------- src/modes/server/common.rs | 124 ++++++++++++++++++------------------ src/tokenizer/mod.rs | 7 +- 4 files changed, 67 insertions(+), 179 deletions(-) diff --git a/src/filter_plugin/tokens.rs b/src/filter_plugin/tokens.rs index 8b0d57f..b6aa59b 100644 --- a/src/filter_plugin/tokens.rs +++ b/src/filter_plugin/tokens.rs @@ -275,7 +275,7 @@ impl FilterPlugin for TailTokensFilter { fn clone_box(&self) -> Box { Box::new(Self { count: self.count, - buffer: self.buffer.clone(), + buffer: Vec::new(), tokenizer: get_tokenizer(self.encoding).clone(), encoding: self.encoding, }) diff --git a/src/modes/server/auth.rs b/src/modes/server/auth.rs index 57d0f67..5b06291 100644 --- a/src/modes/server/auth.rs +++ b/src/modes/server/auth.rs @@ -116,116 +116,3 @@ pub fn validate_jwt(token: &str, secret: &str) -> Result { Ok(token_data.claims) } - -#[cfg(test)] -mod tests { - use super::*; - use jsonwebtoken::{EncodingKey, Header, encode}; - - fn make_token(claims: &serde_json::Value, secret: &str) -> String { - let header = Header::new(jsonwebtoken::Algorithm::HS256); - encode( - &header, - claims, - &EncodingKey::from_secret(secret.as_bytes()), - ) - .unwrap() - } - - #[test] - fn test_validate_jwt_valid_token() { - let secret = "test-secret"; - let claims = serde_json::json!({ - "sub": "test-client", - "exp": 9999999999usize, - "read": true, - "write": true, - "delete": false - }); - let token = make_token(&claims, secret); - - let result = validate_jwt(&token, secret); - assert!(result.is_ok()); - let claims = result.unwrap(); - assert_eq!(claims.sub, "test-client"); - assert!(claims.read); - assert!(claims.write); - assert!(!claims.delete); - } - - #[test] - fn test_validate_jwt_expired_token() { - let secret = "test-secret"; - let claims = serde_json::json!({ - "sub": "test-client", - "exp": 1000000000usize, - "read": true - }); - let token = make_token(&claims, secret); - - let result = validate_jwt(&token, secret); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Token expired"); - } - - #[test] - fn test_validate_jwt_wrong_secret() { - let claims = serde_json::json!({ - "sub": "test-client", - "exp": 9999999999usize, - "read": true - }); - let token = make_token(&claims, "correct-secret"); - - let result = validate_jwt(&token, "wrong-secret"); - assert!(result.is_err()); - } - - #[test] - fn test_validate_jwt_malformed_token() { - let result = validate_jwt("not.a.jwt", "secret"); - assert!(result.is_err()); - } - - #[test] - fn test_required_permission() { - assert_eq!(required_permission(&Method::GET), "read"); - assert_eq!(required_permission(&Method::HEAD), "read"); - assert_eq!(required_permission(&Method::POST), "write"); - assert_eq!(required_permission(&Method::PUT), "write"); - assert_eq!(required_permission(&Method::PATCH), "write"); - assert_eq!(required_permission(&Method::DELETE), "delete"); - } - - #[test] - fn test_check_permission() { - let claims = Claims { - sub: "test".to_string(), - exp: 9999999999, - read: true, - write: false, - delete: true, - }; - - assert!(check_permission(&claims, "read")); - assert!(!check_permission(&claims, "write")); - assert!(check_permission(&claims, "delete")); - assert!(!check_permission(&claims, "unknown")); - } - - #[test] - fn test_check_permission_default_false() { - // When fields are missing from JSON, serde(default) makes them false - let secret = "test-secret"; - let claims = serde_json::json!({ - "sub": "test-client", - "exp": 9999999999usize - }); - let token = make_token(&claims, secret); - - let claims = validate_jwt(&token, secret).unwrap(); - assert!(!claims.read); - assert!(!claims.write); - assert!(!claims.delete); - } -} diff --git a/src/modes/server/common.rs b/src/modes/server/common.rs index e57bbbf..a5a6394 100644 --- a/src/modes/server/common.rs +++ b/src/modes/server/common.rs @@ -7,10 +7,10 @@ use crate::services::item_service::ItemService; /// /// # Usage /// -/// ```rust +/// ```rust,ignore +/// // Illustrative — requires runtime values (db connection, settings). /// use keep::modes::server::common::{ServerConfig, AppState}; -/// let config = ServerConfig { address: "127.0.0.1".to_string(), ..Default::default() }; -/// let state = AppState { /* ... */ }; +/// let config = ServerConfig { address: "127.0.0.1".to_string(), port: Some(8080), /* ... */ }; /// ``` use anyhow::Result; use axum::{ @@ -38,7 +38,8 @@ use utoipa::ToSchema; /// /// # Examples /// -/// ``` +/// ```rust +/// use keep::modes::server::common::ServerConfig; /// let config = ServerConfig { /// address: "127.0.0.1".to_string(), /// port: Some(8080), @@ -105,7 +106,8 @@ pub struct ServerConfig { /// /// # Examples /// -/// ```rust +/// ```rust,ignore +/// // AppState requires runtime values (db connection, settings) not available in doctests. /// use keep::modes::server::common::AppState; /// use std::sync::Arc; /// use tokio::sync::Mutex; @@ -155,9 +157,9 @@ pub struct AppState { /// /// ```rust /// use keep::modes::server::common::ApiResponse; -/// let response: ApiResponse> = ApiResponse { +/// let response: ApiResponse = ApiResponse { /// success: true, -/// data: Some(items), +/// data: Some("items".to_string()), /// error: None, /// }; /// ``` @@ -190,7 +192,7 @@ pub struct ApiResponse { /// use keep::modes::server::common::ItemInfoListResponse; /// let response = ItemInfoListResponse { /// success: true, -/// data: Some(vec![item_info]), +/// data: Some(vec![]), /// error: None, /// }; /// ``` @@ -220,7 +222,7 @@ pub struct ItemInfoListResponse { /// use keep::modes::server::common::ItemInfoResponse; /// let response = ItemInfoResponse { /// success: true, -/// data: Some(item_info), +/// data: None, /// error: None, /// }; /// ``` @@ -250,7 +252,7 @@ pub struct ItemInfoResponse { /// use keep::modes::server::common::ItemContentInfoResponse; /// let response = ItemContentInfoResponse { /// success: true, -/// data: Some(content_info), +/// data: None, /// error: None, /// }; /// ``` @@ -280,7 +282,7 @@ pub struct ItemContentInfoResponse { /// use keep::modes::server::common::MetadataResponse; /// let response = MetadataResponse { /// success: true, -/// data: Some(meta_map), +/// data: None, /// error: None, /// }; /// ``` @@ -310,7 +312,7 @@ pub struct MetadataResponse { /// use keep::modes::server::common::StatusInfoResponse; /// let response = StatusInfoResponse { /// success: true, -/// data: Some(status_info), +/// data: None, /// error: None, /// }; /// ``` @@ -488,6 +490,7 @@ pub struct ListItemsQuery { /// length: 1024, /// stream: false, /// as_meta: false, +/// decompress: true, /// }; /// ``` #[derive(Debug, Deserialize, utoipa::ToSchema)] @@ -538,6 +541,7 @@ pub struct ItemQuery { /// length: 1024, /// stream: false, /// as_meta: false, +/// decompress: true, /// }; /// ``` #[derive(Debug, Deserialize, utoipa::ToSchema)] @@ -678,15 +682,15 @@ pub fn check_auth( let effective_username = username.as_deref().unwrap_or("keep"); - if let Some(auth_header) = headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - return check_basic_auth( - auth_str, - effective_username, - password.as_deref().unwrap_or(""), - password_hash, - ); - } + if let Some(auth_header) = headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + { + return check_basic_auth( + auth_str, + effective_username, + password.as_deref().unwrap_or(""), + password_hash, + ); } false } @@ -817,9 +821,13 @@ pub async fn logging_middleware( /// Creates authentication middleware for the application. /// /// This function returns a middleware that enforces authentication on protected routes. -/// When `jwt_secret` is set, it validates JWT tokens and checks permission claims -/// (read, write, delete) based on the HTTP method. Otherwise, it falls back to -/// Basic Auth password authentication. +/// +/// **JWT and Basic Auth are mutually exclusive.** When `jwt_secret` is set, the +/// middleware validates JWT (HS256) tokens and checks permission claims (read, write, +/// delete) based on the HTTP method. Requests without a valid Bearer token are +/// rejected with 401 — Basic Auth is **not** consulted as a fallback. +/// +/// When `jwt_secret` is not set, Basic Auth password authentication is used instead. /// /// # Arguments /// @@ -831,13 +839,6 @@ pub async fn logging_middleware( /// # Returns /// /// A clonable async middleware function for Axum. -/// -/// # Examples -/// -/// ``` -/// let auth_middleware = create_auth_middleware(None, Some("pass".to_string()), None, None); -/// router.layer(auth_middleware); -/// ``` #[allow(clippy::type_complexity)] pub fn create_auth_middleware( username: Option, @@ -868,40 +869,39 @@ pub fn create_auth_middleware( } // JWT authentication takes priority when secret is configured - if let Some(ref secret) = jwt_secret { - if let Some(auth_header) = headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if let Some(token) = auth_str.strip_prefix("Bearer ") { - match super::auth::validate_jwt(token, secret) { - Ok(claims) => { - let required = super::auth::required_permission(&method); - if !super::auth::check_permission(&claims, required) { - warn!( - "Forbidden: {method} {uri} from {addr} \ - (sub={}, missing permission: {required})", - claims.sub - ); - let mut response = - Response::new(axum::body::Body::from("Forbidden")); - *response.status_mut() = StatusCode::FORBIDDEN; - return Ok(response); - } - // JWT valid and authorized, proceed - let response = next.run(request).await; - return Ok(response); - } - Err(e) => { - warn!("JWT validation failed for {uri} from {addr}: {e}"); - let mut response = - Response::new(axum::body::Body::from("Unauthorized")); - *response.status_mut() = StatusCode::UNAUTHORIZED; - return Ok(response); - } - } + if let Some(ref secret) = jwt_secret + && let Some(auth_header) = headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + && let Some(token) = auth_str.strip_prefix("Bearer ") + { + match super::auth::validate_jwt(token, secret) { + Ok(claims) => { + let required = super::auth::required_permission(&method); + if !super::auth::check_permission(&claims, required) { + warn!( + "Forbidden: {method} {uri} from {addr} \ + (sub={}, missing permission: {required})", + claims.sub + ); + let mut response = Response::new(axum::body::Body::from("Forbidden")); + *response.status_mut() = StatusCode::FORBIDDEN; + return Ok(response); } + // JWT valid and authorized, proceed + let response = next.run(request).await; + return Ok(response); + } + Err(e) => { + warn!("JWT validation failed for {uri} from {addr}: {e}"); + let mut response = Response::new(axum::body::Body::from("Unauthorized")); + *response.status_mut() = StatusCode::UNAUTHORIZED; + return Ok(response); } } - // JWT secret configured but no valid Bearer token provided + } + + // JWT secret configured but no valid Bearer token provided + if jwt_secret.is_some() { warn!("Missing JWT token for {uri} from {addr}"); let mut response = Response::new(axum::body::Body::from("Unauthorized")); *response.status_mut() = StatusCode::UNAUTHORIZED; diff --git a/src/tokenizer/mod.rs b/src/tokenizer/mod.rs index fc1b5dc..7aeeb37 100644 --- a/src/tokenizer/mod.rs +++ b/src/tokenizer/mod.rs @@ -114,9 +114,10 @@ impl Tokenizer { let mut count = 0usize; let mut byte_pos = 0usize; for token_str in self.bpe.split_by_token_iter(text, false) { - if let Ok(s) = token_str { - byte_pos += s.len(); - } + let Ok(s) = token_str else { + break; + }; + byte_pos += s.len(); count += 1; if count >= max_tokens { break;