use anyhow::{Result, anyhow}; use axum::{ extract::{ConnectInfo, Path, Query, State}, http::{HeaderMap, StatusCode}, response::{Html, Json}, routing::get, Router, }; use clap::Command; use log::{debug, info, warn}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; use std::io::Read; use std::net::SocketAddr; use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; use std::time::Instant; use tokio::sync::Mutex; use tower_http::cors::CorsLayer; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; use crate::compression_engine::{CompressionType, get_compression_engine}; use crate::db; use crate::Args; use crate::modes::status::{StatusInfo, generate_status_info}; use crate::compression_engine::{CompressionType as CompressionTypeEnum, COMPRESSION_PROGRAMS}; use crate::compression_engine::program::CompressionEngineProgram; use crate::meta_plugin::{MetaPluginType, get_meta_plugin}; use strum::IntoEnumIterator; #[derive(Debug, Clone)] pub struct ServerConfig { pub address: String, pub password: Option, } impl FromStr for ServerConfig { type Err = anyhow::Error; fn from_str(s: &str) -> Result { Ok(ServerConfig { address: s.to_string(), password: None, }) } } #[derive(Clone)] struct AppState { db: Arc>, data_dir: PathBuf, password: Option, args: Arc, } #[derive(Serialize, Deserialize)] struct ApiResponse { success: bool, data: Option, error: Option, } #[derive(Serialize, Deserialize)] struct ItemInfo { id: i64, ts: String, size: Option, compression: String, tags: Vec, metadata: HashMap, } #[derive(Debug, Deserialize)] struct TagsQuery { tags: Option, } pub fn mode_server( _cmd: &mut Command, args: &Args, conn: &mut rusqlite::Connection, data_path: PathBuf, ) -> Result<()> { let server_address = args.mode.server.as_ref().unwrap(); let config = ServerConfig { address: server_address.clone(), password: args.options.server_password.clone(), }; // We need to move the connection into the async runtime let rt = tokio::runtime::Runtime::new()?; // Take ownership of the connection and move it into the async runtime let owned_conn = std::mem::replace(conn, rusqlite::Connection::open_in_memory()?); rt.block_on(run_server(config, owned_conn, data_path, args)) } async fn run_server( config: ServerConfig, conn: rusqlite::Connection, data_dir: PathBuf, args: &Args, ) -> Result<()> { debug!("Starting REST HTTP server on {}", config.address); // Use the existing database connection let db_conn = Arc::new(Mutex::new(conn)); let state = AppState { db: db_conn, data_dir: data_dir.clone(), password: config.password.clone(), args: Arc::new(args.clone()), }; let app = Router::new() .route("/status", get(handle_status)) .route("/item/", get(handle_list_items).put(handle_put_item)) .route("/item/:id", get(handle_get_item).delete(handle_delete_item)) .route("/content", get(handle_get_content_latest)) .route("/content/:id", get(handle_get_content)) .route("/openapi.json", get(handle_openapi)) .route("/swagger/", get(handle_swagger_ui)) .layer(axum::middleware::from_fn(logging_middleware)) .layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(CorsLayer::permissive()) ) .with_state(state); let addr: SocketAddr = if config.address.starts_with('/') || config.address.starts_with("./") { // Unix socket - not supported by axum directly, fall back to TCP warn!("Unix sockets not yet implemented, falling back to TCP on 127.0.0.1:8080"); "127.0.0.1:8080".parse()? } else { config.address.parse()? }; info!("SERVER: HTTP server listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve( listener, app.into_make_service_with_connect_info::() ).await?; Ok(()) } // Custom middleware for logging requests and responses async fn logging_middleware( req: axum::http::Request, next: axum::middleware::Next, ) -> Result, axum::response::Response> { let method = req.method().clone(); let uri = req.uri().clone(); let headers = req.headers().clone(); // Log incoming request info!("SERVER: {} {} - Headers: {:?}", method, uri, headers); let start = Instant::now(); let response = next.run(req).await; let duration = start.elapsed(); // Log response info!("SERVER: {} {} - Status: {} - Duration: {:?}", method, uri, response.status(), duration); Ok(response) } fn check_auth(headers: &HeaderMap, password: &Option) -> bool { if let Some(expected_password) = password { if let Some(auth_header) = headers.get("authorization") { if let Ok(auth_str) = auth_header.to_str() { return auth_str.starts_with("Bearer ") && &auth_str[7..] == expected_password; } } false } else { true // No password required } } async fn handle_status( State(state): State, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { if !check_auth(&headers, &state.password) { warn!("Unauthorized request from {}", addr); return Err(StatusCode::UNAUTHORIZED); } // Use the actual args that the server was started with let args = &state.args; // Determine which meta plugins would be enabled for a save operation let mut meta_plugin_types: Vec = crate::modes::common::cmd_args_meta_plugin_types(&mut Command::new("keep"), args); // Add digest type if specified let digest_type = crate::modes::common::cmd_args_digest_type(&mut Command::new("keep"), args); let digest_meta_plugin_type = match digest_type { crate::meta_plugin::MetaPluginType::DigestSha256 => Some(MetaPluginType::DigestSha256), crate::meta_plugin::MetaPluginType::DigestMd5 => Some(MetaPluginType::DigestMd5), _ => None, }; if let Some(digest_plugin_type) = digest_meta_plugin_type { if !meta_plugin_types.contains(&digest_plugin_type) { meta_plugin_types.push(digest_plugin_type); } } let mut db_path = state.data_dir.clone(); db_path.push("keep-1.db"); let status_info = generate_status_info(state.data_dir.clone(), db_path, &meta_plugin_types); let response = ApiResponse { success: true, data: Some(status_info), error: None, }; Ok(Json(response)) } async fn handle_list_items( State(state): State, Query(params): Query, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> Result>>, StatusCode> { if !check_auth(&headers, &state.password) { warn!("Unauthorized request to /item/ from {}", addr); return Err(StatusCode::UNAUTHORIZED); } let mut conn = state.db.lock().await; let tags: Vec = params.tags .map(|s| s.split(',').map(|t| t.trim().to_string()).collect()) .unwrap_or_default(); let items = if tags.is_empty() { db::get_items(&mut *conn).map_err(|e| { warn!("Failed to get items: {}", e); StatusCode::INTERNAL_SERVER_ERROR })? } else { db::get_items_matching(&mut *conn, &tags, &HashMap::new()) .map_err(|e| { warn!("Failed to get items matching tags {:?}: {}", tags, e); StatusCode::INTERNAL_SERVER_ERROR })? }; // Get item IDs for batch queries let item_ids: Vec = items.iter().filter_map(|item| item.id).collect(); // Get tags and metadata for all items let tags_map = db::get_tags_for_items(&mut *conn, &item_ids) .map_err(|e| { warn!("Failed to get tags for items: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; let meta_map = db::get_meta_for_items(&mut *conn, &item_ids) .map_err(|e| { warn!("Failed to get metadata for items: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; let item_infos: Vec = items .into_iter() .map(|item| { let item_id = item.id.unwrap_or(0); let item_tags = tags_map.get(&item_id) .map(|tags| tags.iter().map(|t| t.name.clone()).collect()) .unwrap_or_default(); let item_meta = meta_map.get(&item_id) .cloned() .unwrap_or_default(); ItemInfo { id: item_id, ts: item.ts.to_rfc3339(), size: item.size, compression: item.compression, tags: item_tags, metadata: item_meta, } }) .collect(); let response = ApiResponse { success: true, data: Some(item_infos), error: None, }; Ok(Json(response)) } async fn handle_get_item( State(state): State, Path(item_id): Path, Query(params): Query, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { if !check_auth(&headers, &state.password) { warn!("Unauthorized request to /item/{} from {}", item_id, addr); return Err(StatusCode::UNAUTHORIZED); } let mut conn = state.db.lock().await; let item = if let Ok(id) = item_id.parse::() { db::get_item(&mut *conn, id).map_err(|e| { warn!("Failed to get item {}: {}", id, e); StatusCode::INTERNAL_SERVER_ERROR })? } else { // Try to find by tags if let Some(tags_str) = params.tags { let tags: Vec = tags_str.split(',').map(|t| t.trim().to_string()).collect(); db::get_item_matching(&mut *conn, &tags, &HashMap::new()) .map_err(|e| { warn!("Failed to get item matching tags {:?}: {}", tags, e); StatusCode::INTERNAL_SERVER_ERROR })? } else { warn!("Invalid item ID '{}' and no tags provided", item_id); return Err(StatusCode::BAD_REQUEST); } }; if let Some(item) = item { let item_tags = db::get_item_tags(&mut *conn, &item) .map_err(|e| { warn!("Failed to get tags for item {}: {}", item.id.unwrap_or(0), e); StatusCode::INTERNAL_SERVER_ERROR })? .into_iter() .map(|t| t.name) .collect(); let item_meta = db::get_item_meta(&mut *conn, &item) .map_err(|e| { warn!("Failed to get metadata for item {}: {}", item.id.unwrap_or(0), e); StatusCode::INTERNAL_SERVER_ERROR })? .into_iter() .map(|m| (m.name, m.value)) .collect(); let item_info = ItemInfo { id: item.id.unwrap_or(0), ts: item.ts.to_rfc3339(), size: item.size, compression: item.compression, tags: item_tags, metadata: item_meta, }; let response = ApiResponse { success: true, data: Some(item_info), error: None, }; Ok(Json(response)) } else { Err(StatusCode::NOT_FOUND) } } async fn handle_put_item( State(state): State, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { if !check_auth(&headers, &state.password) { warn!("Unauthorized request to PUT /item/ from {}", addr); return Err(StatusCode::UNAUTHORIZED); } // This is a simplified implementation // In a real implementation, you'd need to properly parse multipart/form-data // or JSON payload with the item data let response = ApiResponse:: { success: false, data: None, error: Some("PUT /item/ not yet implemented".to_string()), }; Ok(Json(response)) } async fn handle_delete_item( State(state): State, Path(item_id): Path, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { if !check_auth(&headers, &state.password) { warn!("Unauthorized request to DELETE /item/{} from {}", item_id, addr); return Err(StatusCode::UNAUTHORIZED); } if let Ok(id) = item_id.parse::() { let mut conn = state.db.lock().await; if let Some(item) = db::get_item(&mut *conn, id).map_err(|e| { warn!("Failed to get item {} for deletion: {}", id, e); StatusCode::INTERNAL_SERVER_ERROR })? { db::delete_item(&mut *conn, item).map_err(|e| { warn!("Failed to delete item {}: {}", id, e); StatusCode::INTERNAL_SERVER_ERROR })?; let response = ApiResponse::<()> { success: true, data: None, error: None, }; Ok(Json(response)) } else { Err(StatusCode::NOT_FOUND) } } else { Err(StatusCode::BAD_REQUEST) } } async fn handle_get_content_latest( State(state): State, Query(params): Query, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { if !check_auth(&headers, &state.password) { warn!("Unauthorized request to /content from {}", addr); return Err(StatusCode::UNAUTHORIZED); } let mut conn = state.db.lock().await; let item = if let Some(tags_str) = params.tags { let tags: Vec = tags_str.split(',').map(|t| t.trim().to_string()).collect(); db::get_item_matching(&mut *conn, &tags, &HashMap::new()) .map_err(|e| { warn!("Failed to get item matching tags {:?} for content: {}", tags, e); StatusCode::INTERNAL_SERVER_ERROR })? } else { db::get_item_last(&mut *conn).map_err(|e| { warn!("Failed to get last item for content: {}", e); StatusCode::INTERNAL_SERVER_ERROR })? }; if let Some(item) = item { match get_item_content(&item, &state.data_dir).await { Ok(content) => { let response = ApiResponse { success: true, data: Some(content), error: None, }; Ok(Json(response)) } Err(e) => { warn!("Failed to get content for item {}: {}", item.id.unwrap_or(0), e); let response = ApiResponse:: { success: false, data: None, error: Some(format!("Failed to retrieve content: {}", e)), }; Ok(Json(response)) } } } else { Err(StatusCode::NOT_FOUND) } } async fn handle_get_content( State(state): State, Path(item_id): Path, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> Result>, StatusCode> { if !check_auth(&headers, &state.password) { warn!("Unauthorized request to /content/{} from {}", item_id, addr); return Err(StatusCode::UNAUTHORIZED); } if let Ok(id) = item_id.parse::() { // Validate that item ID is positive to prevent path traversal issues if id <= 0 { warn!("Invalid item ID {} from {}", id, addr); return Err(StatusCode::BAD_REQUEST); } let mut conn = state.db.lock().await; if let Some(item) = db::get_item(&mut *conn, id).map_err(|e| { warn!("Failed to get item {} for content: {}", id, e); StatusCode::INTERNAL_SERVER_ERROR })? { match get_item_content(&item, &state.data_dir).await { Ok(content) => { let response = ApiResponse { success: true, data: Some(content), error: None, }; Ok(Json(response)) } Err(e) => { warn!("Failed to get content for item {}: {}", id, e); let response = ApiResponse:: { success: false, data: None, error: Some(format!("Failed to retrieve content: {}", e)), }; Ok(Json(response)) } } } else { Err(StatusCode::NOT_FOUND) } } else { Err(StatusCode::BAD_REQUEST) } } async fn handle_openapi() -> Json { let openapi_spec = json!({ "openapi": "3.0.0", "info": { "title": "Keep API", "version": "1.0.0", "description": "REST API for the Keep data storage system" }, "servers": [ { "url": "/", "description": "Local server" } ], "components": { "securitySchemes": { "bearerAuth": { "type": "http", "scheme": "bearer" } }, "schemas": { "ItemInfo": { "type": "object", "properties": { "id": {"type": "integer"}, "ts": {"type": "string", "format": "date-time"}, "size": {"type": "integer", "nullable": true}, "compression": {"type": "string"}, "tags": {"type": "array", "items": {"type": "string"}}, "metadata": {"type": "object"} } }, "StatusInfo": { "type": "object", "properties": { "version": {"type": "string"}, "database_path": {"type": "string"}, "data_directory": {"type": "string"}, "compression_engines": {"type": "array", "items": {"type": "string"}}, "meta_plugins": {"type": "array", "items": {"type": "string"}} } } } }, "security": [{"bearerAuth": []}], "paths": { "/status": { "get": { "summary": "Get system status", "responses": { "200": { "description": "System status", "content": { "application/json": { "schema": {"$ref": "#/components/schemas/StatusInfo"} } } } } } }, "/item/": { "get": { "summary": "List items", "parameters": [ { "name": "tags", "in": "query", "schema": {"type": "string"}, "description": "Comma-separated list of tags to filter by" } ], "responses": { "200": { "description": "List of items", "content": { "application/json": { "schema": { "type": "array", "items": {"$ref": "#/components/schemas/ItemInfo"} } } } } } }, "put": { "summary": "Add new item", "responses": { "201": { "description": "Item created", "content": { "application/json": { "schema": {"$ref": "#/components/schemas/ItemInfo"} } } } } } }, "/item/{id}": { "get": { "summary": "Get item by ID", "parameters": [ { "name": "id", "in": "path", "required": true, "schema": {"type": "string"}, "description": "Item ID or use tags query parameter" }, { "name": "tags", "in": "query", "schema": {"type": "string"}, "description": "Comma-separated list of tags (when ID is not numeric)" } ], "responses": { "200": { "description": "Item information", "content": { "application/json": { "schema": {"$ref": "#/components/schemas/ItemInfo"} } } }, "404": {"description": "Item not found"} } }, "delete": { "summary": "Delete item by ID", "parameters": [ { "name": "id", "in": "path", "required": true, "schema": {"type": "integer"} } ], "responses": { "200": {"description": "Item deleted"}, "404": {"description": "Item not found"} } } }, "/content": { "get": { "summary": "Get content of latest item", "parameters": [ { "name": "tags", "in": "query", "schema": {"type": "string"}, "description": "Comma-separated list of tags to filter by" } ], "responses": { "200": {"description": "Item content"}, "404": {"description": "No items found"} } } }, "/content/{id}": { "get": { "summary": "Get content by item ID", "parameters": [ { "name": "id", "in": "path", "required": true, "schema": {"type": "integer"} } ], "responses": { "200": {"description": "Item content"}, "404": {"description": "Item not found"} } } }, } }); Json(openapi_spec) } async fn get_item_content(item: &db::Item, data_dir: &PathBuf) -> Result { let item_id = item.id.ok_or_else(|| anyhow!("Item missing ID"))?; // Validate that item ID is positive to prevent path traversal issues if item_id <= 0 { return Err(anyhow!("Invalid item ID: {}", item_id)); } let mut item_path = data_dir.clone(); item_path.push(item_id.to_string()); let compression_type = CompressionType::from_str(&item.compression)?; let compression_engine = get_compression_engine(compression_type)?; // Read the content using the compression engine let mut reader = compression_engine.open(item_path)?; let mut content = String::new(); reader.read_to_string(&mut content)?; Ok(content) } async fn handle_swagger_ui() -> Html<&'static str> { let html = r#" Keep API Documentation
"#; Html(html) }