fix: count_bounded error counting, clippy if-let, auth test dedup, doc tests

- count_bounded: break on iterator error instead of counting errors as tokens
- collapse nested if-let chains with let-chains in auth middleware
- document JWT/Basic Auth as mutually exclusive
- TailTokensFilter::clone uses empty buffer (always pre-filter)
- fix 9 broken doc examples in server/common.rs
- remove 7 duplicate auth tests from auth.rs (covered by auth_tests.rs)
This commit is contained in:
2026-03-13 22:04:38 -03:00
parent a07bb6b350
commit 560ba6e20c
4 changed files with 67 additions and 179 deletions

View File

@@ -275,7 +275,7 @@ impl FilterPlugin for TailTokensFilter {
fn clone_box(&self) -> Box<dyn FilterPlugin> { fn clone_box(&self) -> Box<dyn FilterPlugin> {
Box::new(Self { Box::new(Self {
count: self.count, count: self.count,
buffer: self.buffer.clone(), buffer: Vec::new(),
tokenizer: get_tokenizer(self.encoding).clone(), tokenizer: get_tokenizer(self.encoding).clone(),
encoding: self.encoding, encoding: self.encoding,
}) })

View File

@@ -116,116 +116,3 @@ pub fn validate_jwt(token: &str, secret: &str) -> Result<Claims, String> {
Ok(token_data.claims) Ok(token_data.claims)
} }
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header, encode};
fn make_token(claims: &serde_json::Value, secret: &str) -> String {
let header = Header::new(jsonwebtoken::Algorithm::HS256);
encode(
&header,
claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
}
#[test]
fn test_validate_jwt_valid_token() {
let secret = "test-secret";
let claims = serde_json::json!({
"sub": "test-client",
"exp": 9999999999usize,
"read": true,
"write": true,
"delete": false
});
let token = make_token(&claims, secret);
let result = validate_jwt(&token, secret);
assert!(result.is_ok());
let claims = result.unwrap();
assert_eq!(claims.sub, "test-client");
assert!(claims.read);
assert!(claims.write);
assert!(!claims.delete);
}
#[test]
fn test_validate_jwt_expired_token() {
let secret = "test-secret";
let claims = serde_json::json!({
"sub": "test-client",
"exp": 1000000000usize,
"read": true
});
let token = make_token(&claims, secret);
let result = validate_jwt(&token, secret);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Token expired");
}
#[test]
fn test_validate_jwt_wrong_secret() {
let claims = serde_json::json!({
"sub": "test-client",
"exp": 9999999999usize,
"read": true
});
let token = make_token(&claims, "correct-secret");
let result = validate_jwt(&token, "wrong-secret");
assert!(result.is_err());
}
#[test]
fn test_validate_jwt_malformed_token() {
let result = validate_jwt("not.a.jwt", "secret");
assert!(result.is_err());
}
#[test]
fn test_required_permission() {
assert_eq!(required_permission(&Method::GET), "read");
assert_eq!(required_permission(&Method::HEAD), "read");
assert_eq!(required_permission(&Method::POST), "write");
assert_eq!(required_permission(&Method::PUT), "write");
assert_eq!(required_permission(&Method::PATCH), "write");
assert_eq!(required_permission(&Method::DELETE), "delete");
}
#[test]
fn test_check_permission() {
let claims = Claims {
sub: "test".to_string(),
exp: 9999999999,
read: true,
write: false,
delete: true,
};
assert!(check_permission(&claims, "read"));
assert!(!check_permission(&claims, "write"));
assert!(check_permission(&claims, "delete"));
assert!(!check_permission(&claims, "unknown"));
}
#[test]
fn test_check_permission_default_false() {
// When fields are missing from JSON, serde(default) makes them false
let secret = "test-secret";
let claims = serde_json::json!({
"sub": "test-client",
"exp": 9999999999usize
});
let token = make_token(&claims, secret);
let claims = validate_jwt(&token, secret).unwrap();
assert!(!claims.read);
assert!(!claims.write);
assert!(!claims.delete);
}
}

View File

@@ -7,10 +7,10 @@ use crate::services::item_service::ItemService;
/// ///
/// # Usage /// # Usage
/// ///
/// ```rust /// ```rust,ignore
/// // Illustrative — requires runtime values (db connection, settings).
/// use keep::modes::server::common::{ServerConfig, AppState}; /// use keep::modes::server::common::{ServerConfig, AppState};
/// let config = ServerConfig { address: "127.0.0.1".to_string(), ..Default::default() }; /// let config = ServerConfig { address: "127.0.0.1".to_string(), port: Some(8080), /* ... */ };
/// let state = AppState { /* ... */ };
/// ``` /// ```
use anyhow::Result; use anyhow::Result;
use axum::{ use axum::{
@@ -38,7 +38,8 @@ use utoipa::ToSchema;
/// ///
/// # Examples /// # Examples
/// ///
/// ``` /// ```rust
/// use keep::modes::server::common::ServerConfig;
/// let config = ServerConfig { /// let config = ServerConfig {
/// address: "127.0.0.1".to_string(), /// address: "127.0.0.1".to_string(),
/// port: Some(8080), /// port: Some(8080),
@@ -105,7 +106,8 @@ pub struct ServerConfig {
/// ///
/// # Examples /// # Examples
/// ///
/// ```rust /// ```rust,ignore
/// // AppState requires runtime values (db connection, settings) not available in doctests.
/// use keep::modes::server::common::AppState; /// use keep::modes::server::common::AppState;
/// use std::sync::Arc; /// use std::sync::Arc;
/// use tokio::sync::Mutex; /// use tokio::sync::Mutex;
@@ -155,9 +157,9 @@ pub struct AppState {
/// ///
/// ```rust /// ```rust
/// use keep::modes::server::common::ApiResponse; /// use keep::modes::server::common::ApiResponse;
/// let response: ApiResponse<Vec<ItemInfo>> = ApiResponse { /// let response: ApiResponse<String> = ApiResponse {
/// success: true, /// success: true,
/// data: Some(items), /// data: Some("items".to_string()),
/// error: None, /// error: None,
/// }; /// };
/// ``` /// ```
@@ -190,7 +192,7 @@ pub struct ApiResponse<T> {
/// use keep::modes::server::common::ItemInfoListResponse; /// use keep::modes::server::common::ItemInfoListResponse;
/// let response = ItemInfoListResponse { /// let response = ItemInfoListResponse {
/// success: true, /// success: true,
/// data: Some(vec![item_info]), /// data: Some(vec![]),
/// error: None, /// error: None,
/// }; /// };
/// ``` /// ```
@@ -220,7 +222,7 @@ pub struct ItemInfoListResponse {
/// use keep::modes::server::common::ItemInfoResponse; /// use keep::modes::server::common::ItemInfoResponse;
/// let response = ItemInfoResponse { /// let response = ItemInfoResponse {
/// success: true, /// success: true,
/// data: Some(item_info), /// data: None,
/// error: None, /// error: None,
/// }; /// };
/// ``` /// ```
@@ -250,7 +252,7 @@ pub struct ItemInfoResponse {
/// use keep::modes::server::common::ItemContentInfoResponse; /// use keep::modes::server::common::ItemContentInfoResponse;
/// let response = ItemContentInfoResponse { /// let response = ItemContentInfoResponse {
/// success: true, /// success: true,
/// data: Some(content_info), /// data: None,
/// error: None, /// error: None,
/// }; /// };
/// ``` /// ```
@@ -280,7 +282,7 @@ pub struct ItemContentInfoResponse {
/// use keep::modes::server::common::MetadataResponse; /// use keep::modes::server::common::MetadataResponse;
/// let response = MetadataResponse { /// let response = MetadataResponse {
/// success: true, /// success: true,
/// data: Some(meta_map), /// data: None,
/// error: None, /// error: None,
/// }; /// };
/// ``` /// ```
@@ -310,7 +312,7 @@ pub struct MetadataResponse {
/// use keep::modes::server::common::StatusInfoResponse; /// use keep::modes::server::common::StatusInfoResponse;
/// let response = StatusInfoResponse { /// let response = StatusInfoResponse {
/// success: true, /// success: true,
/// data: Some(status_info), /// data: None,
/// error: None, /// error: None,
/// }; /// };
/// ``` /// ```
@@ -488,6 +490,7 @@ pub struct ListItemsQuery {
/// length: 1024, /// length: 1024,
/// stream: false, /// stream: false,
/// as_meta: false, /// as_meta: false,
/// decompress: true,
/// }; /// };
/// ``` /// ```
#[derive(Debug, Deserialize, utoipa::ToSchema)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
@@ -538,6 +541,7 @@ pub struct ItemQuery {
/// length: 1024, /// length: 1024,
/// stream: false, /// stream: false,
/// as_meta: false, /// as_meta: false,
/// decompress: true,
/// }; /// };
/// ``` /// ```
#[derive(Debug, Deserialize, utoipa::ToSchema)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
@@ -678,8 +682,9 @@ pub fn check_auth(
let effective_username = username.as_deref().unwrap_or("keep"); let effective_username = username.as_deref().unwrap_or("keep");
if let Some(auth_header) = headers.get("authorization") { if let Some(auth_header) = headers.get("authorization")
if let Ok(auth_str) = auth_header.to_str() { && let Ok(auth_str) = auth_header.to_str()
{
return check_basic_auth( return check_basic_auth(
auth_str, auth_str,
effective_username, effective_username,
@@ -687,7 +692,6 @@ pub fn check_auth(
password_hash, password_hash,
); );
} }
}
false false
} }
@@ -817,9 +821,13 @@ pub async fn logging_middleware(
/// Creates authentication middleware for the application. /// Creates authentication middleware for the application.
/// ///
/// This function returns a middleware that enforces authentication on protected routes. /// This function returns a middleware that enforces authentication on protected routes.
/// When `jwt_secret` is set, it validates JWT tokens and checks permission claims ///
/// (read, write, delete) based on the HTTP method. Otherwise, it falls back to /// **JWT and Basic Auth are mutually exclusive.** When `jwt_secret` is set, the
/// Basic Auth password authentication. /// middleware validates JWT (HS256) tokens and checks permission claims (read, write,
/// delete) based on the HTTP method. Requests without a valid Bearer token are
/// rejected with 401 — Basic Auth is **not** consulted as a fallback.
///
/// When `jwt_secret` is not set, Basic Auth password authentication is used instead.
/// ///
/// # Arguments /// # Arguments
/// ///
@@ -831,13 +839,6 @@ pub async fn logging_middleware(
/// # Returns /// # Returns
/// ///
/// A clonable async middleware function for Axum. /// A clonable async middleware function for Axum.
///
/// # Examples
///
/// ```
/// let auth_middleware = create_auth_middleware(None, Some("pass".to_string()), None, None);
/// router.layer(auth_middleware);
/// ```
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub fn create_auth_middleware( pub fn create_auth_middleware(
username: Option<String>, username: Option<String>,
@@ -868,10 +869,11 @@ pub fn create_auth_middleware(
} }
// JWT authentication takes priority when secret is configured // JWT authentication takes priority when secret is configured
if let Some(ref secret) = jwt_secret { if let Some(ref secret) = jwt_secret
if let Some(auth_header) = headers.get("authorization") { && let Some(auth_header) = headers.get("authorization")
if let Ok(auth_str) = auth_header.to_str() { && let Ok(auth_str) = auth_header.to_str()
if let Some(token) = auth_str.strip_prefix("Bearer ") { && let Some(token) = auth_str.strip_prefix("Bearer ")
{
match super::auth::validate_jwt(token, secret) { match super::auth::validate_jwt(token, secret) {
Ok(claims) => { Ok(claims) => {
let required = super::auth::required_permission(&method); let required = super::auth::required_permission(&method);
@@ -881,8 +883,7 @@ pub fn create_auth_middleware(
(sub={}, missing permission: {required})", (sub={}, missing permission: {required})",
claims.sub claims.sub
); );
let mut response = let mut response = Response::new(axum::body::Body::from("Forbidden"));
Response::new(axum::body::Body::from("Forbidden"));
*response.status_mut() = StatusCode::FORBIDDEN; *response.status_mut() = StatusCode::FORBIDDEN;
return Ok(response); return Ok(response);
} }
@@ -892,16 +893,15 @@ pub fn create_auth_middleware(
} }
Err(e) => { Err(e) => {
warn!("JWT validation failed for {uri} from {addr}: {e}"); warn!("JWT validation failed for {uri} from {addr}: {e}");
let mut response = let mut response = Response::new(axum::body::Body::from("Unauthorized"));
Response::new(axum::body::Body::from("Unauthorized"));
*response.status_mut() = StatusCode::UNAUTHORIZED; *response.status_mut() = StatusCode::UNAUTHORIZED;
return Ok(response); return Ok(response);
} }
} }
} }
}
}
// JWT secret configured but no valid Bearer token provided // JWT secret configured but no valid Bearer token provided
if jwt_secret.is_some() {
warn!("Missing JWT token for {uri} from {addr}"); warn!("Missing JWT token for {uri} from {addr}");
let mut response = Response::new(axum::body::Body::from("Unauthorized")); let mut response = Response::new(axum::body::Body::from("Unauthorized"));
*response.status_mut() = StatusCode::UNAUTHORIZED; *response.status_mut() = StatusCode::UNAUTHORIZED;

View File

@@ -114,9 +114,10 @@ impl Tokenizer {
let mut count = 0usize; let mut count = 0usize;
let mut byte_pos = 0usize; let mut byte_pos = 0usize;
for token_str in self.bpe.split_by_token_iter(text, false) { for token_str in self.bpe.split_by_token_iter(text, false) {
if let Ok(s) = token_str { let Ok(s) = token_str else {
break;
};
byte_pos += s.len(); byte_pos += s.len();
}
count += 1; count += 1;
if count >= max_tokens { if count >= max_tokens {
break; break;