feat: add support for salted password hash authentication
Co-authored-by: aider (openai/andrew/openrouter/qwen/qwen3-coder) <aider@aider.chat>
This commit is contained in:
@@ -48,6 +48,7 @@ pub fn mode_server(
|
||||
address: server_address,
|
||||
port: Some(server_port),
|
||||
password: settings.server_password.clone(),
|
||||
password_hash: settings.server_password_hash.clone(),
|
||||
};
|
||||
|
||||
// We need to move the connection into the async runtime
|
||||
@@ -78,6 +79,7 @@ async fn run_server(
|
||||
db: db_conn,
|
||||
data_dir: data_dir.clone(),
|
||||
password: config.password.clone(),
|
||||
password_hash: config.password_hash.clone(),
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
@@ -89,7 +91,7 @@ async fn run_server(
|
||||
.with_state(state)
|
||||
// Add middleware layers (applied in reverse order)
|
||||
.layer(axum::middleware::from_fn(logging_middleware))
|
||||
.layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone())))
|
||||
.layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone(), config.password_hash.clone())))
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TraceLayer::new_for_http())
|
||||
|
||||
@@ -22,6 +22,7 @@ pub struct ServerConfig {
|
||||
pub address: String,
|
||||
pub port: Option<u16>,
|
||||
pub password: Option<String>,
|
||||
pub password_hash: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -29,6 +30,7 @@ pub struct AppState {
|
||||
pub db: Arc<Mutex<rusqlite::Connection>>,
|
||||
pub data_dir: PathBuf,
|
||||
pub password: Option<String>,
|
||||
pub password_hash: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema)]
|
||||
@@ -125,11 +127,23 @@ pub struct ItemQuery {
|
||||
pub allow_binary: bool,
|
||||
}
|
||||
|
||||
fn check_bearer_auth(auth_str: &str, expected_password: &str) -> bool {
|
||||
auth_str.starts_with("Bearer ") && &auth_str[7..] == expected_password
|
||||
fn check_bearer_auth(auth_str: &str, expected_password: &str, expected_hash: &Option<String>) -> bool {
|
||||
if !auth_str.starts_with("Bearer ") {
|
||||
return false;
|
||||
}
|
||||
|
||||
let provided_password = &auth_str[7..];
|
||||
|
||||
// If we have a password hash, verify against it
|
||||
if let Some(hash) = expected_hash {
|
||||
return pwhash::unix::verify(provided_password, hash);
|
||||
}
|
||||
|
||||
// Otherwise, do direct comparison
|
||||
provided_password == expected_password
|
||||
}
|
||||
|
||||
fn check_basic_auth(auth_str: &str, expected_password: &str) -> bool {
|
||||
fn check_basic_auth(auth_str: &str, expected_password: &str, expected_hash: &Option<String>) -> bool {
|
||||
if !auth_str.starts_with("Basic ") {
|
||||
return false;
|
||||
}
|
||||
@@ -137,25 +151,36 @@ fn check_basic_auth(auth_str: &str, expected_password: &str) -> bool {
|
||||
let encoded = &auth_str[6..];
|
||||
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
|
||||
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
|
||||
let expected_credentials = format!("keep:{}", expected_password);
|
||||
return decoded_str == expected_credentials;
|
||||
if let Some(colon_pos) = decoded_str.find(':') {
|
||||
let provided_password = &decoded_str[colon_pos + 1..];
|
||||
|
||||
// If we have a password hash, verify against it
|
||||
if let Some(hash) = expected_hash {
|
||||
return pwhash::unix::verify(provided_password, hash);
|
||||
}
|
||||
|
||||
// Otherwise, do direct comparison
|
||||
let expected_credentials = format!("keep:{}", expected_password);
|
||||
return decoded_str == expected_credentials;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn check_auth(headers: &HeaderMap, password: &Option<String>) -> bool {
|
||||
if let Some(expected_password) = password {
|
||||
if let Some(auth_header) = headers.get("authorization") {
|
||||
if let Ok(auth_str) = auth_header.to_str() {
|
||||
return check_bearer_auth(auth_str, expected_password) ||
|
||||
check_basic_auth(auth_str, expected_password);
|
||||
}
|
||||
}
|
||||
false
|
||||
} else {
|
||||
true // No password required
|
||||
pub fn check_auth(headers: &HeaderMap, password: &Option<String>, password_hash: &Option<String>) -> bool {
|
||||
// If neither password nor hash is set, no authentication required
|
||||
if password.is_none() && password_hash.is_none() {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(auth_header) = headers.get("authorization") {
|
||||
if let Ok(auth_str) = auth_header.to_str() {
|
||||
return check_bearer_auth(auth_str, password.as_deref().unwrap_or(""), password_hash) ||
|
||||
check_basic_auth(auth_str, password.as_deref().unwrap_or(""), password_hash);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub async fn logging_middleware(
|
||||
@@ -182,14 +207,16 @@ pub async fn logging_middleware(
|
||||
|
||||
pub fn create_auth_middleware(
|
||||
password: Option<String>,
|
||||
password_hash: Option<String>,
|
||||
) -> impl Fn(ConnectInfo<SocketAddr>, Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>> + Clone {
|
||||
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, request: Request, next: Next| {
|
||||
let password = password.clone();
|
||||
let password_hash = password_hash.clone();
|
||||
Box::pin(async move {
|
||||
let headers = request.headers().clone();
|
||||
let uri = request.uri().clone();
|
||||
|
||||
if !check_auth(&headers, &password) {
|
||||
if !check_auth(&headers, &password, &password_hash) {
|
||||
warn!("Unauthorized request to {} from {}", uri, addr);
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user