diff --git a/src/modes/server.rs b/src/modes/server.rs index e7d3aba..b087660 100644 --- a/src/modes/server.rs +++ b/src/modes/server.rs @@ -16,7 +16,7 @@ mod common; mod api; mod pages; -pub use common::{ServerConfig, AppState, logging_middleware}; +pub use common::{ServerConfig, AppState, logging_middleware, create_auth_middleware}; pub fn mode_server( _cmd: &mut Command, @@ -55,21 +55,21 @@ async fn run_server( }; let app = Router::new() + // Add API, documentation, and pages routes first + .merge(api::add_routes(Router::new())) + .merge(api::add_docs_routes(Router::new())) + .merge(pages::add_routes(Router::new())) + // Apply state + .with_state(state) + // Add middleware layers (applied in reverse order) .layer(axum::middleware::from_fn(logging_middleware)) + .layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone()))) .layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(CorsLayer::permissive()) ); - // Add API, documentation, and pages routes - let app = api::add_routes(app); - let app = api::add_docs_routes(app); - let app = pages::add_routes(app); - - // Apply state to the router after all routes are added - let app = app.with_state(state); - let addr: SocketAddr = if config.address.starts_with('/') || config.address.starts_with("./") { // Unix socket - not supported by axum directly, fall back to TCP warn!("Unix sockets not yet implemented, falling back to TCP on 127.0.0.1:8080"); diff --git a/src/modes/server/api/item.rs b/src/modes/server/api/item.rs index b7d97c0..6bc133d 100644 --- a/src/modes/server/api/item.rs +++ b/src/modes/server/api/item.rs @@ -1,12 +1,11 @@ use axum::{ - extract::{ConnectInfo, Path, Query, State}, - http::{HeaderMap, StatusCode}, + extract::{Path, Query, State}, + http::{StatusCode}, response::{Json, Response, IntoResponse}, http::header, }; use log::warn; use std::collections::HashMap; -use std::net::SocketAddr; use std::path::PathBuf; use std::str::FromStr; use std::io::Read; @@ -14,7 +13,7 @@ use anyhow::{Result, anyhow}; use crate::compression_engine::{CompressionType, get_compression_engine}; use crate::db; -use crate::modes::server::common::{AppState, ApiResponse, ItemInfo, TagsQuery, check_auth, ListItemsQuery}; +use crate::modes::server::common::{AppState, ApiResponse, ItemInfo, TagsQuery, ListItemsQuery}; #[utoipa::path( get, @@ -37,13 +36,7 @@ use crate::modes::server::common::{AppState, ApiResponse, ItemInfo, TagsQuery, c pub async fn handle_list_items( State(state): State, Query(params): Query, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/item/ from {}", addr); - return Err(StatusCode::UNAUTHORIZED); - } let mut conn = state.db.lock().await; @@ -138,13 +131,7 @@ pub async fn handle_list_items( )] pub async fn handle_post_item( State(state): State, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to POST /api/item/ from {}", addr); - return Err(StatusCode::UNAUTHORIZED); - } // This is a simplified implementation // In a real implementation, you'd need to properly parse multipart/form-data @@ -178,17 +165,9 @@ pub async fn handle_post_item( pub async fn handle_delete_item( State(state): State, Path(item_id): Path, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to DELETE /api/item/{} from {}", item_id, addr); - return Err(StatusCode::UNAUTHORIZED); - } - // Validate that item ID is positive to prevent path traversal issues if item_id <= 0 { - warn!("Invalid item ID {} from {}", item_id, addr); return Err(StatusCode::BAD_REQUEST); } @@ -198,10 +177,7 @@ pub async fn handle_delete_item( warn!("Failed to get item {} for deletion: {}", item_id, e); StatusCode::INTERNAL_SERVER_ERROR })? { - db::delete_item(&mut *conn, item).map_err(|e| { - warn!("Failed to delete item {}: {}", item_id, e); - StatusCode::INTERNAL_SERVER_ERROR - })?; + db::delete_item(&mut *conn, item).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let response = ApiResponse::<()> { success: true, @@ -233,13 +209,7 @@ pub async fn handle_delete_item( pub async fn handle_get_item_latest( State(state): State, Query(params): Query, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/item/latest/content from {}", addr); - return Err(StatusCode::UNAUTHORIZED); - } let mut conn = state.db.lock().await; @@ -301,17 +271,9 @@ pub async fn handle_get_item_latest( pub async fn handle_get_item( State(state): State, Path(item_id): Path, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/item/{}/content from {}", item_id, addr); - return Err(StatusCode::UNAUTHORIZED); - } - // Validate that item ID is positive to prevent path traversal issues if item_id <= 0 { - warn!("Invalid item ID {} from {}", item_id, addr); return Err(StatusCode::BAD_REQUEST); } @@ -364,13 +326,7 @@ pub async fn handle_get_item( pub async fn handle_get_item_latest_content( State(state): State, Query(params): Query, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/item/latest/content from {}", addr); - return Err(StatusCode::UNAUTHORIZED); - } let mut conn = state.db.lock().await; @@ -427,17 +383,9 @@ pub async fn handle_get_item_latest_content( pub async fn handle_get_item_content( State(state): State, Path(item_id): Path, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/item/{}/content from {}", item_id, addr); - return Err(StatusCode::UNAUTHORIZED); - } - // Validate that item ID is positive to prevent path traversal issues if item_id <= 0 { - warn!("Invalid item ID {} from {}", item_id, addr); return Err(StatusCode::BAD_REQUEST); } @@ -539,13 +487,7 @@ async fn get_item_raw_content(item: &db::Item, data_dir: &PathBuf, conn: &mut ru pub async fn handle_get_item_latest_meta( State(state): State, Query(params): Query, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/item/latest/meta from {}", addr); - return Err(StatusCode::UNAUTHORIZED); - } let mut conn = state.db.lock().await; @@ -604,13 +546,7 @@ pub async fn handle_get_item_latest_meta( pub async fn handle_get_item_meta( State(state): State, Path(item_id): Path, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/item/{}/meta from {}", item_id, addr); - return Err(StatusCode::UNAUTHORIZED); - } let mut conn = state.db.lock().await; diff --git a/src/modes/server/api/mod.rs b/src/modes/server/api/mod.rs index eba8c02..3cf5c39 100644 --- a/src/modes/server/api/mod.rs +++ b/src/modes/server/api/mod.rs @@ -53,8 +53,9 @@ pub fn add_routes(router: Router) -> Router { } pub fn add_docs_routes(router: Router) -> Router { - router.merge( - SwaggerUi::new("/swagger-ui") - .url("/api-docs/openapi.json", ApiDoc::openapi()) - ) + router + .merge(SwaggerUi::new("/swagger").url("/openapi.json", ApiDoc::openapi())) + .route("/openapi.json", axum::routing::get(|| async { + axum::Json(ApiDoc::openapi()) + })) } diff --git a/src/modes/server/api/status.rs b/src/modes/server/api/status.rs index a40de4a..b0961bf 100644 --- a/src/modes/server/api/status.rs +++ b/src/modes/server/api/status.rs @@ -1,12 +1,11 @@ use axum::{ - extract::{ConnectInfo, State}, - http::{HeaderMap, StatusCode}, + extract::State, + http::StatusCode, response::Json, }; use log::warn; -use std::net::SocketAddr; -use crate::modes::server::common::{AppState, ApiResponse, check_auth}; +use crate::modes::server::common::{AppState, ApiResponse}; use crate::common::status::{generate_status_info, StatusInfo}; use crate::meta_plugin::MetaPluginType; @@ -24,13 +23,7 @@ use crate::meta_plugin::MetaPluginType; )] pub async fn handle_status( State(state): State, - headers: HeaderMap, - ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { - if !check_auth(&headers, &state.password) { - warn!("Unauthorized request to /api/status from {}", addr); - return Err(StatusCode::UNAUTHORIZED); - } // Get database path let db_path = state.db.lock().await.path().unwrap_or("unknown").to_string();