use anyhow::Result; use axum::{ extract::{Request, ConnectInfo}, http::{HeaderMap, StatusCode}, middleware::Next, response::Response, }; use base64::Engine; use log::{info, warn}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; use tokio::sync::Mutex; use utoipa::ToSchema; #[derive(Debug, Clone)] pub struct ServerConfig { pub address: String, pub port: Option, pub password: Option, pub password_hash: Option, } #[derive(Clone)] pub struct AppState { pub db: Arc>, pub data_dir: PathBuf, pub password: Option, pub password_hash: Option, } #[derive(Serialize, Deserialize, ToSchema)] #[schema(description = "Standard API response wrapper containing success status, data payload, and error information")] pub struct ApiResponse { pub success: bool, pub data: Option, pub error: Option, } // Specific response types for OpenAPI documentation #[derive(Serialize, Deserialize, ToSchema)] pub struct ItemInfoListResponse { pub success: bool, pub data: Option>, pub error: Option, } #[derive(Serialize, Deserialize, ToSchema)] pub struct ItemInfoResponse { pub success: bool, pub data: Option, pub error: Option, } #[derive(Serialize, Deserialize, ToSchema)] pub struct ItemContentInfoResponse { pub success: bool, pub data: Option, pub error: Option, } #[derive(Serialize, Deserialize, ToSchema)] pub struct MetadataResponse { pub success: bool, pub data: Option>, pub error: Option, } #[derive(Serialize, Deserialize, ToSchema)] pub struct StatusInfoResponse { pub success: bool, pub data: Option, pub error: Option, } #[derive(Serialize, Deserialize, ToSchema)] #[schema(description = "Complete information about a stored item including metadata and tags")] pub struct ItemInfo { #[schema(example = 42)] pub id: i64, #[schema(example = "2023-12-01T15:30:45Z")] pub ts: String, #[schema(example = 1024)] pub size: Option, #[schema(example = "gzip")] pub compression: String, #[schema(example = json!(["important", "work", "document"]))] pub tags: Vec, #[schema(example = json!({"mime_type": "text/plain", "mime_encoding": "utf-8", "line_count": "42"}))] pub metadata: HashMap, } #[derive(Serialize, Deserialize, ToSchema)] #[schema(description = "Item information including content and metadata, with binary detection")] pub struct ItemContentInfo { #[serde(flatten)] #[schema(example = json!({"mime_type": "text/plain", "mime_encoding": "utf-8", "line_count": "42"}))] pub metadata: HashMap, #[schema(example = "Hello, world!\nThis is the content of the file.")] pub content: Option, #[schema(example = false)] pub binary: bool, } #[derive(Debug, Deserialize)] pub struct TagsQuery { pub tags: Option, #[serde(default)] pub allow_binary: bool, } #[derive(Debug, Deserialize)] pub struct ListItemsQuery { pub tags: Option, pub order: Option, pub start: Option, pub count: Option, } #[derive(Debug, Deserialize)] pub struct ItemQuery { #[serde(default)] pub allow_binary: bool, } fn check_bearer_auth(auth_str: &str, expected_password: &str, expected_hash: &Option) -> bool { if !auth_str.starts_with("Bearer ") { return false; } let provided_password = &auth_str[7..]; // If we have a password hash, verify against it if let Some(hash) = expected_hash { return pwhash::unix::verify(provided_password, hash); } // Otherwise, do direct comparison provided_password == expected_password } fn check_basic_auth(auth_str: &str, expected_password: &str, expected_hash: &Option) -> bool { if !auth_str.starts_with("Basic ") { return false; } let encoded = &auth_str[6..]; if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) { if let Ok(decoded_str) = String::from_utf8(decoded_bytes) { if let Some(colon_pos) = decoded_str.find(':') { let provided_password = &decoded_str[colon_pos + 1..]; // If we have a password hash, verify against it if let Some(hash) = expected_hash { return pwhash::unix::verify(provided_password, hash); } // Otherwise, do direct comparison let expected_credentials = format!("keep:{}", expected_password); return decoded_str == expected_credentials; } } } false } pub fn check_auth(headers: &HeaderMap, password: &Option, password_hash: &Option) -> bool { // If neither password nor hash is set, no authentication required if password.is_none() && password_hash.is_none() { return true; } if let Some(auth_header) = headers.get("authorization") { if let Ok(auth_str) = auth_header.to_str() { return check_bearer_auth(auth_str, password.as_deref().unwrap_or(""), password_hash) || check_basic_auth(auth_str, password.as_deref().unwrap_or(""), password_hash); } } false } pub async fn logging_middleware( ConnectInfo(addr): ConnectInfo, request: Request, next: Next, ) -> Response { let method = request.method().clone(); let uri = request.uri().clone(); let content_length = request.headers() .get("content-length") .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()) .unwrap_or(0); let start = Instant::now(); let response = next.run(request).await; let duration = start.elapsed(); info!("{} {} {} {} {} bytes - {:?}", addr, method, uri, response.status(), content_length, duration); response } pub fn create_auth_middleware( password: Option, password_hash: Option, ) -> impl Fn(ConnectInfo, Request, Next) -> std::pin::Pin> + Send>> + Clone { move |ConnectInfo(addr): ConnectInfo, request: Request, next: Next| { let password = password.clone(); let password_hash = password_hash.clone(); Box::pin(async move { let headers = request.headers().clone(); let uri = request.uri().clone(); if !check_auth(&headers, &password, &password_hash) { warn!("Unauthorized request to {} from {}", uri, addr); return Err(StatusCode::UNAUTHORIZED); } let response = next.run(request).await; Ok(response) }) } }