diff --git a/src/args.rs b/src/args.rs index 3f9dcc2..b1b78ce 100644 --- a/src/args.rs +++ b/src/args.rs @@ -143,6 +143,10 @@ pub struct OptionsArgs { #[arg(help("Password for server authentication (requires --server)"))] pub server_password: Option, + #[arg(long, env("KEEP_SERVER_PASSWORD_HASH"))] + #[arg(help("Password hash for server authentication (requires --server)"))] + pub server_password_hash: Option, + #[arg(long, help("Force output even when binary data would be sent to a TTY"))] pub force: bool, } diff --git a/src/config.rs b/src/config.rs index 3adf962..3f7184c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -26,6 +26,7 @@ pub struct ServerConfig { pub port: Option, pub password_file: Option, pub password: Option, + pub password_hash: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -49,6 +50,7 @@ pub struct Settings { pub quiet: bool, pub force: bool, pub server_password: Option, + pub server_password_hash: Option, pub server_address: Option, pub server_port: Option, pub compression: Option, @@ -90,6 +92,9 @@ impl Settings { let server_password = args.options.server_password.clone() .or_else(|| config.get_server_password().ok().flatten()); + let server_password_hash = args.options.server_password_hash.clone() + .or_else(|| config.server.as_ref().and_then(|s| s.password_hash.clone())); + let server_address = args.mode.server_address.clone() .or_else(|| config.server.as_ref().and_then(|s| s.address.clone())); @@ -119,6 +124,7 @@ impl Settings { quiet, force, server_password, + server_password_hash, server_address, server_port, compression, diff --git a/src/modes/server.rs b/src/modes/server.rs index 4dce4b0..5ffe5ed 100644 --- a/src/modes/server.rs +++ b/src/modes/server.rs @@ -48,6 +48,7 @@ pub fn mode_server( address: server_address, port: Some(server_port), password: settings.server_password.clone(), + password_hash: settings.server_password_hash.clone(), }; // We need to move the connection into the async runtime @@ -78,6 +79,7 @@ async fn run_server( db: db_conn, data_dir: data_dir.clone(), password: config.password.clone(), + password_hash: config.password_hash.clone(), }; let app = Router::new() @@ -89,7 +91,7 @@ async fn run_server( .with_state(state) // Add middleware layers (applied in reverse order) .layer(axum::middleware::from_fn(logging_middleware)) - .layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone()))) + .layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone(), config.password_hash.clone()))) .layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) diff --git a/src/modes/server/common.rs b/src/modes/server/common.rs index d899dda..89c65ca 100644 --- a/src/modes/server/common.rs +++ b/src/modes/server/common.rs @@ -22,6 +22,7 @@ pub struct ServerConfig { pub address: String, pub port: Option, pub password: Option, + pub password_hash: Option, } #[derive(Clone)] @@ -29,6 +30,7 @@ pub struct AppState { pub db: Arc>, pub data_dir: PathBuf, pub password: Option, + pub password_hash: Option, } #[derive(Serialize, Deserialize, ToSchema)] @@ -125,11 +127,23 @@ pub struct ItemQuery { pub allow_binary: bool, } -fn check_bearer_auth(auth_str: &str, expected_password: &str) -> bool { - auth_str.starts_with("Bearer ") && &auth_str[7..] == expected_password +fn check_bearer_auth(auth_str: &str, expected_password: &str, expected_hash: &Option) -> bool { + if !auth_str.starts_with("Bearer ") { + return false; + } + + let provided_password = &auth_str[7..]; + + // If we have a password hash, verify against it + if let Some(hash) = expected_hash { + return pwhash::unix::verify(provided_password, hash); + } + + // Otherwise, do direct comparison + provided_password == expected_password } -fn check_basic_auth(auth_str: &str, expected_password: &str) -> bool { +fn check_basic_auth(auth_str: &str, expected_password: &str, expected_hash: &Option) -> bool { if !auth_str.starts_with("Basic ") { return false; } @@ -137,25 +151,36 @@ fn check_basic_auth(auth_str: &str, expected_password: &str) -> bool { let encoded = &auth_str[6..]; if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) { if let Ok(decoded_str) = String::from_utf8(decoded_bytes) { - let expected_credentials = format!("keep:{}", expected_password); - return decoded_str == expected_credentials; + if let Some(colon_pos) = decoded_str.find(':') { + let provided_password = &decoded_str[colon_pos + 1..]; + + // If we have a password hash, verify against it + if let Some(hash) = expected_hash { + return pwhash::unix::verify(provided_password, hash); + } + + // Otherwise, do direct comparison + let expected_credentials = format!("keep:{}", expected_password); + return decoded_str == expected_credentials; + } } } false } -pub fn check_auth(headers: &HeaderMap, password: &Option) -> bool { - if let Some(expected_password) = password { - if let Some(auth_header) = headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - return check_bearer_auth(auth_str, expected_password) || - check_basic_auth(auth_str, expected_password); - } - } - false - } else { - true // No password required +pub fn check_auth(headers: &HeaderMap, password: &Option, password_hash: &Option) -> bool { + // If neither password nor hash is set, no authentication required + if password.is_none() && password_hash.is_none() { + return true; } + + if let Some(auth_header) = headers.get("authorization") { + if let Ok(auth_str) = auth_header.to_str() { + return check_bearer_auth(auth_str, password.as_deref().unwrap_or(""), password_hash) || + check_basic_auth(auth_str, password.as_deref().unwrap_or(""), password_hash); + } + } + false } pub async fn logging_middleware( @@ -182,14 +207,16 @@ pub async fn logging_middleware( pub fn create_auth_middleware( password: Option, + password_hash: Option, ) -> impl Fn(ConnectInfo, Request, Next) -> std::pin::Pin> + Send>> + Clone { move |ConnectInfo(addr): ConnectInfo, request: Request, next: Next| { let password = password.clone(); + let password_hash = password_hash.clone(); Box::pin(async move { let headers = request.headers().clone(); let uri = request.uri().clone(); - if !check_auth(&headers, &password) { + if !check_auth(&headers, &password, &password_hash) { warn!("Unauthorized request to {} from {}", uri, addr); return Err(StatusCode::UNAUTHORIZED); }