feat: add support for salted password hash authentication
Co-authored-by: aider (openai/andrew/openrouter/qwen/qwen3-coder) <aider@aider.chat>
This commit is contained in:
@@ -143,6 +143,10 @@ pub struct OptionsArgs {
|
|||||||
#[arg(help("Password for server authentication (requires --server)"))]
|
#[arg(help("Password for server authentication (requires --server)"))]
|
||||||
pub server_password: Option<String>,
|
pub server_password: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, env("KEEP_SERVER_PASSWORD_HASH"))]
|
||||||
|
#[arg(help("Password hash for server authentication (requires --server)"))]
|
||||||
|
pub server_password_hash: Option<String>,
|
||||||
|
|
||||||
#[arg(long, help("Force output even when binary data would be sent to a TTY"))]
|
#[arg(long, help("Force output even when binary data would be sent to a TTY"))]
|
||||||
pub force: bool,
|
pub force: bool,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ pub struct ServerConfig {
|
|||||||
pub port: Option<u16>,
|
pub port: Option<u16>,
|
||||||
pub password_file: Option<PathBuf>,
|
pub password_file: Option<PathBuf>,
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
|
pub password_hash: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -49,6 +50,7 @@ pub struct Settings {
|
|||||||
pub quiet: bool,
|
pub quiet: bool,
|
||||||
pub force: bool,
|
pub force: bool,
|
||||||
pub server_password: Option<String>,
|
pub server_password: Option<String>,
|
||||||
|
pub server_password_hash: Option<String>,
|
||||||
pub server_address: Option<String>,
|
pub server_address: Option<String>,
|
||||||
pub server_port: Option<u16>,
|
pub server_port: Option<u16>,
|
||||||
pub compression: Option<String>,
|
pub compression: Option<String>,
|
||||||
@@ -90,6 +92,9 @@ impl Settings {
|
|||||||
let server_password = args.options.server_password.clone()
|
let server_password = args.options.server_password.clone()
|
||||||
.or_else(|| config.get_server_password().ok().flatten());
|
.or_else(|| config.get_server_password().ok().flatten());
|
||||||
|
|
||||||
|
let server_password_hash = args.options.server_password_hash.clone()
|
||||||
|
.or_else(|| config.server.as_ref().and_then(|s| s.password_hash.clone()));
|
||||||
|
|
||||||
let server_address = args.mode.server_address.clone()
|
let server_address = args.mode.server_address.clone()
|
||||||
.or_else(|| config.server.as_ref().and_then(|s| s.address.clone()));
|
.or_else(|| config.server.as_ref().and_then(|s| s.address.clone()));
|
||||||
|
|
||||||
@@ -119,6 +124,7 @@ impl Settings {
|
|||||||
quiet,
|
quiet,
|
||||||
force,
|
force,
|
||||||
server_password,
|
server_password,
|
||||||
|
server_password_hash,
|
||||||
server_address,
|
server_address,
|
||||||
server_port,
|
server_port,
|
||||||
compression,
|
compression,
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ pub fn mode_server(
|
|||||||
address: server_address,
|
address: server_address,
|
||||||
port: Some(server_port),
|
port: Some(server_port),
|
||||||
password: settings.server_password.clone(),
|
password: settings.server_password.clone(),
|
||||||
|
password_hash: settings.server_password_hash.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// We need to move the connection into the async runtime
|
// We need to move the connection into the async runtime
|
||||||
@@ -78,6 +79,7 @@ async fn run_server(
|
|||||||
db: db_conn,
|
db: db_conn,
|
||||||
data_dir: data_dir.clone(),
|
data_dir: data_dir.clone(),
|
||||||
password: config.password.clone(),
|
password: config.password.clone(),
|
||||||
|
password_hash: config.password_hash.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
@@ -89,7 +91,7 @@ async fn run_server(
|
|||||||
.with_state(state)
|
.with_state(state)
|
||||||
// Add middleware layers (applied in reverse order)
|
// Add middleware layers (applied in reverse order)
|
||||||
.layer(axum::middleware::from_fn(logging_middleware))
|
.layer(axum::middleware::from_fn(logging_middleware))
|
||||||
.layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone())))
|
.layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone(), config.password_hash.clone())))
|
||||||
.layer(
|
.layer(
|
||||||
ServiceBuilder::new()
|
ServiceBuilder::new()
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ pub struct ServerConfig {
|
|||||||
pub address: String,
|
pub address: String,
|
||||||
pub port: Option<u16>,
|
pub port: Option<u16>,
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
|
pub password_hash: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -29,6 +30,7 @@ pub struct AppState {
|
|||||||
pub db: Arc<Mutex<rusqlite::Connection>>,
|
pub db: Arc<Mutex<rusqlite::Connection>>,
|
||||||
pub data_dir: PathBuf,
|
pub data_dir: PathBuf,
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
|
pub password_hash: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema)]
|
#[derive(Serialize, Deserialize, ToSchema)]
|
||||||
@@ -125,11 +127,23 @@ pub struct ItemQuery {
|
|||||||
pub allow_binary: bool,
|
pub allow_binary: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_bearer_auth(auth_str: &str, expected_password: &str) -> bool {
|
fn check_bearer_auth(auth_str: &str, expected_password: &str, expected_hash: &Option<String>) -> bool {
|
||||||
auth_str.starts_with("Bearer ") && &auth_str[7..] == expected_password
|
if !auth_str.starts_with("Bearer ") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let provided_password = &auth_str[7..];
|
||||||
|
|
||||||
|
// If we have a password hash, verify against it
|
||||||
|
if let Some(hash) = expected_hash {
|
||||||
|
return pwhash::unix::verify(provided_password, hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, do direct comparison
|
||||||
|
provided_password == expected_password
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_basic_auth(auth_str: &str, expected_password: &str) -> bool {
|
fn check_basic_auth(auth_str: &str, expected_password: &str, expected_hash: &Option<String>) -> bool {
|
||||||
if !auth_str.starts_with("Basic ") {
|
if !auth_str.starts_with("Basic ") {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -137,25 +151,36 @@ fn check_basic_auth(auth_str: &str, expected_password: &str) -> bool {
|
|||||||
let encoded = &auth_str[6..];
|
let encoded = &auth_str[6..];
|
||||||
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
|
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
|
||||||
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
|
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
|
||||||
let expected_credentials = format!("keep:{}", expected_password);
|
if let Some(colon_pos) = decoded_str.find(':') {
|
||||||
return decoded_str == expected_credentials;
|
let provided_password = &decoded_str[colon_pos + 1..];
|
||||||
|
|
||||||
|
// If we have a password hash, verify against it
|
||||||
|
if let Some(hash) = expected_hash {
|
||||||
|
return pwhash::unix::verify(provided_password, hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, do direct comparison
|
||||||
|
let expected_credentials = format!("keep:{}", expected_password);
|
||||||
|
return decoded_str == expected_credentials;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_auth(headers: &HeaderMap, password: &Option<String>) -> bool {
|
pub fn check_auth(headers: &HeaderMap, password: &Option<String>, password_hash: &Option<String>) -> bool {
|
||||||
if let Some(expected_password) = password {
|
// If neither password nor hash is set, no authentication required
|
||||||
if let Some(auth_header) = headers.get("authorization") {
|
if password.is_none() && password_hash.is_none() {
|
||||||
if let Ok(auth_str) = auth_header.to_str() {
|
return true;
|
||||||
return check_bearer_auth(auth_str, expected_password) ||
|
|
||||||
check_basic_auth(auth_str, expected_password);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
false
|
|
||||||
} else {
|
|
||||||
true // No password required
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(auth_header) = headers.get("authorization") {
|
||||||
|
if let Ok(auth_str) = auth_header.to_str() {
|
||||||
|
return check_bearer_auth(auth_str, password.as_deref().unwrap_or(""), password_hash) ||
|
||||||
|
check_basic_auth(auth_str, password.as_deref().unwrap_or(""), password_hash);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn logging_middleware(
|
pub async fn logging_middleware(
|
||||||
@@ -182,14 +207,16 @@ pub async fn logging_middleware(
|
|||||||
|
|
||||||
pub fn create_auth_middleware(
|
pub fn create_auth_middleware(
|
||||||
password: Option<String>,
|
password: Option<String>,
|
||||||
|
password_hash: Option<String>,
|
||||||
) -> impl Fn(ConnectInfo<SocketAddr>, Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>> + Clone {
|
) -> impl Fn(ConnectInfo<SocketAddr>, Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>> + Clone {
|
||||||
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, request: Request, next: Next| {
|
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, request: Request, next: Next| {
|
||||||
let password = password.clone();
|
let password = password.clone();
|
||||||
|
let password_hash = password_hash.clone();
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let headers = request.headers().clone();
|
let headers = request.headers().clone();
|
||||||
let uri = request.uri().clone();
|
let uri = request.uri().clone();
|
||||||
|
|
||||||
if !check_auth(&headers, &password) {
|
if !check_auth(&headers, &password, &password_hash) {
|
||||||
warn!("Unauthorized request to {} from {}", uri, addr);
|
warn!("Unauthorized request to {} from {}", uri, addr);
|
||||||
return Err(StatusCode::UNAUTHORIZED);
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user