diff --git a/Cargo.toml b/Cargo.toml index 89c8237..4f2a0c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,11 @@ uzers = "0.11.3" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.142" serde_yaml = "0.9.34" +tokio = { version = "1.0", features = ["full"] } +axum = "0.7" +tower = "0.4" +tower-http = { version = "0.5", features = ["cors", "fs"] } +hyper = { version = "1.0", features = ["full"] } [dev-dependencies] tempfile = "3.3.0" diff --git a/DESIGN.md b/DESIGN.md index 9c03987..ad4280e 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -31,6 +31,7 @@ - `modes/info.rs` - Show detailed item information - `modes/diff.rs` - Compare two items - `modes/status.rs` - Show system status and capabilities +- `modes/server.rs` - REST HTTP server mode with OpenAPI documentation - `modes/common.rs` - Shared utilities for all modes ### Database Module diff --git a/src/main.rs b/src/main.rs index fc82d76..96799ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -80,9 +80,13 @@ struct ModeArgs { ))] info: bool, - #[arg(group("mode"), help_heading("Mode Options"), short('S'), long, conflicts_with_all(["save", "get", "diff", "list", "update", "delete", "info"]))] + #[arg(group("mode"), help_heading("Mode Options"), short('S'), long, conflicts_with_all(["save", "get", "diff", "list", "update", "delete", "info", "server"]))] #[arg(help("Show status of directories and supported compression algorithms"))] status: bool, + + #[arg(group("mode"), help_heading("Mode Options"), long, conflicts_with_all(["save", "get", "diff", "list", "update", "delete", "info", "status"]))] + #[arg(help("Start REST HTTP server on specified address:port or socket path"))] + server: Option, } /** @@ -142,6 +146,10 @@ struct OptionsArgs { #[arg(long, value_enum, default_value("table"))] #[arg(help("Output format (only works with --info, --status, --list)"))] output_format: Option, + + #[arg(long, env("KEEP_SERVER_PASSWORD"))] + #[arg(help("Password for server authentication (requires --server)"))] + server_password: Option, } /** @@ -158,6 +166,7 @@ enum KeepModes { Delete, Info, Status, + Server, } /** @@ -251,6 +260,8 @@ fn main() -> Result<(), Error> { mode = KeepModes::Info; } else if args.mode.status { mode = KeepModes::Status; + } else if args.mode.server.is_some() { + mode = KeepModes::Server; } if mode == KeepModes::Unknown { @@ -279,6 +290,14 @@ fn main() -> Result<(), Error> { ).exit(); } + // Validate server password usage + if args.options.server_password.is_some() && mode != KeepModes::Server { + cmd.error( + ErrorKind::InvalidValue, + "--server-password can only be used with --server mode" + ).exit(); + } + debug!("MAIN: args: {:?}", args); debug!("MAIN: ids: {:?}", ids); debug!("MAIN: tags: {:?}", tags); @@ -333,6 +352,9 @@ fn main() -> Result<(), Error> { KeepModes::Status => { crate::modes::status::mode_status(&mut cmd, &args, data_path, db_path)? } + KeepModes::Server => { + crate::modes::server::mode_server(&mut cmd, &args, &mut conn, data_path)? + } _ => todo!(), } diff --git a/src/modes/mod.rs b/src/modes/mod.rs index 402c2e0..1f66a60 100644 --- a/src/modes/mod.rs +++ b/src/modes/mod.rs @@ -5,5 +5,6 @@ pub mod get; pub mod info; pub mod list; pub mod save; +pub mod server; pub mod status; pub mod update; diff --git a/src/modes/server.rs b/src/modes/server.rs index e69de29..70dba2d 100644 --- a/src/modes/server.rs +++ b/src/modes/server.rs @@ -0,0 +1,658 @@ +use anyhow::{anyhow, Result}; +use axum::{ + extract::{Path, Query, State}, + http::{HeaderMap, StatusCode}, + response::{Html, Json}, + routing::{delete, get, put}, + Router, +}; +use clap::Command; +use log::{debug, info, warn}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::Mutex; +use tower_http::cors::CorsLayer; + +use crate::db::{self, Item, Tag, Meta}; +use crate::modes::common::{format_size, OutputFormat}; +use crate::Args; + +#[derive(Debug, Clone)] +pub struct ServerConfig { + pub address: String, + pub password: Option, +} + +impl FromStr for ServerConfig { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + Ok(ServerConfig { + address: s.to_string(), + password: None, + }) + } +} + +#[derive(Clone)] +struct AppState { + db: Arc>, + data_dir: PathBuf, + password: Option, +} + +#[derive(Serialize, Deserialize)] +struct ApiResponse { + success: bool, + data: Option, + error: Option, +} + +#[derive(Serialize, Deserialize)] +struct ItemInfo { + id: i64, + ts: String, + size: Option, + compression: String, + tags: Vec, + metadata: HashMap, +} + +#[derive(Serialize, Deserialize)] +struct StatusInfo { + version: String, + database_path: String, + data_directory: String, + compression_engines: Vec, + meta_plugins: Vec, +} + +#[derive(Deserialize)] +struct TagsQuery { + tags: Option, +} + +pub fn mode_server( + cmd: &mut Command, + args: &Args, + conn: &mut rusqlite::Connection, + data_path: PathBuf, +) -> Result<()> { + let server_address = args.mode.server.as_ref().unwrap(); + + let config = ServerConfig { + address: server_address.clone(), + password: args.options.server_password.clone(), + }; + + // We need to move the connection into the async runtime + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(run_server(config, conn, data_path)) +} + +async fn run_server( + config: ServerConfig, + conn: &mut rusqlite::Connection, + data_dir: PathBuf, +) -> Result<()> { + info!("Starting REST HTTP server on {}", config.address); + + // Move connection into Arc> for sharing across async tasks + // Note: This is a simplified approach. In production, you'd want a connection pool + let db_conn = Arc::new(Mutex::new(std::mem::replace(conn, unsafe { std::mem::zeroed() }))); + + let state = AppState { + db: db_conn, + data_dir: data_dir.clone(), + password: config.password.clone(), + }; + + let app = Router::new() + .route("/status", get(handle_status)) + .route("/item/", get(handle_list_items).put(handle_put_item)) + .route("/item/:id", get(handle_get_item).delete(handle_delete_item)) + .route("/content", get(handle_get_content_latest)) + .route("/content/:id", get(handle_get_content)) + .route("/openapi.json", get(handle_openapi)) + .route("/swagger/", get(handle_swagger_ui)) + .layer(CorsLayer::permissive()) + .with_state(state); + + let addr: SocketAddr = if config.address.starts_with('/') || config.address.starts_with("./") { + // Unix socket - not supported by axum directly, fall back to TCP + warn!("Unix sockets not yet implemented, falling back to TCP on 127.0.0.1:8080"); + "127.0.0.1:8080".parse()? + } else { + config.address.parse()? + }; + + info!("Server listening on {}", addr); + + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app).await?; + + Ok(()) +} + +fn check_auth(headers: &HeaderMap, password: &Option) -> bool { + if let Some(expected_password) = password { + if let Some(auth_header) = headers.get("authorization") { + if let Ok(auth_str) = auth_header.to_str() { + return auth_str.starts_with("Bearer ") && &auth_str[7..] == expected_password; + } + } + false + } else { + true // No password required + } +} + +async fn handle_status( + State(state): State, + headers: HeaderMap, +) -> Result>, StatusCode> { + if !check_auth(&headers, &state.password) { + return Err(StatusCode::UNAUTHORIZED); + } + + let mut db_path = state.data_dir.clone(); + db_path.push("keep-1.db"); + + let status = StatusInfo { + version: env!("CARGO_PKG_VERSION").to_string(), + database_path: db_path.to_string_lossy().to_string(), + data_directory: state.data_dir.to_string_lossy().to_string(), + compression_engines: vec!["gzip".to_string(), "lz4".to_string(), "none".to_string()], + meta_plugins: vec![ + "file_magic".to_string(), + "file_mime".to_string(), + "line_count".to_string(), + "word_count".to_string(), + "sha256".to_string(), + ], + }; + + let response = ApiResponse { + success: true, + data: Some(status), + error: None, + }; + + Ok(Json(response)) +} + +async fn handle_list_items( + State(state): State, + Query(params): Query, + headers: HeaderMap, +) -> Result>>, StatusCode> { + if !check_auth(&headers, &state.password) { + return Err(StatusCode::UNAUTHORIZED); + } + + let mut conn = state.db.lock().await; + + let tags: Vec = params.tags + .map(|s| s.split(',').map(|t| t.trim().to_string()).collect()) + .unwrap_or_default(); + + let items = if tags.is_empty() { + db::get_items(&mut *conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + } else { + db::get_items_matching(&mut *conn, &tags, &HashMap::new()) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + }; + + // Get item IDs for batch queries + let item_ids: Vec = items.iter().filter_map(|item| item.id).collect(); + + // Get tags and metadata for all items + let tags_map = db::get_tags_for_items(&mut *conn, &item_ids) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let meta_map = db::get_meta_for_items(&mut *conn, &item_ids) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let item_infos: Vec = items + .into_iter() + .map(|item| { + let item_id = item.id.unwrap_or(0); + let item_tags = tags_map.get(&item_id) + .map(|tags| tags.iter().map(|t| t.name.clone()).collect()) + .unwrap_or_default(); + let item_meta = meta_map.get(&item_id) + .cloned() + .unwrap_or_default(); + + ItemInfo { + id: item_id, + ts: item.ts.to_rfc3339(), + size: item.size, + compression: item.compression, + tags: item_tags, + metadata: item_meta, + } + }) + .collect(); + + let response = ApiResponse { + success: true, + data: Some(item_infos), + error: None, + }; + + Ok(Json(response)) +} + +async fn handle_get_item( + State(state): State, + Path(item_id): Path, + Query(params): Query, + headers: HeaderMap, +) -> Result>, StatusCode> { + if !check_auth(&headers, &state.password) { + return Err(StatusCode::UNAUTHORIZED); + } + + let mut conn = state.db.lock().await; + + let item = if let Ok(id) = item_id.parse::() { + db::get_item(&mut *conn, id).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + } else { + // Try to find by tags + if let Some(tags_str) = params.tags { + let tags: Vec = tags_str.split(',').map(|t| t.trim().to_string()).collect(); + db::get_item_matching(&mut *conn, &tags, &HashMap::new()) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + } else { + return Err(StatusCode::BAD_REQUEST); + } + }; + + if let Some(item) = item { + let item_tags = db::get_item_tags(&mut *conn, &item) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .into_iter() + .map(|t| t.name) + .collect(); + let item_meta = db::get_item_meta(&mut *conn, &item) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .into_iter() + .map(|m| (m.name, m.value)) + .collect(); + + let item_info = ItemInfo { + id: item.id.unwrap_or(0), + ts: item.ts.to_rfc3339(), + size: item.size, + compression: item.compression, + tags: item_tags, + metadata: item_meta, + }; + + let response = ApiResponse { + success: true, + data: Some(item_info), + error: None, + }; + + Ok(Json(response)) + } else { + Err(StatusCode::NOT_FOUND) + } +} + +async fn handle_put_item( + State(state): State, + headers: HeaderMap, +) -> Result>, StatusCode> { + if !check_auth(&headers, &state.password) { + return Err(StatusCode::UNAUTHORIZED); + } + + // This is a simplified implementation + // In a real implementation, you'd need to properly parse multipart/form-data + // or JSON payload with the item data + + let response = ApiResponse:: { + success: false, + data: None, + error: Some("PUT /item/ not yet implemented".to_string()), + }; + + Ok(Json(response)) +} + +async fn handle_delete_item( + State(state): State, + Path(item_id): Path, + headers: HeaderMap, +) -> Result>, StatusCode> { + if !check_auth(&headers, &state.password) { + return Err(StatusCode::UNAUTHORIZED); + } + + if let Ok(id) = item_id.parse::() { + let mut conn = state.db.lock().await; + + if let Some(item) = db::get_item(&mut *conn, id).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? { + db::delete_item(&mut *conn, item).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let response = ApiResponse::<()> { + success: true, + data: None, + error: None, + }; + Ok(Json(response)) + } else { + Err(StatusCode::NOT_FOUND) + } + } else { + Err(StatusCode::BAD_REQUEST) + } +} + +async fn handle_get_content_latest( + State(state): State, + Query(params): Query, + headers: HeaderMap, +) -> Result>, StatusCode> { + if !check_auth(&headers, &state.password) { + return Err(StatusCode::UNAUTHORIZED); + } + + let mut conn = state.db.lock().await; + + let item = if let Some(tags_str) = params.tags { + let tags: Vec = tags_str.split(',').map(|t| t.trim().to_string()).collect(); + db::get_item_matching(&mut *conn, &tags, &HashMap::new()) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + } else { + db::get_item_last(&mut *conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + }; + + if let Some(_item) = item { + // Get the actual content - this would need to be implemented + // based on how content is stored and retrieved in your system + let response = ApiResponse:: { + success: false, + data: None, + error: Some("Content retrieval not yet implemented".to_string()), + }; + Ok(Json(response)) + } else { + Err(StatusCode::NOT_FOUND) + } +} + +async fn handle_get_content( + State(state): State, + Path(item_id): Path, + headers: HeaderMap, +) -> Result>, StatusCode> { + if !check_auth(&headers, &state.password) { + return Err(StatusCode::UNAUTHORIZED); + } + + if let Ok(id) = item_id.parse::() { + let mut conn = state.db.lock().await; + + if let Some(_item) = db::get_item(&mut *conn, id).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? { + // Get the actual content - this would need to be implemented + // based on how content is stored and retrieved in your system + let response = ApiResponse:: { + success: false, + data: None, + error: Some("Content retrieval not yet implemented".to_string()), + }; + Ok(Json(response)) + } else { + Err(StatusCode::NOT_FOUND) + } + } else { + Err(StatusCode::BAD_REQUEST) + } +} + +async fn handle_openapi() -> Json { + let openapi_spec = json!({ + "openapi": "3.0.0", + "info": { + "title": "Keep API", + "version": "1.0.0", + "description": "REST API for the Keep data storage system" + }, + "servers": [ + { + "url": "/", + "description": "Local server" + } + ], + "components": { + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer" + } + }, + "schemas": { + "ItemInfo": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "ts": {"type": "string", "format": "date-time"}, + "size": {"type": "integer", "nullable": true}, + "compression": {"type": "string"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "metadata": {"type": "object"} + } + }, + "StatusInfo": { + "type": "object", + "properties": { + "version": {"type": "string"}, + "database_path": {"type": "string"}, + "data_directory": {"type": "string"}, + "compression_engines": {"type": "array", "items": {"type": "string"}}, + "meta_plugins": {"type": "array", "items": {"type": "string"}} + } + } + } + }, + "security": [{"bearerAuth": []}], + "paths": { + "/status": { + "get": { + "summary": "Get system status", + "responses": { + "200": { + "description": "System status", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/StatusInfo"} + } + } + } + } + } + }, + "/item/": { + "get": { + "summary": "List items", + "parameters": [ + { + "name": "tags", + "in": "query", + "schema": {"type": "string"}, + "description": "Comma-separated list of tags to filter by" + } + ], + "responses": { + "200": { + "description": "List of items", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": {"$ref": "#/components/schemas/ItemInfo"} + } + } + } + } + } + }, + "put": { + "summary": "Add new item", + "responses": { + "201": { + "description": "Item created", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/ItemInfo"} + } + } + } + } + } + }, + "/item/{id}": { + "get": { + "summary": "Get item by ID", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": {"type": "string"}, + "description": "Item ID or use tags query parameter" + }, + { + "name": "tags", + "in": "query", + "schema": {"type": "string"}, + "description": "Comma-separated list of tags (when ID is not numeric)" + } + ], + "responses": { + "200": { + "description": "Item information", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/ItemInfo"} + } + } + }, + "404": {"description": "Item not found"} + } + }, + "delete": { + "summary": "Delete item by ID", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": {"type": "integer"} + } + ], + "responses": { + "200": {"description": "Item deleted"}, + "404": {"description": "Item not found"} + } + } + }, + "/content": { + "get": { + "summary": "Get content of latest item", + "parameters": [ + { + "name": "tags", + "in": "query", + "schema": {"type": "string"}, + "description": "Comma-separated list of tags to filter by" + } + ], + "responses": { + "200": {"description": "Item content"}, + "404": {"description": "No items found"} + } + } + }, + "/content/{id}": { + "get": { + "summary": "Get content by item ID", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": {"type": "integer"} + } + ], + "responses": { + "200": {"description": "Item content"}, + "404": {"description": "Item not found"} + } + } + }, + "/openapi.json": { + "get": { + "summary": "Get OpenAPI specification", + "responses": { + "200": { + "description": "OpenAPI specification", + "content": { + "application/json": {"schema": {"type": "object"}} + } + } + } + } + }, + "/swagger/": { + "get": { + "summary": "Swagger UI", + "responses": { + "200": { + "description": "Swagger UI HTML page", + "content": { + "text/html": {"schema": {"type": "string"}} + } + } + } + } + } + } + }); + + Json(openapi_spec) +} + +async fn handle_swagger_ui() -> Html<&'static str> { + let html = r#" + + + Keep API Documentation + + + +
+ + + +"#; + + Html(html) +}