diff --git a/Cargo.toml b/Cargo.toml index 66477c4..bc26852 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ comfy-table = "7.2.0" pwhash = "1.0.0" regex = "1.9.5" ringbuf = "0.3" -rmcp = { version = "0.2.0", features = ["server"] } +rmcp = { version = "0.2.0", features = ["server"], optional = true } rusqlite = { version = "0.37.0", features = ["bundled", "array", "chrono"] } rusqlite_migration = "2.3.0" serde = { version = "1.0.219", features = ["derive"] } @@ -85,6 +85,9 @@ all-filter-plugins = [] # Individual plugin features magic = ["magic"] +# MCP feature (Model Context Protocol support) +mcp = ["rmcp"] + [dev-dependencies] tempfile = "3.3.0" rand = "0.8.5" diff --git a/src/modes/server/api/mod.rs b/src/modes/server/api/mod.rs index 72e5d52..27d00ea 100644 --- a/src/modes/server/api/mod.rs +++ b/src/modes/server/api/mod.rs @@ -1,5 +1,6 @@ pub mod item; pub mod status; +#[cfg(feature = "mcp")] pub mod mcp; use axum::{ @@ -67,6 +68,7 @@ pub fn add_routes(router: Router) -> Router { .route("/api/item/{item_id}/content", get(item::handle_get_item_content)) // MCP endpoints + #[cfg(feature = "mcp")] .route("/mcp/sse", get(mcp::handle_mcp_sse)) } diff --git a/src/modes/server/mod.rs b/src/modes/server/mod.rs new file mode 100644 index 0000000..b1ced07 --- /dev/null +++ b/src/modes/server/mod.rs @@ -0,0 +1,136 @@ +use anyhow::Result; +use axum::{ + Router, + routing::post, +}; +use clap::Command; +use log::{debug, info}; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::Mutex; +use tower_http::cors::CorsLayer; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; +use crate::config; +use crate::services::item_service::ItemService; + +pub mod common; +mod api; +mod pages; +#[cfg(feature = "mcp")] +mod mcp; + +pub use common::{AppState, logging_middleware, create_auth_middleware}; + +pub fn mode_server( + cmd: &mut Command, + settings: &config::Settings, + conn: &mut rusqlite::Connection, + data_path: PathBuf, +) -> Result<()> { + // Get server address from args or config with default + let server_address = if let Some(addr) = &settings.server_address() { + addr.clone() + } else if let Some(server_config) = &settings.server { + server_config.address.clone().unwrap_or_else(|| "127.0.0.1".to_string()) + } else { + "127.0.0.1".to_string() + }; + + // Get server port from args or config with default + let server_port = if let Some(port) = settings.server_port() { + port + } else if let Some(server_config) = &settings.server { + server_config.port.unwrap_or(21080) + } else { + 21080 + }; + + let server_config = common::ServerConfig { + address: server_address, + port: Some(server_port), + password: settings.server_password(), + password_hash: settings.server_password_hash(), + }; + + // Create ItemService once + let item_service = ItemService::new(data_path.clone()); + + // We need to move the connection into the async runtime + let rt = tokio::runtime::Runtime::new()?; + // Take ownership of the connection and move it into the async runtime + let owned_conn = std::mem::replace(conn, rusqlite::Connection::open_in_memory()?); + let cmd = cmd.clone(); + let settings = settings.clone(); + rt.block_on(run_server(server_config, owned_conn, data_path, item_service, cmd, settings)) +} + +async fn run_server( + config: common::ServerConfig, + conn: rusqlite::Connection, + data_dir: PathBuf, + item_service: ItemService, + _cmd: Command, + settings: config::Settings, +) -> Result<()> { + // Construct address with port + let bind_address = if let Some(port) = config.port { + format!("{}:{}", config.address, port) + } else { + format!("{}:21080", config.address) + }; + + debug!("SERVER: Starting REST HTTP server on {}", bind_address); + + // Use the existing database connection + let db_conn = Arc::new(Mutex::new(conn)); + + let state = AppState { + db: db_conn, + data_dir: data_dir.clone(), + item_service: Arc::new(item_service), + cmd: Arc::new(Mutex::new(Command::new("keep"))), + settings: Arc::new(settings.clone()), + }; + + #[cfg(feature = "mcp")] + let mcp_router = Router::new() + .route("/mcp", post(mcp::handle_mcp_request)) + .with_state(state.clone()); + + // Create the app with documentation routes open and others protected + let app = Router::new() + // Add documentation routes without authentication + .merge(api::add_docs_routes(Router::new())) + // Add API, pages, and MCP routes with authentication + .merge( + Router::new() + .merge(api::add_routes(Router::new())) + .merge(pages::add_routes(Router::new())) + #[cfg(feature = "mcp")] + .merge(mcp_router) + .layer(axum::middleware::from_fn(create_auth_middleware(config.password.clone(), config.password_hash.clone()))) + ) + // Apply state to all routes + .with_state(state) + // Add other middleware layers to all routes + .layer(axum::middleware::from_fn(logging_middleware)) + .layer( + ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(CorsLayer::permissive()) + ); + + let addr: SocketAddr = bind_address.parse()?; + + info!("SERVER: HTTP server listening on {}", addr); + + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve( + listener, + app.into_make_service_with_connect_info::() + ).await?; + + Ok(()) +}