Files
keep/src/modes/server.rs

783 lines
26 KiB
Rust

use anyhow::{Result, anyhow};
use axum::{
extract::{ConnectInfo, Path, Query, State},
http::{HeaderMap, StatusCode},
response::{Html, Json},
routing::get,
Router,
};
use clap::Command;
use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::io::Read;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::Mutex;
use tower_http::cors::CorsLayer;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use crate::compression_engine::{CompressionType, get_compression_engine};
use crate::db;
use crate::Args;
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub address: String,
pub password: Option<String>,
}
impl FromStr for ServerConfig {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(ServerConfig {
address: s.to_string(),
password: None,
})
}
}
#[derive(Clone)]
struct AppState {
db: Arc<Mutex<rusqlite::Connection>>,
data_dir: PathBuf,
password: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct ApiResponse<T> {
success: bool,
data: Option<T>,
error: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct ItemInfo {
id: i64,
ts: String,
size: Option<i64>,
compression: String,
tags: Vec<String>,
metadata: HashMap<String, String>,
}
#[derive(Serialize, Deserialize)]
struct StatusInfo {
version: String,
database_path: String,
data_directory: String,
compression_engines: Vec<String>,
meta_plugins: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct TagsQuery {
tags: Option<String>,
}
pub fn mode_server(
_cmd: &mut Command,
args: &Args,
conn: &mut rusqlite::Connection,
data_path: PathBuf,
) -> Result<()> {
let server_address = args.mode.server.as_ref().unwrap();
let config = ServerConfig {
address: server_address.clone(),
password: args.options.server_password.clone(),
};
// We need to move the connection into the async runtime
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(run_server(config, conn, data_path))
}
async fn run_server(
config: ServerConfig,
_conn: &mut rusqlite::Connection,
data_dir: PathBuf,
) -> Result<()> {
debug!("Starting REST HTTP server on {}", config.address);
// Create a new database connection for the server
// Note: This is a simplified approach. In production, you'd want a connection pool
let mut db_path = data_dir.clone();
db_path.push("keep-1.db");
let new_conn = crate::db::open(db_path)?;
let db_conn = Arc::new(Mutex::new(new_conn));
let state = AppState {
db: db_conn,
data_dir: data_dir.clone(),
password: config.password.clone(),
};
let app = Router::new()
.route("/status", get(handle_status))
.route("/item/", get(handle_list_items).put(handle_put_item))
.route("/item/:id", get(handle_get_item).delete(handle_delete_item))
.route("/content", get(handle_get_content_latest))
.route("/content/:id", get(handle_get_content))
.route("/openapi.json", get(handle_openapi))
.route("/swagger/", get(handle_swagger_ui))
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CorsLayer::permissive())
)
.with_state(state);
let addr: SocketAddr = if config.address.starts_with('/') || config.address.starts_with("./") {
// Unix socket - not supported by axum directly, fall back to TCP
warn!("Unix sockets not yet implemented, falling back to TCP on 127.0.0.1:8080");
"127.0.0.1:8080".parse()?
} else {
config.address.parse()?
};
info!("SERVER: HTTP server listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>()
).await?;
Ok(())
}
fn check_auth(headers: &HeaderMap, password: &Option<String>) -> bool {
if let Some(expected_password) = password {
if let Some(auth_header) = headers.get("authorization") {
if let Ok(auth_str) = auth_header.to_str() {
return auth_str.starts_with("Bearer ") && &auth_str[7..] == expected_password;
}
}
false
} else {
true // No password required
}
}
async fn handle_status(
State(state): State<AppState>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Json<ApiResponse<StatusInfo>>, StatusCode> {
info!("SERVER: GET /status from {}", addr);
if !check_auth(&headers, &state.password) {
warn!("Unauthorized request from {}", addr);
return Err(StatusCode::UNAUTHORIZED);
}
let mut db_path = state.data_dir.clone();
db_path.push("keep-1.db");
let status = StatusInfo {
version: env!("CARGO_PKG_VERSION").to_string(),
database_path: db_path.to_string_lossy().to_string(),
data_directory: state.data_dir.to_string_lossy().to_string(),
compression_engines: vec!["gzip".to_string(), "lz4".to_string(), "none".to_string()],
meta_plugins: vec![
"file_magic".to_string(),
"file_mime".to_string(),
"line_count".to_string(),
"word_count".to_string(),
"sha256".to_string(),
],
};
let response = ApiResponse {
success: true,
data: Some(status),
error: None,
};
Ok(Json(response))
}
async fn handle_list_items(
State(state): State<AppState>,
Query(params): Query<TagsQuery>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Json<ApiResponse<Vec<ItemInfo>>>, StatusCode> {
info!("SERVER: GET /item/ from {} with params: {:?}", addr, params);
if !check_auth(&headers, &state.password) {
warn!("Unauthorized request to /item/ from {}", addr);
return Err(StatusCode::UNAUTHORIZED);
}
let mut conn = state.db.lock().await;
let tags: Vec<String> = params.tags
.map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
.unwrap_or_default();
let items = if tags.is_empty() {
db::get_items(&mut *conn).map_err(|e| {
warn!("Failed to get items: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
db::get_items_matching(&mut *conn, &tags, &HashMap::new())
.map_err(|e| {
warn!("Failed to get items matching tags {:?}: {}", tags, e);
StatusCode::INTERNAL_SERVER_ERROR
})?
};
// Get item IDs for batch queries
let item_ids: Vec<i64> = items.iter().filter_map(|item| item.id).collect();
// Get tags and metadata for all items
let tags_map = db::get_tags_for_items(&mut *conn, &item_ids)
.map_err(|e| {
warn!("Failed to get tags for items: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let meta_map = db::get_meta_for_items(&mut *conn, &item_ids)
.map_err(|e| {
warn!("Failed to get metadata for items: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let item_infos: Vec<ItemInfo> = items
.into_iter()
.map(|item| {
let item_id = item.id.unwrap_or(0);
let item_tags = tags_map.get(&item_id)
.map(|tags| tags.iter().map(|t| t.name.clone()).collect())
.unwrap_or_default();
let item_meta = meta_map.get(&item_id)
.cloned()
.unwrap_or_default();
ItemInfo {
id: item_id,
ts: item.ts.to_rfc3339(),
size: item.size,
compression: item.compression,
tags: item_tags,
metadata: item_meta,
}
})
.collect();
let response = ApiResponse {
success: true,
data: Some(item_infos),
error: None,
};
Ok(Json(response))
}
async fn handle_get_item(
State(state): State<AppState>,
Path(item_id): Path<String>,
Query(params): Query<TagsQuery>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Json<ApiResponse<ItemInfo>>, StatusCode> {
info!("SERVER: GET /item/{} from {} with params: {:?}", item_id, addr, params);
if !check_auth(&headers, &state.password) {
warn!("Unauthorized request to /item/{} from {}", item_id, addr);
return Err(StatusCode::UNAUTHORIZED);
}
let mut conn = state.db.lock().await;
let item = if let Ok(id) = item_id.parse::<i64>() {
db::get_item(&mut *conn, id).map_err(|e| {
warn!("Failed to get item {}: {}", id, e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
// Try to find by tags
if let Some(tags_str) = params.tags {
let tags: Vec<String> = tags_str.split(',').map(|t| t.trim().to_string()).collect();
db::get_item_matching(&mut *conn, &tags, &HashMap::new())
.map_err(|e| {
warn!("Failed to get item matching tags {:?}: {}", tags, e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
warn!("Invalid item ID '{}' and no tags provided", item_id);
return Err(StatusCode::BAD_REQUEST);
}
};
if let Some(item) = item {
let item_tags = db::get_item_tags(&mut *conn, &item)
.map_err(|e| {
warn!("Failed to get tags for item {}: {}", item.id.unwrap_or(0), e);
StatusCode::INTERNAL_SERVER_ERROR
})?
.into_iter()
.map(|t| t.name)
.collect();
let item_meta = db::get_item_meta(&mut *conn, &item)
.map_err(|e| {
warn!("Failed to get metadata for item {}: {}", item.id.unwrap_or(0), e);
StatusCode::INTERNAL_SERVER_ERROR
})?
.into_iter()
.map(|m| (m.name, m.value))
.collect();
let item_info = ItemInfo {
id: item.id.unwrap_or(0),
ts: item.ts.to_rfc3339(),
size: item.size,
compression: item.compression,
tags: item_tags,
metadata: item_meta,
};
let response = ApiResponse {
success: true,
data: Some(item_info),
error: None,
};
Ok(Json(response))
} else {
Err(StatusCode::NOT_FOUND)
}
}
async fn handle_put_item(
State(state): State<AppState>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Json<ApiResponse<ItemInfo>>, StatusCode> {
info!("SERVER: PUT /item/ from {}", addr);
if !check_auth(&headers, &state.password) {
warn!("Unauthorized request to PUT /item/ from {}", addr);
return Err(StatusCode::UNAUTHORIZED);
}
// This is a simplified implementation
// In a real implementation, you'd need to properly parse multipart/form-data
// or JSON payload with the item data
let response = ApiResponse::<ItemInfo> {
success: false,
data: None,
error: Some("PUT /item/ not yet implemented".to_string()),
};
Ok(Json(response))
}
async fn handle_delete_item(
State(state): State<AppState>,
Path(item_id): Path<String>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Json<ApiResponse<()>>, StatusCode> {
info!("SERVER: DELETE /item/{} from {}", item_id, addr);
if !check_auth(&headers, &state.password) {
warn!("Unauthorized request to DELETE /item/{} from {}", item_id, addr);
return Err(StatusCode::UNAUTHORIZED);
}
if let Ok(id) = item_id.parse::<i64>() {
let mut conn = state.db.lock().await;
if let Some(item) = db::get_item(&mut *conn, id).map_err(|e| {
warn!("Failed to get item {} for deletion: {}", id, e);
StatusCode::INTERNAL_SERVER_ERROR
})? {
db::delete_item(&mut *conn, item).map_err(|e| {
warn!("Failed to delete item {}: {}", id, e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response = ApiResponse::<()> {
success: true,
data: None,
error: None,
};
Ok(Json(response))
} else {
Err(StatusCode::NOT_FOUND)
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
async fn handle_get_content_latest(
State(state): State<AppState>,
Query(params): Query<TagsQuery>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Json<ApiResponse<String>>, StatusCode> {
info!("SERVER: GET /content from {} with params: {:?}", addr, params);
if !check_auth(&headers, &state.password) {
warn!("Unauthorized request to /content from {}", addr);
return Err(StatusCode::UNAUTHORIZED);
}
let mut conn = state.db.lock().await;
let item = if let Some(tags_str) = params.tags {
let tags: Vec<String> = tags_str.split(',').map(|t| t.trim().to_string()).collect();
db::get_item_matching(&mut *conn, &tags, &HashMap::new())
.map_err(|e| {
warn!("Failed to get item matching tags {:?} for content: {}", tags, e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
db::get_item_last(&mut *conn).map_err(|e| {
warn!("Failed to get last item for content: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
};
if let Some(item) = item {
match get_item_content(&item, &state.data_dir).await {
Ok(content) => {
let response = ApiResponse {
success: true,
data: Some(content),
error: None,
};
Ok(Json(response))
}
Err(e) => {
warn!("Failed to get content for item {}: {}", item.id.unwrap_or(0), e);
let response = ApiResponse::<String> {
success: false,
data: None,
error: Some(format!("Failed to retrieve content: {}", e)),
};
Ok(Json(response))
}
}
} else {
Err(StatusCode::NOT_FOUND)
}
}
async fn handle_get_content(
State(state): State<AppState>,
Path(item_id): Path<String>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Json<ApiResponse<String>>, StatusCode> {
info!("SERVER: GET /content/{} from {}", item_id, addr);
if !check_auth(&headers, &state.password) {
warn!("Unauthorized request to /content/{} from {}", item_id, addr);
return Err(StatusCode::UNAUTHORIZED);
}
if let Ok(id) = item_id.parse::<i64>() {
// Validate that item ID is positive to prevent path traversal issues
if id <= 0 {
warn!("Invalid item ID {} from {}", id, addr);
return Err(StatusCode::BAD_REQUEST);
}
let mut conn = state.db.lock().await;
if let Some(item) = db::get_item(&mut *conn, id).map_err(|e| {
warn!("Failed to get item {} for content: {}", id, e);
StatusCode::INTERNAL_SERVER_ERROR
})? {
match get_item_content(&item, &state.data_dir).await {
Ok(content) => {
let response = ApiResponse {
success: true,
data: Some(content),
error: None,
};
Ok(Json(response))
}
Err(e) => {
warn!("Failed to get content for item {}: {}", id, e);
let response = ApiResponse::<String> {
success: false,
data: None,
error: Some(format!("Failed to retrieve content: {}", e)),
};
Ok(Json(response))
}
}
} else {
Err(StatusCode::NOT_FOUND)
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
async fn handle_openapi() -> Json<serde_json::Value> {
let openapi_spec = json!({
"openapi": "3.0.0",
"info": {
"title": "Keep API",
"version": "1.0.0",
"description": "REST API for the Keep data storage system"
},
"servers": [
{
"url": "/",
"description": "Local server"
}
],
"components": {
"securitySchemes": {
"bearerAuth": {
"type": "http",
"scheme": "bearer"
}
},
"schemas": {
"ItemInfo": {
"type": "object",
"properties": {
"id": {"type": "integer"},
"ts": {"type": "string", "format": "date-time"},
"size": {"type": "integer", "nullable": true},
"compression": {"type": "string"},
"tags": {"type": "array", "items": {"type": "string"}},
"metadata": {"type": "object"}
}
},
"StatusInfo": {
"type": "object",
"properties": {
"version": {"type": "string"},
"database_path": {"type": "string"},
"data_directory": {"type": "string"},
"compression_engines": {"type": "array", "items": {"type": "string"}},
"meta_plugins": {"type": "array", "items": {"type": "string"}}
}
}
}
},
"security": [{"bearerAuth": []}],
"paths": {
"/status": {
"get": {
"summary": "Get system status",
"responses": {
"200": {
"description": "System status",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/StatusInfo"}
}
}
}
}
}
},
"/item/": {
"get": {
"summary": "List items",
"parameters": [
{
"name": "tags",
"in": "query",
"schema": {"type": "string"},
"description": "Comma-separated list of tags to filter by"
}
],
"responses": {
"200": {
"description": "List of items",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {"$ref": "#/components/schemas/ItemInfo"}
}
}
}
}
}
},
"put": {
"summary": "Add new item",
"responses": {
"201": {
"description": "Item created",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/ItemInfo"}
}
}
}
}
}
},
"/item/{id}": {
"get": {
"summary": "Get item by ID",
"parameters": [
{
"name": "id",
"in": "path",
"required": true,
"schema": {"type": "string"},
"description": "Item ID or use tags query parameter"
},
{
"name": "tags",
"in": "query",
"schema": {"type": "string"},
"description": "Comma-separated list of tags (when ID is not numeric)"
}
],
"responses": {
"200": {
"description": "Item information",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/ItemInfo"}
}
}
},
"404": {"description": "Item not found"}
}
},
"delete": {
"summary": "Delete item by ID",
"parameters": [
{
"name": "id",
"in": "path",
"required": true,
"schema": {"type": "integer"}
}
],
"responses": {
"200": {"description": "Item deleted"},
"404": {"description": "Item not found"}
}
}
},
"/content": {
"get": {
"summary": "Get content of latest item",
"parameters": [
{
"name": "tags",
"in": "query",
"schema": {"type": "string"},
"description": "Comma-separated list of tags to filter by"
}
],
"responses": {
"200": {"description": "Item content"},
"404": {"description": "No items found"}
}
}
},
"/content/{id}": {
"get": {
"summary": "Get content by item ID",
"parameters": [
{
"name": "id",
"in": "path",
"required": true,
"schema": {"type": "integer"}
}
],
"responses": {
"200": {"description": "Item content"},
"404": {"description": "Item not found"}
}
}
},
"/openapi.json": {
"get": {
"summary": "Get OpenAPI specification",
"responses": {
"200": {
"description": "OpenAPI specification",
"content": {
"application/json": {"schema": {"type": "object"}}
}
}
}
}
},
"/swagger/": {
"get": {
"summary": "Swagger UI",
"responses": {
"200": {
"description": "Swagger UI HTML page",
"content": {
"text/html": {"schema": {"type": "string"}}
}
}
}
}
}
}
});
Json(openapi_spec)
}
async fn get_item_content(item: &db::Item, data_dir: &PathBuf) -> Result<String> {
let item_id = item.id.ok_or_else(|| anyhow!("Item missing ID"))?;
// Validate that item ID is positive to prevent path traversal issues
if item_id <= 0 {
return Err(anyhow!("Invalid item ID: {}", item_id));
}
let mut item_path = data_dir.clone();
item_path.push(item_id.to_string());
let compression_type = CompressionType::from_str(&item.compression)?;
let compression_engine = get_compression_engine(compression_type)?;
// Read the content using the compression engine
let mut reader = compression_engine.open(item_path)?;
let mut content = String::new();
reader.read_to_string(&mut content)?;
Ok(content)
}
async fn handle_swagger_ui() -> Html<&'static str> {
let html = r#"<!DOCTYPE html>
<html>
<head>
<title>Keep API Documentation</title>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3.52.5/swagger-ui.css" />
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@3.52.5/swagger-ui-bundle.js"></script>
<script>
SwaggerUIBundle({
url: '/openapi.json',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.presets.standalone
]
});
</script>
</body>
</html>"#;
Html(html)
}