228 lines
7.0 KiB
Rust
228 lines
7.0 KiB
Rust
use anyhow::Result;
|
|
use axum::{
|
|
extract::{Request, ConnectInfo},
|
|
http::{HeaderMap, StatusCode},
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use base64::Engine;
|
|
use log::{info, warn};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
use tokio::sync::Mutex;
|
|
use utoipa::ToSchema;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ServerConfig {
|
|
pub address: String,
|
|
pub port: Option<u16>,
|
|
pub password: Option<String>,
|
|
pub password_hash: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
pub db: Arc<Mutex<rusqlite::Connection>>,
|
|
pub data_dir: PathBuf,
|
|
pub password: Option<String>,
|
|
pub password_hash: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
#[schema(description = "Standard API response wrapper containing success status, data payload, and error information")]
|
|
pub struct ApiResponse<T> {
|
|
pub success: bool,
|
|
pub data: Option<T>,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
// Specific response types for OpenAPI documentation
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
pub struct ItemInfoListResponse {
|
|
pub success: bool,
|
|
pub data: Option<Vec<ItemInfo>>,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
pub struct ItemInfoResponse {
|
|
pub success: bool,
|
|
pub data: Option<ItemInfo>,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
pub struct ItemContentInfoResponse {
|
|
pub success: bool,
|
|
pub data: Option<ItemContentInfo>,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
pub struct MetadataResponse {
|
|
pub success: bool,
|
|
pub data: Option<HashMap<String, String>>,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
pub struct StatusInfoResponse {
|
|
pub success: bool,
|
|
pub data: Option<crate::common::status::StatusInfo>,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
#[schema(description = "Complete information about a stored item including metadata and tags")]
|
|
pub struct ItemInfo {
|
|
#[schema(example = 42)]
|
|
pub id: i64,
|
|
#[schema(example = "2023-12-01T15:30:45Z")]
|
|
pub ts: String,
|
|
#[schema(example = 1024)]
|
|
pub size: Option<i64>,
|
|
#[schema(example = "gzip")]
|
|
pub compression: String,
|
|
#[schema(example = json!(["important", "work", "document"]))]
|
|
pub tags: Vec<String>,
|
|
#[schema(example = json!({"file.mime": "text/plain", "file.encoding": "utf-8", "line_count": "42"}))]
|
|
pub metadata: HashMap<String, String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, ToSchema)]
|
|
#[schema(description = "Item information including content and metadata, with binary detection")]
|
|
pub struct ItemContentInfo {
|
|
#[serde(flatten)]
|
|
#[schema(example = json!({"file.mime": "text/plain", "file.encoding": "utf-8", "line_count": "42"}))]
|
|
pub metadata: HashMap<String, String>,
|
|
#[schema(example = "Hello, world!\nThis is the content of the file.")]
|
|
pub content: Option<String>,
|
|
#[schema(example = false)]
|
|
pub binary: bool,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct TagsQuery {
|
|
pub tags: Option<String>,
|
|
#[serde(default)]
|
|
pub allow_binary: bool,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ListItemsQuery {
|
|
pub tags: Option<String>,
|
|
pub order: Option<String>,
|
|
pub start: Option<u32>,
|
|
pub count: Option<u32>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ItemQuery {
|
|
#[serde(default)]
|
|
pub allow_binary: bool,
|
|
}
|
|
|
|
fn check_bearer_auth(auth_str: &str, expected_password: &str, expected_hash: &Option<String>) -> bool {
|
|
if !auth_str.starts_with("Bearer ") {
|
|
return false;
|
|
}
|
|
|
|
let provided_password = &auth_str[7..];
|
|
|
|
// If we have a password hash, verify against it
|
|
if let Some(hash) = expected_hash {
|
|
return pwhash::unix::verify(provided_password, hash);
|
|
}
|
|
|
|
// Otherwise, do direct comparison
|
|
provided_password == expected_password
|
|
}
|
|
|
|
fn check_basic_auth(auth_str: &str, expected_password: &str, expected_hash: &Option<String>) -> bool {
|
|
if !auth_str.starts_with("Basic ") {
|
|
return false;
|
|
}
|
|
|
|
let encoded = &auth_str[6..];
|
|
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
|
|
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
|
|
if let Some(colon_pos) = decoded_str.find(':') {
|
|
let provided_password = &decoded_str[colon_pos + 1..];
|
|
|
|
// If we have a password hash, verify against it
|
|
if let Some(hash) = expected_hash {
|
|
return pwhash::unix::verify(provided_password, hash);
|
|
}
|
|
|
|
// Otherwise, do direct comparison
|
|
let expected_credentials = format!("keep:{}", expected_password);
|
|
return decoded_str == expected_credentials;
|
|
}
|
|
}
|
|
}
|
|
false
|
|
}
|
|
|
|
pub fn check_auth(headers: &HeaderMap, password: &Option<String>, password_hash: &Option<String>) -> bool {
|
|
// If neither password nor hash is set, no authentication required
|
|
if password.is_none() && password_hash.is_none() {
|
|
return true;
|
|
}
|
|
|
|
if let Some(auth_header) = headers.get("authorization") {
|
|
if let Ok(auth_str) = auth_header.to_str() {
|
|
return check_bearer_auth(auth_str, password.as_deref().unwrap_or(""), password_hash) ||
|
|
check_basic_auth(auth_str, password.as_deref().unwrap_or(""), password_hash);
|
|
}
|
|
}
|
|
false
|
|
}
|
|
|
|
pub async fn logging_middleware(
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
request: Request,
|
|
next: Next,
|
|
) -> Response {
|
|
let method = request.method().clone();
|
|
let uri = request.uri().clone();
|
|
let content_length = request.headers()
|
|
.get("content-length")
|
|
.and_then(|v| v.to_str().ok())
|
|
.and_then(|s| s.parse::<u64>().ok())
|
|
.unwrap_or(0);
|
|
|
|
let start = Instant::now();
|
|
let response = next.run(request).await;
|
|
let duration = start.elapsed();
|
|
|
|
info!("{} {} {} {} {} bytes - {:?}", addr, method, uri, response.status(), content_length, duration);
|
|
|
|
response
|
|
}
|
|
|
|
pub fn create_auth_middleware(
|
|
password: Option<String>,
|
|
password_hash: Option<String>,
|
|
) -> impl Fn(ConnectInfo<SocketAddr>, Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, StatusCode>> + Send>> + Clone {
|
|
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, request: Request, next: Next| {
|
|
let password = password.clone();
|
|
let password_hash = password_hash.clone();
|
|
Box::pin(async move {
|
|
let headers = request.headers().clone();
|
|
let uri = request.uri().clone();
|
|
|
|
if !check_auth(&headers, &password, &password_hash) {
|
|
warn!("Unauthorized request to {} from {}", uri, addr);
|
|
return Err(StatusCode::UNAUTHORIZED);
|
|
}
|
|
|
|
let response = next.run(request).await;
|
|
Ok(response)
|
|
})
|
|
}
|
|
}
|