From 17be6abaab0e4e7006fcdffe12b5e23ae6992252 Mon Sep 17 00:00:00 2001 From: Andrew Phillips Date: Sat, 14 Mar 2026 00:03:42 -0300 Subject: [PATCH] refactor: streaming, security hardening, and MCP removal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major overhaul of server architecture and security posture: - Streaming: Unified all I/O through PIPESIZE (8192-byte) buffers. POST bodies stream via MpscReader through the save pipeline. GET content streams from disk via decompression to client. Removed save_item_with_reader, get_item_content_info, ChannelReader. 413 responses keep partial items (nonfatal by design). - Security: XSS protection in all HTML pages via html_escape crate. Security headers middleware (nosniff, frame deny, referrer policy). CORS tightened to explicit headers. Input validation for tags (256 chars), metadata (128/4096), pagination (10k cap). Config file reads use from_utf8_lossy. Generic error messages in HTML. Diff endpoint has 10 MB per-item cap. max_body_size config option. - Panics eliminated: Path unwraps → proper error propagation. Mutex unwraps → map_err (registries) / expect with message (local). - MCP removed: Deleted all MCP code, rmcp dependency, mcp feature. - Docs: Updated README, DESIGN, AGENTS to reflect all changes. --- AGENTS.md | 20 +- Cargo.lock | 111 +----- Cargo.toml | 5 +- DESIGN.md | 22 +- Dockerfile | 2 +- README.md | 64 ++-- src/args.rs | 4 + src/common/schema.rs | 2 +- src/common/status.rs | 28 +- src/config.rs | 30 +- src/filter_plugin/exec.rs | 3 +- src/filter_plugin/head.rs | 6 +- src/filter_plugin/skip.rs | 6 +- src/filter_plugin/tail.rs | 6 +- src/filter_plugin/tokens.rs | 9 +- src/meta_plugin/cwd.rs | 3 +- src/meta_plugin/digest.rs | 3 +- src/meta_plugin/env.rs | 3 +- src/meta_plugin/exec.rs | 3 +- src/meta_plugin/hostname.rs | 3 +- src/meta_plugin/keep_pid.rs | 3 +- src/meta_plugin/magic_file.rs | 3 +- src/meta_plugin/mod.rs | 12 +- src/meta_plugin/read_rate.rs | 3 +- src/meta_plugin/read_time.rs | 3 +- src/meta_plugin/shell.rs | 3 +- src/meta_plugin/shell_pid.rs | 3 +- src/meta_plugin/text.rs | 3 +- src/meta_plugin/tokens.rs | 3 +- src/meta_plugin/user.rs | 3 +- src/modes/client/save.rs | 4 +- src/modes/info.rs | 2 +- src/modes/list.rs | 6 +- src/modes/server/api/item.rs | 592 +++++++++++++++++++++-------- src/modes/server/api/mcp.rs | 72 ---- src/modes/server/api/mod.rs | 14 +- src/modes/server/api/status.rs | 36 +- src/modes/server/common.rs | 51 ++- src/modes/server/mcp/mod.rs | 83 ---- src/modes/server/mcp/server.rs | 83 ---- src/modes/server/mcp/tools.rs | 344 ----------------- src/modes/server/mod.rs | 47 ++- src/modes/server/pages.rs | 71 +++- src/modes/status.rs | 2 +- src/modes/status_plugins.rs | 2 +- src/services/async_data_service.rs | 36 +- src/services/async_item_service.rs | 153 ++++---- src/services/filter_service.rs | 12 +- src/services/item_service.rs | 161 +------- src/services/status_service.rs | 8 +- src/services/sync_data_service.rs | 34 +- 51 files changed, 876 insertions(+), 1309 deletions(-) delete mode 100644 src/modes/server/api/mcp.rs delete mode 100644 src/modes/server/mcp/mod.rs delete mode 100644 src/modes/server/mcp/server.rs delete mode 100644 src/modes/server/mcp/tools.rs diff --git a/AGENTS.md b/AGENTS.md index f8f3bca..93bb2e2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -29,11 +29,27 @@ TERM=dumb cargo build --features server # With server feature - Filter plugins must implement `filter()`, `clone_box()`, and `options()` - Meta plugins extend `BaseMetaPlugin` for boilerplate reduction - Enum string representations: `#[strum(serialize_all = "snake_case")]` -- Lint rules: `deny(clippy::all)`, `deny(unsafe_code)` (except `libc::umask` in main.rs) -- Feature flags: `default = ["magic", "lz4", "gzip"]`; optional: `server`, `mcp`, `swagger` +- Lint rules: `deny(clippy::all)`, `deny(unsafe_code)` (except `libc::umask` in main.rs, `unsafe impl Send` in `src/meta_plugin/magic_file.rs` for `SendCookie`) +- Feature flags: `default = ["magic", "lz4", "gzip"]`; optional: `server`, `swagger` ## Testing - Tests in `src/tests/` mirroring `src/` structure; shared helpers in `src/tests/common/test_helpers.rs` - Key helpers: `create_temp_dir()`, `create_temp_db()`, `test_compression_engine()` - Test naming: `test__` + +## Streaming Constraint + +**At no point should the whole file be in memory at once.** All I/O must use fixed-size buffers: + +- `PIPESIZE` = 8192 bytes (`src/common/mod.rs:10`) +- Server POST body streams through `save_item_raw_streaming` via `MpscReader` +- Server GET content streams via streaming reader (not `read_to_end`) +- When `max_body_size` is exceeded, return `413` but keep the partial item (nonfatal by design) +- Filter/meta plugins use `PIPESIZE`-sized buffers + +## HTML Rendering + +- Use `html_escape` crate for all user-controlled data in HTML pages +- `esc()` for text content, `esc_attr()` for HTML attributes +- Security headers middleware: `X-Content-Type-Options: nosniff`, `X-Frame-Options: DENY`, `Referrer-Policy: strict-origin-when-cross-origin` diff --git a/Cargo.lock b/Cargo.lock index de8c97c..720ab6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -860,12 +860,6 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" -[[package]] -name = "dyn-clone" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" - [[package]] name = "either" version = "1.15.0" @@ -1283,6 +1277,15 @@ dependencies = [ "digest 0.9.0", ] +[[package]] +name = "html-escape" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1ad449764d627e22bfd7cd5e8868264fc9236e07c752972b4080cd351cb476" +dependencies = [ + "utf8-width", +] + [[package]] name = "http" version = "1.4.0" @@ -1658,6 +1661,7 @@ dependencies = [ "flate2", "futures", "gethostname", + "html-escape", "http-body-util", "humansize", "hyper", @@ -1680,7 +1684,6 @@ dependencies = [ "rand 0.9.2", "regex", "ringbuf", - "rmcp", "rusqlite", "rusqlite_migration", "serde", @@ -2050,12 +2053,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - [[package]] name = "pathdiff" version = "0.2.3" @@ -2388,40 +2385,6 @@ dependencies = [ "portable-atomic-util", ] -[[package]] -name = "rmcp" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37f2048a81a7ff7e8ef6bc5abced70c3d9114c8f03d85d7aaaafd9fd04f12e9e" -dependencies = [ - "base64", - "chrono", - "futures", - "paste", - "pin-project-lite", - "rmcp-macros", - "schemars", - "serde", - "serde_json", - "thiserror 2.0.18", - "tokio", - "tokio-util", - "tracing", -] - -[[package]] -name = "rmcp-macros" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72398e694b9f6dbb5de960cf158c8699e6a1854cb5bbaac7de0646b2005763c4" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "serde_json", - "syn", -] - [[package]] name = "ron" version = "0.12.0" @@ -2591,31 +2554,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schemars" -version = "0.8.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" -dependencies = [ - "chrono", - "dyn-clone", - "schemars_derive", - "serde", - "serde_json", -] - -[[package]] -name = "schemars_derive" -version = "0.8.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" -dependencies = [ - "proc-macro2", - "quote", - "serde_derive_internals", - "syn", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -2670,17 +2608,6 @@ dependencies = [ "syn", ] -[[package]] -name = "serde_derive_internals" -version = "0.29.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "serde_json" version = "1.0.149" @@ -3216,21 +3143,9 @@ checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", - "tracing-attributes", "tracing-core", ] -[[package]] -name = "tracing-attributes" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "tracing-core" version = "0.1.36" @@ -3368,6 +3283,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf8-width" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1292c0d970b54115d14f2492fe0170adf21d68a1de108eebc51c1df4f346a091" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 9ba8840..df07211 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,6 @@ comfy-table = "7.2" pwhash = "1.0" regex = "1.10" ringbuf = "0.4" -rmcp = { version = "0.2", features = ["server"], optional = true } rusqlite = { version = "0.37", features = ["bundled", "array", "chrono"] } rusqlite_migration = "2.3" serde = { version = "1.0", features = ["derive"] } @@ -71,6 +70,7 @@ pest = "2.8" pest_derive = "2.8" dirs = "6.0" similar = { version = "2.7", default-features = false, features = ["text"] } +html-escape = "0.2" ureq = { version = "3", features = ["json"], optional = true } os_pipe = { version = "1", optional = true } axum-server = { version = "0.8", features = ["tls-rustls"], optional = true } @@ -102,9 +102,6 @@ all-filter-plugins = [] # Individual plugin features magic = ["dep:magic"] -# MCP feature (Model Context Protocol support) -mcp = ["dep:rmcp"] - # Swagger UI feature swagger = ["dep:utoipa-swagger-ui"] diff --git a/DESIGN.md b/DESIGN.md index 7fe5b44..63571eb 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -128,16 +128,19 @@ ### Item Operations - `GET /api/item/` - Get a list of items as JSON. Optional params: `order=newest|oldest`, `start=0`, `count=100`, `tags=tag1,tag2` -- `POST /api/item/` - Add a new item (body: raw content). Query params: `tags`, `metadata` (JSON), `compress=true|false`, `meta=true|false` +- `POST /api/item/` - Add a new item (body: raw content, **streamed** through fixed-size 8192-byte buffers). Query params: `tags`, `metadata` (JSON), `compress=true|false`, `meta=true|false` - `POST /api/item/<#>/meta` - Add metadata to an existing item (body: JSON object) - `DELETE /api/item/<#>` - Delete an item - `GET /api/item/latest` - Return the latest item as JSON. Optional params: `tags=tag1,tag2`, `allow_binary=true|false` - `GET /api/item/latest/meta` - Return the latest item metadata as JSON. Optional params: `tags=tag1,tag2` -- `GET /api/item/latest/content` - Return the raw content of the latest item. Optional params: `tags=tag1,tag2`, `decompress=true|false` +- `GET /api/item/latest/content` - Return the raw content of the latest item (**streamed**). Optional params: `tags=tag1,tag2`, `decompress=true|false` - `GET /api/item/<#>` - Return the item as JSON. Optional params: `allow_binary=true|false` - `GET /api/item/<#>/meta` - Return the item metadata as JSON -- `GET /api/item/<#>/content` - Return the raw content of the item. Optional params: `decompress=true|false` -- `GET /api/diff` - Diff two items. Params: `id_a`, `id_b` +- `GET /api/item/<#>/content` - Return the raw content of the item (**streamed**). Optional params: `decompress=true|false` +- `GET /api/diff` - Diff two items. Params: `id_a`, `id_b` (individual items capped at 10 MB) + +### Server Configuration +- `max_body_size` - Maximum POST body size in bytes (default: unlimited). When exceeded, server returns `413 PAYLOAD_TOO_LARGE` while keeping the partial item already saved through the streaming pipeline. Set to `0` for unlimited. ### Server Modes - **Plain HTTP** (default): `tokio::net::TcpListener` + `axum::serve()` @@ -149,6 +152,8 @@ - Dumb clients (curl) use defaults (`compress=true`, `meta=true`), server handles everything - GET responses include `X-Keep-Compression` header when `decompress=false` - Streaming save uses chunked transfer encoding for constant memory usage +- **Universal streaming**: All server paths (POST, GET, diff) use `PIPESIZE` (8192) byte buffers +- **413 partial item**: When `max_body_size` is exceeded, the server returns `413` but keeps the partial item already saved through the pipeline (nonfatal design — pipes continue normally) ### Authentication - Bearer token authentication: `Authorization: Bearer ` @@ -207,12 +212,19 @@ - TLS/HTTPS support via rustls when certificate and key are provided - Proper resource cleanup using RAII patterns - Safe handling of external processes with proper stdin/stdout management +- **Streaming architecture**: All server I/O uses fixed-size 8192-byte buffers; no full file contents held in memory +- **XSS protection**: All user-controlled data in HTML pages is escaped via `html-escape` +- **Security headers**: `X-Content-Type-Options: nosniff`, `X-Frame-Options: DENY`, `Referrer-Policy: strict-origin-when-cross-origin` +- **CORS**: Explicit allowed headers only (`Content-Type`, `Authorization`, `Accept`); no wildcard headers +- **Input limits**: Tags (256 chars), metadata keys (128 chars), metadata values (4096 chars), pagination (10,000 max) +- **Config file size**: 4 KB cap with `from_utf8_lossy` for safe UTF-8 handling +- **Error sanitization**: Internal errors never exposed in HTML responses +- **No `unsafe_code`**: Enforced via `#![deny(unsafe_code)]` (exceptions: `libc::umask` in main.rs, `unsafe impl Send` for `SendCookie` in magic_file.rs) ## Feature Flags - `server` - HTTP REST API server (axum-based) - `tls` - HTTPS/TLS support for server (axum-server + rustls) - `client` - HTTP client for remote server (ureq-based, includes streaming save) -- `mcp` - Model Context Protocol for AI assistant integration - `swagger` - OpenAPI/Swagger UI documentation - `magic` - File type detection via libmagic - `lz4` - LZ4 compression (internal) diff --git a/Dockerfile b/Dockerfile index 84c8aba..a71d601 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,7 +23,7 @@ RUN cargo fetch --target x86_64-unknown-linux-musl # magic feature excluded (requires shared libmagic; fallback uses `file` command) COPY src/ src/ RUN cargo build --release --target x86_64-unknown-linux-musl \ - --no-default-features --features lz4,gzip,server,mcp,swagger,client,tls \ + --no-default-features --features lz4,gzip,server,swagger,client,tls \ && strip target/x86_64-unknown-linux-musl/release/keep # Runtime stage - scratch since binary is fully static diff --git a/README.md b/README.md index f936e34..740dd85 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,6 @@ keep --get api-data - [Server Mode](#server-mode) - [Client Mode](#client-mode) - [API Endpoints](#api-endpoints) -- [MCP (Model Context Protocol)](#mcp-model-context-protocol) - [Shell Integration](#shell-integration) - [Feature Flags](#feature-flags) - [License](#license) @@ -46,7 +45,6 @@ keep --get api-data - **Filters** — Apply transformations (head, tail, grep, strip ANSI) on retrieval - **Querying** — List, search, diff items with flexible formatting - **Client/server architecture** — Optional HTTP server with streaming support -- **MCP support** — Model Context Protocol integration for AI assistants - **Modular design** — Extensible plugin system for compression, metadata, and filtering ## Installation @@ -82,7 +80,7 @@ cargo build --release --features server cargo build --release --features client # Server + client + all optional features -cargo build --release --features server,tls,client,swagger,mcp +cargo build --release --features server,client,swagger ``` ## Quick Start @@ -356,6 +354,7 @@ KEEP_META_build=1234 echo "data" | keep --save tag --meta env=staging | `KEEP_SERVER_PASSWORD_HASH` | Server password hash | none | | `KEEP_SERVER_JWT_SECRET` | JWT secret for token auth | none | | `KEEP_SERVER_JWT_SECRET_FILE` | Path to JWT secret file | none | +| `KEEP_SERVER_MAX_BODY_SIZE` | Maximum POST body size in bytes (0=unlimited) | unlimited | | `KEEP_SERVER_CERT` | TLS certificate file path (PEM) | none | | `KEEP_SERVER_KEY` | TLS private key file path (PEM) | none | | `KEEP_CLIENT_URL` | Remote keep server URL | none | @@ -416,6 +415,8 @@ server: port: 21080 username: "keep" password: "secret" + # Maximum POST body size in bytes (0 = unlimited) + # max_body_size: 52428800 # 50 MB # JWT authentication (takes priority over password) # jwt_secret: "my-secret-key" # jwt_secret_file: /path/to/jwt_secret @@ -612,6 +613,33 @@ keep --client-url https://localhost:21080 --save my-tag The server accepts data from both dumb clients (raw HTTP/curl) and smart clients (the keep CLI). +#### Server Streaming + +The server streams all data through fixed-size buffers (8192 bytes). At no point is the entire file content held in memory. + +- **POST**: Body streams through the compression and storage pipeline in chunks. When `max_body_size` is exceeded, the server returns `413 PAYLOAD_TOO_LARGE` while keeping the partial item already saved through the pipeline. +- **GET**: Content streams from disk through decompression to the client using the same fixed-size buffers. +- **Diff**: Individual items are capped at 10 MB for the diff endpoint to prevent unbounded memory use. + +##### Max Body Size + +Control the maximum accepted body size with: + +```sh +# Via CLI flag (bytes) +keep --server --server-max-body-size 52428800 + +# Via environment variable +export KEEP_SERVER__MAX_BODY_SIZE=52428800 +keep --server + +# Via config file (config.yml) +server: + max_body_size: 52428800 # 50 MB +``` + +When set to `0` or omitted, no limit is enforced. + #### Server Query Parameters The server supports query parameters that control processing: @@ -696,7 +724,7 @@ Client save uses a 3-thread streaming pipeline for constant memory usage regardl - **Streamer thread**: Reads compressed bytes from pipe, streams to server via chunked HTTP POST - **Main thread**: After streaming completes, sends computed metadata (digest, hostname, size) to server -Memory usage is O(PIPESIZE) — typically 64KB — regardless of how much data is being stored. +Memory usage is O(PIPESIZE) — typically 8 KB — regardless of how much data is being stored. #### Example: Remote Pipeline @@ -769,25 +797,16 @@ cargo build --features server,swagger Swagger UI available at `/swagger`, OpenAPI spec at `/openapi.json`. -## MCP (Model Context Protocol) +#### Security -AI assistant integration via the Model Context Protocol. Enable with the `mcp` feature. +The server applies the following security measures: -```sh -cargo build --features server,mcp -``` - -MCP endpoint available at `/mcp/sse` when the server is running. - -### Available Tools - -| Tool | Description | Parameters | -|------|-------------|------------| -| `save_item` | Save new content | `content`, `tags[]`, `metadata{}` | -| `get_item` | Get item by ID | `id` | -| `get_latest_item` | Get latest item | `tags[]` | -| `list_items` | List items | `tags[]`, `limit`, `offset` | -| `search_items` | Search items | `tags[]`, `metadata{}` | +- **Input validation**: Item IDs are validated as positive integers; tags and metadata have length limits (256 and 128 characters respectively). +- **XSS protection**: All user-controlled data rendered into HTML pages is escaped. +- **Security headers**: Responses include `X-Content-Type-Options: nosniff`, `X-Frame-Options: DENY`, and `Referrer-Policy: strict-origin-when-cross-origin`. +- **CORS**: Explicit allowed headers (`Content-Type`, `Authorization`, `Accept`); no wildcard headers. +- **Path traversal**: Item IDs are validated to prevent directory traversal attacks. +- **Internal errors**: Internal error details are never exposed in HTML responses — only generic messages are shown. ## Shell Integration @@ -821,7 +840,6 @@ curl -s api.example.com | @ api-response | `server` | No | HTTP REST API server | | `tls` | No | HTTPS/TLS server support (requires `server`) | | `client` | No | HTTP client for remote server | -| `mcp` | No | Model Context Protocol support | | `swagger` | No | Swagger UI for API docs | | `bzip2` | No | BZip2 compression (external program) | | `xz` | No | XZ compression (external program) | @@ -838,7 +856,7 @@ cargo build --features server,tls cargo build --features client # Everything -cargo build --features server,tls,client,mcp,swagger,magic +cargo build --features server,tls,client,swagger,magic ``` ## License diff --git a/src/args.rs b/src/args.rs index 2cb754e..08a9e53 100644 --- a/src/args.rs +++ b/src/args.rs @@ -165,6 +165,10 @@ pub struct OptionsArgs { #[arg(help("Path to file containing JWT secret (requires --server)"))] pub server_jwt_secret_file: Option, + #[arg(long, env("KEEP_SERVER_MAX_BODY_SIZE"))] + #[arg(help("Maximum request body size in bytes (requires --server, default: unlimited)"))] + pub server_max_body_size: Option, + #[cfg(feature = "client")] #[arg(long, env("KEEP_CLIENT_URL"), help_heading("Client Options"))] #[arg(help("Remote keep server URL for client mode"))] diff --git a/src/common/schema.rs b/src/common/schema.rs index 62c1eea..4a12644 100644 --- a/src/common/schema.rs +++ b/src/common/schema.rs @@ -125,7 +125,7 @@ pub fn gather_meta_plugin_schemas() -> Vec { pub fn gather_filter_plugin_schemas() -> Vec { use crate::services::filter_service::get_available_filter_plugins; - let plugins = get_available_filter_plugins(); + let plugins = get_available_filter_plugins().unwrap_or_default(); let mut schemas: Vec = plugins .into_iter() .map(|(name, creator)| { diff --git a/src/common/status.rs b/src/common/status.rs index 588da9f..f3615ba 100644 --- a/src/common/status.rs +++ b/src/common/status.rs @@ -27,6 +27,22 @@ pub struct StatusInfo { pub configured_meta_plugins: Option>, } +impl Default for StatusInfo { + fn default() -> Self { + Self { + paths: PathInfo { + data: String::new(), + database: String::new(), + }, + compression: Vec::new(), + meta_plugins: std::collections::HashMap::new(), + enabled_meta_plugins: Vec::new(), + filter_plugins: Vec::new(), + configured_meta_plugins: None, + } + } +} + #[derive(serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "server", derive(ToSchema))] pub struct PathInfo { @@ -59,17 +75,17 @@ pub fn generate_status_info( db_path: PathBuf, enabled_meta_plugins: &[MetaPluginType], enabled_compression_type: Option, -) -> StatusInfo { +) -> anyhow::Result { log::debug!("STATUS: Starting status info generation"); let path_info = PathInfo { data: data_path .into_os_string() .into_string() - .expect("Unable to convert data path to string"), + .map_err(|_| anyhow::anyhow!("Unable to convert data path to string"))?, database: db_path .into_os_string() .into_string() - .expect("Unable to convert DB path to string"), + .map_err(|_| anyhow::anyhow!("Unable to convert DB path to string"))?, }; let _default_type = crate::compression_engine::default_compression_type(); @@ -183,7 +199,7 @@ pub fn generate_status_info( } // Populate filter plugin info from the global registry - let filter_plugins_map = crate::services::filter_service::get_available_filter_plugins(); + let filter_plugins_map = crate::services::filter_service::get_available_filter_plugins()?; let filter_plugins_info: Vec = filter_plugins_map .into_iter() .map(|(name, creator)| { @@ -196,12 +212,12 @@ pub fn generate_status_info( }) .collect(); - StatusInfo { + Ok(StatusInfo { paths: path_info, compression: compression_info, meta_plugins: meta_plugins_map, enabled_meta_plugins: enabled_meta_plugins_vec, filter_plugins: filter_plugins_info, configured_meta_plugins: None, - } + }) } diff --git a/src/config.rs b/src/config.rs index a70375e..db136c0 100644 --- a/src/config.rs +++ b/src/config.rs @@ -152,6 +152,7 @@ pub struct ServerConfig { pub cert_file: Option, pub key_file: Option, pub cors_origin: Option, + pub max_body_size: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -256,7 +257,11 @@ impl Settings { // Override with CLI args if let Some(dir) = &args.options.dir { debug!("CONFIG: Overriding dir with CLI arg: {dir:?}"); - config_builder = config_builder.set_override("dir", dir.to_str().unwrap())?; + config_builder = config_builder.set_override( + "dir", + dir.to_str() + .ok_or_else(|| anyhow::anyhow!("non-UTF-8 directory path"))?, + )?; } if args.options.human_readable { @@ -316,6 +321,10 @@ impl Settings { .set_override("server.key_file", server_key.to_string_lossy().as_ref())?; } + if let Some(max_body_size) = args.options.server_max_body_size { + config_builder = config_builder.set_override("server.max_body_size", max_body_size)?; + } + if let Some(compression) = &args.item.compression { config_builder = config_builder.set_override("compression_plugin.name", compression.as_str())?; @@ -486,10 +495,10 @@ impl Settings { // First check for password_file if let Some(password_file) = &server.password_file { debug!("CONFIG: Reading password from file: {password_file:?}"); - let password = fs::read_to_string(password_file) - .with_context(|| format!("Failed to read password file: {password_file:?}"))? - .trim() - .to_string(); + let password = fs::read(password_file) + .with_context(|| format!("Failed to read password file: {password_file:?}"))?; + let end = password.len().min(4096); + let password = String::from_utf8_lossy(&password[..end]).trim().to_string(); return Ok(Some(password)); } @@ -521,12 +530,11 @@ impl Settings { // First check for jwt_secret_file if let Some(jwt_secret_file) = &server.jwt_secret_file { debug!("CONFIG: Reading JWT secret from file: {jwt_secret_file:?}"); - let secret = fs::read_to_string(jwt_secret_file) - .with_context(|| { - format!("Failed to read JWT secret file: {jwt_secret_file:?}") - })? - .trim() - .to_string(); + let secret = fs::read(jwt_secret_file).with_context(|| { + format!("Failed to read JWT secret file: {jwt_secret_file:?}") + })?; + let end = secret.len().min(4096); + let secret = String::from_utf8_lossy(&secret[..end]).trim().to_string(); return Ok(Some(secret)); } diff --git a/src/filter_plugin/exec.rs b/src/filter_plugin/exec.rs index ec4b4ab..e2434e0 100644 --- a/src/filter_plugin/exec.rs +++ b/src/filter_plugin/exec.rs @@ -224,5 +224,6 @@ fn register_exec_filter() { stdin_writer: None, stdout_reader: None, }) - }); + }) + .expect("Failed to register exec filter"); } diff --git a/src/filter_plugin/head.rs b/src/filter_plugin/head.rs index ec5de92..b8520fe 100644 --- a/src/filter_plugin/head.rs +++ b/src/filter_plugin/head.rs @@ -283,6 +283,8 @@ impl FilterPlugin for HeadLinesFilter { // Register the plugin at module initialization time #[ctor::ctor] fn register_head_filters() { - register_filter_plugin("head_bytes", || Box::new(HeadBytesFilter::new(0))); - register_filter_plugin("head_lines", || Box::new(HeadLinesFilter::new(0))); + register_filter_plugin("head_bytes", || Box::new(HeadBytesFilter::new(0))) + .expect("Failed to register head_bytes filter"); + register_filter_plugin("head_lines", || Box::new(HeadLinesFilter::new(0))) + .expect("Failed to register head_lines filter"); } diff --git a/src/filter_plugin/skip.rs b/src/filter_plugin/skip.rs index 84ebc36..fd6a3f2 100644 --- a/src/filter_plugin/skip.rs +++ b/src/filter_plugin/skip.rs @@ -150,6 +150,8 @@ impl FilterPlugin for SkipLinesFilter { // Register the plugin at module initialization time #[ctor::ctor] fn register_skip_filters() { - register_filter_plugin("skip_bytes", || Box::new(SkipBytesFilter::new(0))); - register_filter_plugin("skip_lines", || Box::new(SkipLinesFilter::new(0))); + register_filter_plugin("skip_bytes", || Box::new(SkipBytesFilter::new(0))) + .expect("Failed to register skip_bytes filter"); + register_filter_plugin("skip_lines", || Box::new(SkipLinesFilter::new(0))) + .expect("Failed to register skip_lines filter"); } diff --git a/src/filter_plugin/tail.rs b/src/filter_plugin/tail.rs index 50ef25d..321142b 100644 --- a/src/filter_plugin/tail.rs +++ b/src/filter_plugin/tail.rs @@ -169,6 +169,8 @@ impl FilterPlugin for TailLinesFilter { // Register the plugin at module initialization time #[ctor::ctor] fn register_tail_filters() { - register_filter_plugin("tail_bytes", || Box::new(TailBytesFilter::new(0))); - register_filter_plugin("tail_lines", || Box::new(TailLinesFilter::new(0))); + register_filter_plugin("tail_bytes", || Box::new(TailBytesFilter::new(0))) + .expect("Failed to register tail_bytes filter"); + register_filter_plugin("tail_lines", || Box::new(TailLinesFilter::new(0))) + .expect("Failed to register tail_lines filter"); } diff --git a/src/filter_plugin/tokens.rs b/src/filter_plugin/tokens.rs index b6aa59b..899dfd3 100644 --- a/src/filter_plugin/tokens.rs +++ b/src/filter_plugin/tokens.rs @@ -377,9 +377,12 @@ fn map_lossy_pos_to_bytes(original: &[u8], lossy: &str, lossy_pos: usize) -> usi #[ctor::ctor] fn register_token_filters() { - register_filter_plugin("head_tokens", || Box::new(HeadTokensFilter::new(0))); - register_filter_plugin("skip_tokens", || Box::new(SkipTokensFilter::new(0))); - register_filter_plugin("tail_tokens", || Box::new(TailTokensFilter::new(0))); + register_filter_plugin("head_tokens", || Box::new(HeadTokensFilter::new(0))) + .expect("Failed to register head_tokens filter"); + register_filter_plugin("skip_tokens", || Box::new(SkipTokensFilter::new(0))) + .expect("Failed to register skip_tokens filter"); + register_filter_plugin("tail_tokens", || Box::new(TailTokensFilter::new(0))) + .expect("Failed to register tail_tokens filter"); } #[cfg(test)] diff --git a/src/meta_plugin/cwd.rs b/src/meta_plugin/cwd.rs index a63c95c..70601a1 100644 --- a/src/meta_plugin/cwd.rs +++ b/src/meta_plugin/cwd.rs @@ -128,5 +128,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_cwd_plugin() { register_meta_plugin(MetaPluginType::Cwd, |options, outputs| { Box::new(CwdMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register CwdMetaPlugin"); } diff --git a/src/meta_plugin/digest.rs b/src/meta_plugin/digest.rs index 8ae60c3..2753ec0 100644 --- a/src/meta_plugin/digest.rs +++ b/src/meta_plugin/digest.rs @@ -271,5 +271,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_digest_plugin() { register_meta_plugin(MetaPluginType::Digest, |options, outputs| { Box::new(DigestMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register DigestMetaPlugin"); } diff --git a/src/meta_plugin/env.rs b/src/meta_plugin/env.rs index 91f16e1..a07afaf 100644 --- a/src/meta_plugin/env.rs +++ b/src/meta_plugin/env.rs @@ -227,5 +227,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_env_plugin() { register_meta_plugin(MetaPluginType::Env, |options, outputs| { Box::new(EnvMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register EnvMetaPlugin"); } diff --git a/src/meta_plugin/exec.rs b/src/meta_plugin/exec.rs index b87d6af..0fb478c 100644 --- a/src/meta_plugin/exec.rs +++ b/src/meta_plugin/exec.rs @@ -311,5 +311,6 @@ fn register_exec_plugin() { options, outputs, )) - }); + }) + .expect("Failed to register ExecMetaPlugin"); } diff --git a/src/meta_plugin/hostname.rs b/src/meta_plugin/hostname.rs index d4e2f67..dedf88f 100644 --- a/src/meta_plugin/hostname.rs +++ b/src/meta_plugin/hostname.rs @@ -406,5 +406,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_hostname_plugin() { register_meta_plugin(MetaPluginType::Hostname, |options, outputs| { Box::new(HostnameMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register HostnameMetaPlugin"); } diff --git a/src/meta_plugin/keep_pid.rs b/src/meta_plugin/keep_pid.rs index 5a18f6f..be10150 100644 --- a/src/meta_plugin/keep_pid.rs +++ b/src/meta_plugin/keep_pid.rs @@ -204,5 +204,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_keep_pid_plugin() { register_meta_plugin(MetaPluginType::KeepPid, |options, outputs| { Box::new(KeepPidMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register KeepPidMetaPlugin"); } diff --git a/src/meta_plugin/magic_file.rs b/src/meta_plugin/magic_file.rs index 8fffe19..5a7290d 100644 --- a/src/meta_plugin/magic_file.rs +++ b/src/meta_plugin/magic_file.rs @@ -457,5 +457,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_magic_file_plugin() { register_meta_plugin(MetaPluginType::MagicFile, |options, outputs| { Box::new(MagicFileMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register MagicFileMetaPlugin"); } diff --git a/src/meta_plugin/mod.rs b/src/meta_plugin/mod.rs index 1ca2bc5..36d61dc 100644 --- a/src/meta_plugin/mod.rs +++ b/src/meta_plugin/mod.rs @@ -578,11 +578,15 @@ static META_PLUGIN_REGISTRY: Lazy anyhow::Result<()> { META_PLUGIN_REGISTRY .lock() - .unwrap() + .map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}"))? .insert(meta_plugin_type, constructor); + Ok(()) } pub fn get_meta_plugin( @@ -590,7 +594,9 @@ pub fn get_meta_plugin( options: Option>, outputs: Option>, ) -> anyhow::Result> { - let registry = META_PLUGIN_REGISTRY.lock().unwrap(); + let registry = META_PLUGIN_REGISTRY + .lock() + .map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}"))?; if let Some(constructor) = registry.get(&meta_plugin_type) { return Ok(constructor(options, outputs)); } diff --git a/src/meta_plugin/read_rate.rs b/src/meta_plugin/read_rate.rs index 265ae23..802d494 100644 --- a/src/meta_plugin/read_rate.rs +++ b/src/meta_plugin/read_rate.rs @@ -237,5 +237,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_read_rate_plugin() { register_meta_plugin(MetaPluginType::ReadRate, |options, outputs| { Box::new(ReadRateMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register ReadRateMetaPlugin"); } diff --git a/src/meta_plugin/read_time.rs b/src/meta_plugin/read_time.rs index f01832e..a7369f1 100644 --- a/src/meta_plugin/read_time.rs +++ b/src/meta_plugin/read_time.rs @@ -124,5 +124,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_read_time_plugin() { register_meta_plugin(MetaPluginType::ReadTime, |options, outputs| { Box::new(ReadTimeMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register ReadTimeMetaPlugin"); } diff --git a/src/meta_plugin/shell.rs b/src/meta_plugin/shell.rs index cb84070..e971d11 100644 --- a/src/meta_plugin/shell.rs +++ b/src/meta_plugin/shell.rs @@ -240,5 +240,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_shell_plugin() { register_meta_plugin(MetaPluginType::Shell, |options, outputs| { Box::new(ShellMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register ShellMetaPlugin"); } diff --git a/src/meta_plugin/shell_pid.rs b/src/meta_plugin/shell_pid.rs index 0c98621..dc5458d 100644 --- a/src/meta_plugin/shell_pid.rs +++ b/src/meta_plugin/shell_pid.rs @@ -132,5 +132,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_shell_pid_plugin() { register_meta_plugin(MetaPluginType::ShellPid, |options, outputs| { Box::new(ShellPidMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register ShellPidMetaPlugin"); } diff --git a/src/meta_plugin/text.rs b/src/meta_plugin/text.rs index 395a235..101346a 100644 --- a/src/meta_plugin/text.rs +++ b/src/meta_plugin/text.rs @@ -818,5 +818,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_text_plugin() { register_meta_plugin(MetaPluginType::Text, |options, outputs| { Box::new(TextMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register TextMetaPlugin"); } diff --git a/src/meta_plugin/tokens.rs b/src/meta_plugin/tokens.rs index 13d3998..837b116 100644 --- a/src/meta_plugin/tokens.rs +++ b/src/meta_plugin/tokens.rs @@ -312,5 +312,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_tokens_plugin() { register_meta_plugin(MetaPluginType::Tokens, |options, outputs| { Box::new(TokensMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register TokensMetaPlugin"); } diff --git a/src/meta_plugin/user.rs b/src/meta_plugin/user.rs index 0bb529e..f6fbf55 100644 --- a/src/meta_plugin/user.rs +++ b/src/meta_plugin/user.rs @@ -166,5 +166,6 @@ use crate::meta_plugin::register_meta_plugin; fn register_user_plugin() { register_meta_plugin(MetaPluginType::User, |options, outputs| { Box::new(UserMetaPlugin::new(options, outputs)) - }); + }) + .expect("Failed to register UserMetaPlugin"); } diff --git a/src/modes/client/save.rs b/src/modes/client/save.rs index 211a7e8..bea4c19 100644 --- a/src/modes/client/save.rs +++ b/src/modes/client/save.rs @@ -92,7 +92,7 @@ pub fn mode( let digest = format!("{:x}", hasher.finalize()); // Set shared state for main thread - let mut shared = shared_reader.lock().unwrap(); + let mut shared = shared_reader.lock().expect("client save mutex poisoned"); *shared = (total_bytes, digest.clone()); Ok((total_bytes, digest)) @@ -135,7 +135,7 @@ pub fn mode( // Read results from shared state let (uncompressed_size, digest) = { - let shared = shared.lock().unwrap(); + let shared = shared.lock().expect("client save mutex poisoned"); shared.clone() }; diff --git a/src/modes/info.rs b/src/modes/info.rs index 5692a42..5567074 100644 --- a/src/modes/info.rs +++ b/src/modes/info.rs @@ -162,7 +162,7 @@ fn show_item( item_path_buf.push(item_id.to_string()); let path_str = item_path_buf .to_str() - .expect("Unable to get item path") + .ok_or_else(|| anyhow::anyhow!("non-UTF-8 item path"))? .to_string(); table.add_row(vec![ Cell::new("Path").add_attribute(Attribute::Bold), diff --git a/src/modes/list.rs b/src/modes/list.rs index 69b9158..48fa277 100644 --- a/src/modes/list.rs +++ b/src/modes/list.rs @@ -240,7 +240,11 @@ pub fn mode_list( Ok(metadata) => format_size(metadata.len(), settings.human_readable), Err(_) => "Missing".to_string(), }, - ColumnType::FilePath => item_path.clone().into_os_string().into_string().unwrap(), + ColumnType::FilePath => item_path + .clone() + .into_os_string() + .into_string() + .unwrap_or_else(|os| os.to_string_lossy().into_owned()), ColumnType::Tags => tags.join(" "), ColumnType::Meta => match meta_name { Some(meta_name) => match meta.get(meta_name) { diff --git a/src/modes/server/api/item.rs b/src/modes/server/api/item.rs index 5564fe0..8e62397 100644 --- a/src/modes/server/api/item.rs +++ b/src/modes/server/api/item.rs @@ -16,7 +16,9 @@ use axum::{ use http_body_util::BodyExt; use log::{debug, warn}; use std::collections::HashMap; -use std::io::{Cursor, Read}; +use std::io::Read; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::sync::mpsc; use tokio::task; @@ -24,14 +26,15 @@ use tokio::task; /// /// Used in `spawn_blocking` contexts to consume data from an async body /// stream as a regular reader. Blocks on each `read()` call until data -/// is available or the channel is closed. -struct ChannelReader { +/// is available or the channel is closed. No internal buffering — each +/// channel message is consumed fully before the next is fetched. +struct MpscReader { rx: mpsc::Receiver, std::io::Error>>, current: Vec, pos: usize, } -impl ChannelReader { +impl MpscReader { fn new(rx: mpsc::Receiver, std::io::Error>>) -> Self { Self { rx, @@ -41,9 +44,8 @@ impl ChannelReader { } } -impl Read for ChannelReader { +impl Read for MpscReader { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - // If we have buffered data, return it first if self.pos < self.current.len() { let remaining = &self.current[self.pos..]; let n = std::cmp::min(buf.len(), remaining.len()); @@ -52,7 +54,6 @@ impl Read for ChannelReader { return Ok(n); } - // Need more data from the channel - block until available match self.rx.blocking_recv() { Some(Ok(data)) => { let n = std::cmp::min(buf.len(), data.len()); @@ -64,7 +65,7 @@ impl Read for ChannelReader { Ok(n) } Some(Err(e)) => Err(e), - None => Ok(0), // Channel closed, EOF + None => Ok(0), } } } @@ -116,23 +117,6 @@ fn get_mime_type(metadata: &HashMap) -> String { .unwrap_or_else(|| "application/octet-stream".to_string()) } -/// Helper function to apply offset and length to content -fn apply_offset_length(content: &[u8], offset: u64, length: u64) -> &[u8] { - let content_len = content.len() as u64; - let start = std::cmp::min(offset, content_len); - let end = if length > 0 { - std::cmp::min(start + length, content_len) - } else { - content_len - }; - - if start < content_len { - &content[start as usize..end as usize] - } else { - &[] - } -} - /// Helper function to handle item not found errors fn handle_item_error(error: CoreError) -> StatusCode { match error { @@ -203,7 +187,7 @@ pub async fn handle_list_items( // Apply pagination let start = params.start.unwrap_or(0) as usize; - let count = params.count.unwrap_or(100) as usize; + let count = params.count.unwrap_or(100).min(10000) as usize; let items_with_meta: Vec<_> = items_with_meta .into_iter() .skip(start) @@ -253,7 +237,7 @@ async fn handle_as_meta_response( handle_as_meta_response_with_metadata(data_service, item_id, &metadata, offset, length).await } -/// Handle as_meta=true response with pre-fetched metadata +/// Handle as_meta=true response with pre-fetched metadata using streaming async fn handle_as_meta_response_with_metadata( data_service: &AsyncDataService, item_id: i64, @@ -279,61 +263,121 @@ async fn handle_as_meta_response_with_metadata( .body(axum::body::Body::from(response_body.to_string())) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } else { - // Get the content as text - match data_service.get_item_content_info(item_id, None).await { - Ok((content, _, _)) => { - // Apply offset and length - let content_len = content.len() as u64; - let start = std::cmp::min(offset, content_len); - let end = if length > 0 { - std::cmp::min(start + length, content_len) - } else { - content_len - }; + // Use streaming approach to read content + let db = data_service.db(); + let data_path = data_service.data_path().clone(); + let settings = data_service.settings(); - let response_content = if start < content_len { - &content[start as usize..end as usize] - } else { - &[] - }; + // Get streaming reader from sync service + let (reader, content_len_result) = tokio::task::spawn_blocking(move || { + let mut conn = db.blocking_lock(); + let sync_service = + crate::services::SyncDataService::new(data_path, settings.as_ref().clone()); + let (reader, item_with_meta) = sync_service.get_content(&mut conn, item_id)?; + let content_len = item_with_meta.item.size.unwrap_or(0); + Ok::<_, CoreError>((reader, content_len)) + }) + .await + .map_err(|e| { + warn!("Blocking task failed for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .map_err(|e| { + warn!("Failed to get content reader for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; - // Convert to UTF-8 string - let content_str = match String::from_utf8(response_content.to_vec()) { - Ok(s) => s, - Err(_) => { - // This shouldn't happen since we checked is_binary, but handle it just in case - let response_body = serde_json::json!({ - "metadata": metadata, - "content": serde_json::Value::Null, - "error": "Content is not valid UTF-8" - }); + let content_len = content_len_result as u64; - let response = Response::builder() - .header(header::CONTENT_TYPE, "application/json") - .status(StatusCode::UNPROCESSABLE_ENTITY) - .body(axum::body::Body::from(response_body.to_string())) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - return Ok(response); + // Calculate offset and length bounds + let start = std::cmp::min(offset, content_len); + let end = if length > 0 { + std::cmp::min(start + length, content_len) + } else { + content_len + }; + let response_len = end - start; + + // Read content with offset and length using fixed-size buffer + let content_str = tokio::task::spawn_blocking(move || { + let mut reader = reader; + let mut buf = [0u8; crate::common::PIPESIZE]; + let mut result = Vec::new(); + let mut bytes_read = 0u64; + + // Skip offset bytes + if offset > 0 { + let mut remaining = offset; + while remaining > 0 { + let to_read = std::cmp::min(remaining, buf.len() as u64) as usize; + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, + Ok(n) => remaining -= n as u64, + Err(e) => return Err(CoreError::Io(e)), } - }; + } + } - // Return JSON with metadata and content + // Read up to length bytes + let mut remaining = if length > 0 { length } else { u64::MAX }; + while remaining > 0 && bytes_read < response_len { + let to_read = std::cmp::min(remaining, buf.len() as u64) as usize; + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, + Ok(n) => { + result.extend_from_slice(&buf[..n]); + bytes_read += n as u64; + if length > 0 { + remaining -= n as u64; + } + } + Err(e) => return Err(CoreError::Io(e)), + } + } + + Ok::, CoreError>(result) + }) + .await + .map_err(|e| { + warn!("Blocking task failed for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .map_err(|e| { + warn!("Failed to read content for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Convert to UTF-8 string + let content_str = match String::from_utf8(content_str) { + Ok(s) => s, + Err(_) => { + // This shouldn't happen since we checked is_binary, but handle it just in case let response_body = serde_json::json!({ "metadata": metadata, - "content": content_str, - "error": serde_json::Value::Null + "content": serde_json::Value::Null, + "error": "Content is not valid UTF-8" }); - Response::builder() + let response = Response::builder() .header(header::CONTENT_TYPE, "application/json") + .status(StatusCode::UNPROCESSABLE_ENTITY) .body(axum::body::Body::from(response_body.to_string())) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + return Ok(response); } - Err(e) => { - warn!("Failed to get content for item {item_id}: {e}"); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } + }; + + // Return JSON with metadata and content + let response_body = serde_json::json!({ + "metadata": metadata, + "content": content_str, + "error": serde_json::Value::Null + }); + + Response::builder() + .header(header::CONTENT_TYPE, "application/json") + .body(axum::body::Body::from(response_body.to_string())) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } } @@ -380,6 +424,14 @@ pub async fn handle_post_item( .map(parse_comma_tags) .unwrap_or_default(); + // Validate tag lengths + for tag in &tags { + if tag.len() > 256 { + warn!("Tag too long: {} chars (max 256)", tag.len()); + return Err(StatusCode::BAD_REQUEST); + } + } + // Parse metadata from query parameter let metadata: HashMap = if let Some(ref meta_str) = params.metadata { serde_json::from_str(meta_str).map_err(|e| { @@ -390,91 +442,116 @@ pub async fn handle_post_item( HashMap::new() }; + // Validate metadata key/value lengths + for (key, value) in &metadata { + if key.len() > 128 { + warn!("Metadata key too long: {} chars (max 128)", key.len()); + return Err(StatusCode::BAD_REQUEST); + } + if value.len() > 4096 { + warn!("Metadata value too long: {} chars (max 4096)", value.len()); + return Err(StatusCode::BAD_REQUEST); + } + } + let compress = params.compress; let run_meta = params.meta; - // When server handles both compression and meta, save_item_with_reader - // buffers internally anyway, so collect body in memory. - // When client handles compression/meta, stream the body to avoid buffering. - let item_with_meta = if compress && run_meta { - let body_bytes = body - .collect() - .await - .map_err(|e| { - warn!("Failed to read request body: {e}"); - StatusCode::BAD_REQUEST - })? - .to_bytes(); + // Stream body through an mpsc channel with fixed-size frames. + // Size tracking ensures we never buffer the whole body in memory. + // Treat Some(0) as unlimited (None). + let max_body_size: Option = settings + .server + .as_ref() + .and_then(|s| s.max_body_size) + .filter(|&v| v > 0); - task::spawn_blocking(move || { - let mut conn = db.blocking_lock(); - let sync_service = - crate::services::SyncDataService::new(data_dir, settings.as_ref().clone()); - let mut cursor = Cursor::new(body_bytes.to_vec()); - sync_service.save_item_with_reader(&mut conn, &mut cursor, tags, metadata) - }) - .await - .map_err(|e| { - warn!("Failed to save item (task error): {e}"); - StatusCode::INTERNAL_SERVER_ERROR - })? - .map_err(|e| { - warn!("Failed to save item: {e}"); - StatusCode::INTERNAL_SERVER_ERROR - })? - } else { - // Stream body through a channel to avoid buffering in memory - let (tx, rx) = tokio::sync::mpsc::channel::, std::io::Error>>(16); + let (tx, rx) = mpsc::channel::, std::io::Error>>(16); + let body_truncated = Arc::new(AtomicBool::new(false)); + let truncated_flag = body_truncated.clone(); - // Task to read body frames and send through channel - tokio::spawn(async move { - let mut body = body; - loop { - match body.frame().await { - None => break, // Body complete - Some(Err(e)) => { - let _ = tx - .send(Err(std::io::Error::other(format!("Body error: {e}")))) - .await; - break; - } - Some(Ok(frame)) => { - if let Ok(data) = frame.into_data() { - if tx.send(Ok(data.to_vec())).await.is_err() { - break; // Receiver dropped - } + // Async task: read body frames, track size, stop when limit exceeded + tokio::spawn(async move { + let mut body = body; + let mut total_bytes: u64 = 0; + loop { + match body.frame().await { + None => break, + Some(Err(e)) => { + let _ = tx + .send(Err(std::io::Error::other(format!("Body error: {e}")))) + .await; + break; + } + Some(Ok(frame)) => { + if let Ok(data) = frame.into_data() { + total_bytes += data.len() as u64; + if let Some(limit) = max_body_size + && total_bytes > limit + { + truncated_flag.store(true, Ordering::Relaxed); + break; // Drop sender → reader sees EOF + } + if tx.send(Ok(data.to_vec())).await.is_err() { + break; } } } } - }); + } + }); - task::spawn_blocking(move || { - let mut conn = db.blocking_lock(); - let sync_service = - crate::services::SyncDataService::new(data_dir, settings.as_ref().clone()); + // Blocking task: consume streaming reader, save via save_item_raw_streaming + let item_with_meta = task::spawn_blocking(move || { + let mut conn = db.blocking_lock(); + let sync_service = + crate::services::SyncDataService::new(data_dir, settings.as_ref().clone()); + let mut reader = MpscReader::new(rx); + sync_service.save_item_raw_streaming( + &mut conn, + &mut reader, + tags, + metadata, + compress, + run_meta, + ) + }) + .await + .map_err(|e| { + warn!("Failed to save item (task error): {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .map_err(|e| { + warn!("Failed to save item: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; - // Convert async mpsc receiver into a sync Read - let mut stream_reader = ChannelReader::new(rx); - sync_service.save_item_raw_streaming( - &mut conn, - &mut stream_reader, - tags, - metadata, - compress, - run_meta, - ) - }) - .await - .map_err(|e| { - warn!("Failed to save item (task error): {e}"); - StatusCode::INTERNAL_SERVER_ERROR - })? - .map_err(|e| { - warn!("Failed to save item: {e}"); - StatusCode::INTERNAL_SERVER_ERROR - })? - }; + // If body was truncated due to size limit, return 413 but keep the partial item. + // + // Design rationale for preserving truncated data: + // + // 1. Keep is a temporary file manager — partial content is still useful data. + // The client sent it, the server received it, and it was written through the + // normal save pipeline (compression, meta plugins). Discarding it would waste + // work already done and lose data the user may want to inspect or retry from. + // + // 2. The streaming pipeline is nonfatal by design. The sender stops pushing + // frames, the channel closes cleanly (sender drop → reader sees EOF), and + // save_item_raw_streaming finishes with whatever it received. No corruption, + // no partial writes to the storage file — the compression engine flushes and + // closes normally. + // + // 3. No destructive rollback. Deleting the item would require either a second + // DB transaction (fragile under concurrency) or leaving orphaned files on + // disk. Keeping the item is simpler and safer. + // + // 4. The 413 response tells the client exactly what happened. The client can + // use GET /api/item/ to list items and find the partial one by timestamp + // if it needs to resume or clean up. + // + if body_truncated.load(Ordering::Relaxed) { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } let compression = item_with_meta.item.compression.clone(); let tags = item_with_meta.tags.iter().map(|t| t.name.clone()).collect(); @@ -696,37 +773,110 @@ async fn stream_item_content_response( /// Stream raw (unprocessed) content directly from the item file. /// -/// Returns the stored file bytes without decompression or filtering. +/// Returns the stored file bytes without decompression or filtering using streaming. async fn stream_raw_content_response( data_service: &AsyncDataService, item_id: i64, offset: u64, length: u64, ) -> Result { - // Get item info to find the file path and compression type + // Get item info to find the compression type let item_with_meta = data_service.get_item(item_id).await.map_err(|e| { warn!("Failed to get item {item_id} for raw content: {e}"); StatusCode::INTERNAL_SERVER_ERROR })?; let compression = item_with_meta.item.compression.clone(); + let content_size = item_with_meta.item.size.unwrap_or(0); - // Read raw file bytes - let content = data_service - .get_raw_item_content(item_id) + // Get streaming reader for raw content + let reader = data_service + .get_raw_item_content_reader(item_id) .await .map_err(|e| { - warn!("Failed to get raw content for item {item_id}: {e}"); + warn!("Failed to get raw content reader for item {item_id}: {e}"); StatusCode::INTERNAL_SERVER_ERROR })?; - let response_content = apply_offset_length(&content, offset, length); + // Calculate the actual response length + let content_len = content_size as u64; + let start = std::cmp::min(offset, content_len); + let end = if length > 0 { + std::cmp::min(start + length, content_len) + } else { + content_len + }; + let response_len = end - start; + + // Create a channel to stream data between blocking thread and async runtime + let (tx, rx) = tokio::sync::mpsc::channel::, std::io::Error>>(16); + + // Spawn blocking task to read with offset and length + tokio::task::spawn_blocking(move || { + let mut reader = reader; + let mut buf = [0u8; crate::common::PIPESIZE]; + + // Apply offset by reading and discarding bytes + if offset > 0 { + let mut remaining = offset; + while remaining > 0 { + let to_read = std::cmp::min(remaining, buf.len() as u64) as usize; + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, + Ok(n) => remaining -= n as u64, + Err(e) => { + let _ = tx.blocking_send(Err(e)); + return; + } + } + } + } + + // Read and send data up to the specified length + let mut remaining_length = length; + + loop { + let to_read = if length > 0 { + std::cmp::min(remaining_length, buf.len() as u64) as usize + } else { + buf.len() + }; + + if to_read == 0 { + break; + } + + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, + Ok(n) => { + let chunk = buf[..n].to_vec(); + if tx.blocking_send(Ok(chunk)).is_err() { + break; + } + if length > 0 { + remaining_length -= n as u64; + if remaining_length == 0 { + break; + } + } + } + Err(e) => { + let _ = tx.blocking_send(Err(e)); + break; + } + } + } + }); + + // Convert the receiver into a stream + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let body = axum::body::Body::from_stream(stream); let response = Response::builder() .header(header::CONTENT_TYPE, "application/octet-stream") .header("X-Keep-Compression", &compression) - .header(header::CONTENT_LENGTH, response_content.len()) - .body(axum::body::Body::from(response_content.to_vec())) + .header(header::CONTENT_LENGTH, response_len) + .body(body) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(response) @@ -776,24 +926,94 @@ async fn stream_item_content_response_with_metadata( } } } else { - debug!("NON-STREAMING: Building full response in memory"); - match data_service.get_item_content_info(item_id, None).await { - Ok((content, _, _)) => { - let response_content = apply_offset_length(&content, offset, length); + debug!("NON-STREAMING: Building full response in memory using streaming reader"); + // Use streaming approach even for non-streaming response + let db = data_service.db(); + let data_path = data_service.data_path().clone(); + let settings = data_service.settings(); - debug!( - "NON-STREAMING: Content length: {}, response length: {}", - content.len(), - response_content.len() - ); + // Get streaming reader from sync service + let (reader, content_len_result) = tokio::task::spawn_blocking(move || { + let mut conn = db.blocking_lock(); + let sync_service = + crate::services::SyncDataService::new(data_path, settings.as_ref().clone()); + let (reader, item_with_meta) = sync_service.get_content(&mut conn, item_id)?; + let content_len = item_with_meta.item.size.unwrap_or(0); + Ok::<_, CoreError>((reader, content_len)) + }) + .await + .map_err(|e| { + warn!("Blocking task failed for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .map_err(|e| { + warn!("Failed to get content reader for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; - ResponseBuilder::binary(response_content, &mime_type) + let content_len = content_len_result as u64; + + // Calculate offset and length bounds + let start = std::cmp::min(offset, content_len); + let end = if length > 0 { + std::cmp::min(start + length, content_len) + } else { + content_len + }; + let response_len = end - start; + + // Read content with offset and length using fixed-size buffer + let content = tokio::task::spawn_blocking(move || { + let mut reader = reader; + let mut buf = [0u8; crate::common::PIPESIZE]; + let mut result = Vec::new(); + let mut bytes_read = 0u64; + + // Skip offset bytes + if offset > 0 { + let mut remaining = offset; + while remaining > 0 { + let to_read = std::cmp::min(remaining, buf.len() as u64) as usize; + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, + Ok(n) => remaining -= n as u64, + Err(e) => return Err(CoreError::Io(e)), + } + } } - Err(e) => { - warn!("Failed to get content for item {item_id}: {e}"); - Err(StatusCode::INTERNAL_SERVER_ERROR) + + // Read up to length bytes + let mut remaining = if length > 0 { length } else { u64::MAX }; + while remaining > 0 && bytes_read < response_len { + let to_read = std::cmp::min(remaining, buf.len() as u64) as usize; + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, + Ok(n) => { + result.extend_from_slice(&buf[..n]); + bytes_read += n as u64; + if length > 0 { + remaining -= n as u64; + } + } + Err(e) => return Err(CoreError::Io(e)), + } } - } + + Ok::, CoreError>(result) + }) + .await + .map_err(|e| { + warn!("Blocking task failed for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .map_err(|e| { + warn!("Failed to read content for item {item_id}: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + debug!("NON-STREAMING: Response length: {}", content.len()); + + ResponseBuilder::binary(&content, &mime_type) } } @@ -870,6 +1090,10 @@ pub async fn handle_get_item_meta( State(state): State, Path(item_id): Path, ) -> Result>>, StatusCode> { + if item_id <= 0 { + return Err(StatusCode::BAD_REQUEST); + } + let data_service = create_data_service(&state); match data_service.get_item(item_id).await { @@ -893,6 +1117,10 @@ pub async fn handle_post_item_meta( Path(item_id): Path, Json(metadata): Json>, ) -> Result>, StatusCode> { + if item_id <= 0 { + return Err(StatusCode::BAD_REQUEST); + } + let data_service = create_data_service(&state); // Verify item exists @@ -942,6 +1170,10 @@ pub async fn handle_delete_item( State(state): State, Path(item_id): Path, ) -> Result>, StatusCode> { + if item_id <= 0 { + return Err(StatusCode::BAD_REQUEST); + } + let mut conn = state.db.lock().await; let sync_service = crate::services::SyncDataService::new( @@ -995,6 +1227,10 @@ pub async fn handle_get_item_info( State(state): State, Path(item_id): Path, ) -> Result>, StatusCode> { + if item_id <= 0 { + return Err(StatusCode::BAD_REQUEST); + } + let mut conn = state.db.lock().await; let sync_service = crate::services::SyncDataService::new( @@ -1097,6 +1333,28 @@ pub async fn handle_diff_items( let id_a = item_a.item.id.ok_or(StatusCode::BAD_REQUEST)?; let id_b = item_b.item.id.ok_or(StatusCode::BAD_REQUEST)?; + // Size limit for diff operation (10MB per item) + const MAX_DIFF_SIZE: i64 = 10 * 1024 * 1024; + + if let Some(size_a) = item_a.item.size + && size_a > MAX_DIFF_SIZE + { + warn!( + "Item A ({}) exceeds diff size limit: {} > {}", + id_a, size_a, MAX_DIFF_SIZE + ); + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + if let Some(size_b) = item_b.item.size + && size_b > MAX_DIFF_SIZE + { + warn!( + "Item B ({}) exceeds diff size limit: {} > {}", + id_b, size_b, MAX_DIFF_SIZE + ); + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + let (mut reader_a, _) = sync_service .get_content(&mut conn, id_a) .map_err(handle_item_error)?; diff --git a/src/modes/server/api/mcp.rs b/src/modes/server/api/mcp.rs deleted file mode 100644 index 467b457..0000000 --- a/src/modes/server/api/mcp.rs +++ /dev/null @@ -1,72 +0,0 @@ -use axum::{ - extract::State, - http::StatusCode, - response::sse::{Event, KeepAlive, Sse}, -}; -use futures::stream::{self, Stream}; -use log::{debug, info}; -use std::convert::Infallible; -use std::time::Duration; - -use crate::modes::server::common::AppState; -use crate::modes::server::mcp::KeepMcpServer; - -#[utoipa::path( - get, - path = "/mcp/sse", - operation_id = "mcp_sse", - summary = "MCP SSE endpoint", - description = "Server-Sent Events for Model Context Protocol. Enables AI tools to interact with Keep's storage and retrieval functions.", - responses( - (status = 200, description = "SSE stream established"), - (status = 401, description = "Unauthorized"), - (status = 500, description = "Internal server error") - ), - security( - ("bearerAuth" = []) - ), - tag = "mcp" -)] -pub async fn handle_mcp_sse( - State(state): State, -) -> Result>>, StatusCode> { - debug!("MCP: Starting SSE endpoint"); - - let _mcp_server = KeepMcpServer::new(state); - - // Create a simple message channel for SSE communication - let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); - - // Send initial connection message - let _ = tx.send("data: {\"type\":\"connection\",\"status\":\"connected\"}\n\n".to_string()); - - // For now, create a simple stream that sends periodic keep-alive messages - // In a full implementation, this would integrate with the rmcp transport layer - let stream = stream::unfold((rx, tx), |(mut rx, tx)| async move { - tokio::select! { - msg = rx.recv() => { - match msg { - Some(data) => { - let event = Event::default().data(data); - Some((Ok(event), (rx, tx))) - } - None => None, - } - } - _ = tokio::time::sleep(Duration::from_secs(30)) => { - let event = Event::default() - .event("keep-alive") - .data("ping"); - Some((Ok(event), (rx, tx))) - } - } - }); - - info!("MCP: SSE endpoint established"); - - Ok(Sse::new(stream).keep_alive( - KeepAlive::new() - .interval(Duration::from_secs(30)) - .text("keep-alive"), - )) -} diff --git a/src/modes/server/api/mod.rs b/src/modes/server/api/mod.rs index cbd78a9..fcb062a 100644 --- a/src/modes/server/api/mod.rs +++ b/src/modes/server/api/mod.rs @@ -1,7 +1,5 @@ pub mod common; pub mod item; -#[cfg(feature = "mcp")] -pub mod mcp; pub mod status; use axum::{ @@ -60,8 +58,7 @@ use utoipa_swagger_ui::SwaggerUi; struct ApiDoc; pub fn add_routes(router: Router) -> Router { - #[cfg_attr(not(feature = "mcp"), allow(unused_mut))] - let mut router = router + router // Status endpoints .route("/api/status", get(status::handle_status)) .route("/api/plugins/status", get(status::handle_plugins_status)) @@ -88,14 +85,7 @@ pub fn add_routes(router: Router) -> Router { ) .route("/api/item/{item_id}", delete(item::handle_delete_item)) .route("/api/item/{item_id}/info", get(item::handle_get_item_info)) - .route("/api/diff", get(item::handle_diff_items)); - - #[cfg(feature = "mcp")] - { - router = router.route("/mcp/sse", get(mcp::handle_mcp_sse)); - } - - router + .route("/api/diff", get(item::handle_diff_items)) } #[cfg(feature = "swagger")] diff --git a/src/modes/server/api/status.rs b/src/modes/server/api/status.rs index f65b5a1..28ffb3f 100644 --- a/src/modes/server/api/status.rs +++ b/src/modes/server/api/status.rs @@ -39,7 +39,7 @@ use crate::modes::server::common::{ApiResponse, AppState, StatusInfoResponse}; /// /// # Examples /// -/// ``` +/// ```ignore /// // In an Axum app: /// async fn app() -> Result, StatusCode> { /// handle_status(State(app_state)).await @@ -60,12 +60,17 @@ pub async fn handle_status( // Use the status service to generate status info showing configured plugins let status_service = crate::services::status_service::StatusService::new(); let mut cmd = state.cmd.lock().await; - let status_info = status_service.generate_status( - &mut cmd, - &state.settings, - state.data_dir.clone(), - db_path.into(), - ); + let status_info = status_service + .generate_status( + &mut cmd, + &state.settings, + state.data_dir.clone(), + db_path.into(), + ) + .map_err(|e| { + log::warn!("Failed to generate status: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; let response = StatusInfoResponse { success: true, @@ -112,12 +117,17 @@ pub async fn handle_plugins_status( let status_service = crate::services::status_service::StatusService::new(); let mut cmd = state.cmd.lock().await; - let status_info = status_service.generate_status( - &mut cmd, - &state.settings, - state.data_dir.clone(), - db_path.into(), - ); + let status_info = status_service + .generate_status( + &mut cmd, + &state.settings, + state.data_dir.clone(), + db_path.into(), + ) + .map_err(|e| { + log::warn!("Failed to generate status: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; let response_data = PluginsStatusResponse { meta_plugins: status_info.meta_plugins, diff --git a/src/modes/server/common.rs b/src/modes/server/common.rs index a5a6394..1de4bf7 100644 --- a/src/modes/server/common.rs +++ b/src/modes/server/common.rs @@ -726,34 +726,33 @@ fn check_basic_auth( } let encoded = &auth_str[6..]; - if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) { - if let Ok(decoded_str) = String::from_utf8(decoded_bytes) { - if let Some(colon_pos) = decoded_str.find(':') { - let provided_username = &decoded_str[..colon_pos]; - let provided_password = &decoded_str[colon_pos + 1..]; + if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) + && let Ok(decoded_str) = String::from_utf8(decoded_bytes) + && let Some(colon_pos) = decoded_str.find(':') + { + let provided_username = &decoded_str[..colon_pos]; + let provided_password = &decoded_str[colon_pos + 1..]; - // Check username with constant-time comparison - if !bool::from( - provided_username - .as_bytes() - .ct_eq(expected_username.as_bytes()), - ) { - return false; - } - - // If we have a password hash, verify against it - if let Some(hash) = expected_hash { - return pwhash::unix::verify(provided_password, hash); - } - - // Otherwise, do constant-time comparison to prevent timing attacks - return bool::from( - provided_password - .as_bytes() - .ct_eq(expected_password.as_bytes()), - ); - } + // Check username with constant-time comparison + if !bool::from( + provided_username + .as_bytes() + .ct_eq(expected_username.as_bytes()), + ) { + return false; } + + // If we have a password hash, verify against it + if let Some(hash) = expected_hash { + return pwhash::unix::verify(provided_password, hash); + } + + // Otherwise, do constant-time comparison to prevent timing attacks + return bool::from( + provided_password + .as_bytes() + .ct_eq(expected_password.as_bytes()), + ); } false } diff --git a/src/modes/server/mcp/mod.rs b/src/modes/server/mcp/mod.rs deleted file mode 100644 index 194b4d6..0000000 --- a/src/modes/server/mcp/mod.rs +++ /dev/null @@ -1,83 +0,0 @@ -pub mod server; -pub mod tools; - -pub use server::KeepMcpServer; - -/// Module for handling MCP (Model Context Protocol) requests in the server. -/// -/// Provides handlers for JSON-RPC style requests to interact with Keep's storage -/// via the API. -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use serde::Deserialize; -use serde_json::Value; - -use crate::modes::server::common::ApiResponse; -use crate::modes::server::common::AppState; - -/// Request structure for MCP JSON-RPC calls. -/// -/// # Fields -/// -/// * `method` - The MCP method name (e.g., "save_item"). -/// * `params` - Optional JSON parameters for the method. -#[derive(Deserialize)] -pub struct McpRequest { - pub method: String, - pub params: Option, -} - -/// Handles an MCP request via the Axum framework. -/// -/// Parses the JSON request, delegates to `KeepMcpServer`, and returns an API response. -/// Attempts to parse the result as JSON; falls back to string if invalid. -/// -/// # Arguments -/// -/// * `State(state)` - The application state. -/// * `Json(request)` - The deserialized MCP request. -/// -/// # Returns -/// -/// An `IntoResponse` with status code and JSON API response. -/// -/// # Errors -/// -/// Returns 400 Bad Request on handler errors. -pub async fn handle_mcp_request( - State(state): State, - Json(request): Json, -) -> impl IntoResponse { - let mcp_server = KeepMcpServer::new(state); - - match mcp_server - .handle_request(&request.method, request.params) - .await - { - Ok(result) => match serde_json::from_str(&result) { - Ok(parsed_result) => { - let response = ApiResponse { - success: true, - data: Some(parsed_result), - error: None, - }; - (StatusCode::OK, Json(response)) - } - Err(_) => { - let response = ApiResponse { - success: true, - data: Some(serde_json::Value::String(result)), - error: None, - }; - (StatusCode::OK, Json(response)) - } - }, - Err(e) => { - let response = ApiResponse { - success: false, - data: None, - error: Some(e.to_string()), - }; - (StatusCode::BAD_REQUEST, Json(response)) - } - } -} diff --git a/src/modes/server/mcp/server.rs b/src/modes/server/mcp/server.rs deleted file mode 100644 index d507624..0000000 --- a/src/modes/server/mcp/server.rs +++ /dev/null @@ -1,83 +0,0 @@ -use log::debug; -use serde_json::Value; - -use super::tools::{KeepTools, ToolError}; -use crate::modes::server::common::AppState; - -/// Server handler for MCP (Model Context Protocol) requests. -/// -/// Routes requests to appropriate tools and handles responses. Clones AppState for tool usage. -/// -/// # Fields -/// -/// * `state` - The shared application state (DB, config, etc.). -#[derive(Clone)] -pub struct KeepMcpServer { - state: AppState, -} - -/// Creates a new `KeepMcpServer` instance. -/// -/// # Arguments -/// -/// * `state` - The application state containing DB, config, and services. -/// -/// # Returns -/// -/// A new `KeepMcpServer` instance. -/// -/// # Examples -/// -/// ``` -/// let server = KeepMcpServer::new(app_state); -/// ``` -impl KeepMcpServer { - pub fn new(state: AppState) -> Self { - Self { state } - } - - /// Handles an MCP request by routing to the appropriate tool. - /// - /// Supports methods like "save_item", "get_item", "list_items". Logs the request and delegates to KeepTools. - /// - /// # Arguments - /// - /// * `method` - The MCP method name (string). - /// * `params` - Optional JSON parameters as serde_json::Value. - /// - /// # Returns - /// - /// `Ok(String)` with JSON-serialized response on success, or `Err(ToolError)` on failure. - /// - /// # Errors - /// - /// * ToolError::UnknownTool if method unsupported. - /// * Propagates tool-specific errors (e.g., invalid args, DB failures). - /// - /// # Examples - /// - /// ``` - /// let result = server.handle_request("save_item", Some(params)).await?; - /// ``` - pub async fn handle_request( - &self, - method: &str, - params: Option, - ) -> Result { - debug!( - "MCP: Handling request '{}' with params: {:?}", - method, params - ); - - let tools = KeepTools::new(self.state.clone()); - - match method { - "save_item" => tools.save_item(params).await, - "get_item" => tools.get_item(params).await, - "get_latest_item" => tools.get_latest_item(params).await, - "list_items" => tools.list_items(params).await, - "search_items" => tools.search_items(params).await, - _ => Err(ToolError::UnknownTool(method.to_string())), - } - } -} diff --git a/src/modes/server/mcp/tools.rs b/src/modes/server/mcp/tools.rs deleted file mode 100644 index 4f52a49..0000000 --- a/src/modes/server/mcp/tools.rs +++ /dev/null @@ -1,344 +0,0 @@ -use anyhow::{Result, anyhow}; -use log::debug; -use serde_json::Value; -use std::collections::HashMap; - -use crate::modes::server::common::AppState; -use crate::services::async_item_service::AsyncItemService; -use crate::services::error::CoreError; - -#[derive(Debug, thiserror::Error)] -pub enum ToolError { - #[error("Unknown tool: {0}")] - UnknownTool(String), - #[error("Invalid arguments: {0}")] - InvalidArguments(String), - #[error("Database error: {0}")] - Database(#[from] rusqlite::Error), - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - #[error("JSON error: {0}")] - Json(#[from] serde_json::Error), - #[error("Parse error: {0}")] - Parse(#[from] strum::ParseError), - #[error("Other error: {0}")] - Other(#[from] anyhow::Error), -} - -pub struct KeepTools { - state: AppState, -} - -impl KeepTools { - pub fn new(state: AppState) -> Self { - Self { state } - } - - pub async fn save_item(&self, args: Option) -> Result { - let args = - args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?; - - let content = args - .get("content") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::InvalidArguments("Missing 'content' field".to_string()))?; - - let tags: Vec = args - .get("tags") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .collect() - }) - .unwrap_or_default(); - - let metadata: HashMap = args - .get("metadata") - .and_then(|v| v.as_object()) - .map(|obj| { - obj.iter() - .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string()))) - .collect() - }) - .unwrap_or_default(); - - debug!( - "MCP: Saving item with {} bytes, {} tags, {} metadata entries", - content.len(), - tags.len(), - metadata.len() - ); - - let service = AsyncItemService::new( - self.state.data_dir.clone(), - self.state.db.clone(), - self.state.item_service.clone(), - self.state.cmd.clone(), - self.state.settings.clone(), - ); - let item_with_meta = service - .save_item_from_mcp(content.as_bytes().to_vec(), tags, metadata) - .await - .map_err(|e| ToolError::Other(anyhow::Error::from(e)))?; - - let item_id = item_with_meta - .item - .id - .ok_or_else(|| anyhow!("Failed to get item ID"))?; - - Ok(format!("Successfully saved item with ID: {}", item_id)) - } - - pub async fn get_item(&self, args: Option) -> Result { - let args = - args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?; - - let item_id = args.get("id").and_then(|v| v.as_i64()).ok_or_else(|| { - ToolError::InvalidArguments("Missing or invalid 'id' field".to_string()) - })?; - - let service = AsyncItemService::new( - self.state.data_dir.clone(), - self.state.db.clone(), - self.state.item_service.clone(), - self.state.cmd.clone(), - self.state.settings.clone(), - ); - - let item_with_content = match service.get_item_content(item_id).await { - Ok(iwc) => iwc, - Err(CoreError::ItemNotFound(_)) => { - return Err(ToolError::InvalidArguments(format!( - "Item {} not found", - item_id - ))); - } - Err(e) => return Err(ToolError::Other(anyhow::Error::from(e))), - }; - - let content = String::from_utf8_lossy(&item_with_content.content).to_string(); - let tags: Vec = item_with_content - .item_with_meta - .tags - .iter() - .map(|t| t.name.clone()) - .collect(); - let metadata = item_with_content.item_with_meta.meta_as_map(); - let item = item_with_content.item_with_meta.item; - - let response = serde_json::json!({ - "id": item_id, - "content": content, - "timestamp": item.ts.to_rfc3339(), - "size": item.size, - "compression": item.compression, - "tags": tags, - "metadata": metadata, - }); - - Ok(serde_json::to_string_pretty(&response)?) - } - - pub async fn get_latest_item(&self, args: Option) -> Result { - let tags: Vec = args - .as_ref() - .and_then(|v| v.get("tags")) - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .collect() - }) - .unwrap_or_default(); - - let service = AsyncItemService::new( - self.state.data_dir.clone(), - self.state.db.clone(), - self.state.item_service.clone(), - self.state.cmd.clone(), - self.state.settings.clone(), - ); - - let item_with_meta = match service.find_item(vec![], tags, HashMap::new()).await { - Ok(iwm) => iwm, - Err(CoreError::ItemNotFoundGeneric) => { - return Err(ToolError::InvalidArguments("No items found".to_string())); - } - Err(e) => return Err(ToolError::Other(anyhow::Error::from(e))), - }; - - let item_id = item_with_meta - .item - .id - .ok_or_else(|| anyhow!("Item missing ID after find"))?; - let item_with_content = service - .get_item_content(item_id) - .await - .map_err(|e| ToolError::Other(anyhow::Error::from(e)))?; - - let content = String::from_utf8_lossy(&item_with_content.content).to_string(); - let tags: Vec = item_with_content - .item_with_meta - .tags - .iter() - .map(|t| t.name.clone()) - .collect(); - let metadata = item_with_content.item_with_meta.meta_as_map(); - let item = item_with_content.item_with_meta.item; - - let response = serde_json::json!({ - "id": item_id, - "content": content, - "timestamp": item.ts.to_rfc3339(), - "size": item.size, - "compression": item.compression, - "tags": tags, - "metadata": metadata, - }); - - Ok(serde_json::to_string_pretty(&response)?) - } - - pub async fn list_items(&self, args: Option) -> Result { - let args_ref = args.as_ref(); - let tags: Vec = args_ref - .and_then(|v| v.get("tags")) - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .collect() - }) - .unwrap_or_default(); - - let limit = args_ref - .and_then(|v| v.get("limit")) - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize; - - let offset = args_ref - .and_then(|v| v.get("offset")) - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - - let service = AsyncItemService::new( - self.state.data_dir.clone(), - self.state.db.clone(), - self.state.item_service.clone(), - self.state.cmd.clone(), - self.state.settings.clone(), - ); - let mut items_with_meta = service - .list_items(tags, HashMap::new()) - .await - .map_err(|e| ToolError::Other(anyhow::Error::from(e)))?; - - // Sort by timestamp (newest first) and apply pagination - items_with_meta.sort_by(|a, b| b.item.ts.cmp(&a.item.ts)); - let items_with_meta: Vec<_> = items_with_meta - .into_iter() - .skip(offset) - .take(limit) - .collect(); - - let items_info: Vec<_> = items_with_meta - .into_iter() - .map(|item_with_meta| { - let item_tags: Vec = - item_with_meta.tags.iter().map(|t| t.name.clone()).collect(); - let item_meta = item_with_meta.meta_as_map(); - let item = item_with_meta.item; - let item_id = item.id.unwrap_or(0); - - serde_json::json!({ - "id": item_id, - "timestamp": item.ts.to_rfc3339(), - "size": item.size, - "compression": item.compression, - "tags": item_tags, - "metadata": item_meta - }) - }) - .collect(); - - let response = serde_json::json!({ - "items": items_info, - "count": items_info.len(), - "offset": offset, - "limit": limit - }); - - Ok(serde_json::to_string_pretty(&response)?) - } - - pub async fn search_items(&self, args: Option) -> Result { - let tags: Vec = args - .as_ref() - .and_then(|v| v.get("tags")) - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .collect() - }) - .unwrap_or_default(); - - let metadata: HashMap = args - .as_ref() - .and_then(|v| v.get("metadata")) - .and_then(|v| v.as_object()) - .map(|obj| { - obj.iter() - .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string()))) - .collect() - }) - .unwrap_or_default(); - - let service = AsyncItemService::new( - self.state.data_dir.clone(), - self.state.db.clone(), - self.state.item_service.clone(), - self.state.cmd.clone(), - self.state.settings.clone(), - ); - let mut items_with_meta = service - .list_items(tags.clone(), metadata.clone()) - .await - .map_err(|e| ToolError::Other(anyhow::Error::from(e)))?; - - // Sort by timestamp (newest first) - items_with_meta.sort_by(|a, b| b.item.ts.cmp(&a.item.ts)); - - let items_info: Vec<_> = items_with_meta - .into_iter() - .map(|item_with_meta| { - let item_tags: Vec = - item_with_meta.tags.iter().map(|t| t.name.clone()).collect(); - let item_meta = item_with_meta.meta_as_map(); - let item = item_with_meta.item; - let item_id = item.id.unwrap_or(0); - - serde_json::json!({ - "id": item_id, - "timestamp": item.ts.to_rfc3339(), - "size": item.size, - "compression": item.compression, - "tags": item_tags, - "metadata": item_meta - }) - }) - .collect(); - - let response = serde_json::json!({ - "items": items_info, - "count": items_info.len(), - "search_criteria": { - "tags": tags, - "metadata": metadata - } - }); - - Ok(serde_json::to_string_pretty(&response)?) - } -} diff --git a/src/modes/server/mod.rs b/src/modes/server/mod.rs index 34d7ad7..be274af 100644 --- a/src/modes/server/mod.rs +++ b/src/modes/server/mod.rs @@ -1,7 +1,10 @@ use crate::config; use crate::services::item_service::ItemService; use anyhow::Result; -use axum::{Router, routing::post}; +use axum::Router; +use axum::http::{HeaderValue, header}; +use axum::middleware::Next; +use axum::response::Response; use clap::Command; use log::{debug, info}; use std::net::SocketAddr; @@ -15,12 +18,26 @@ use tower_http::trace::TraceLayer; mod api; pub mod auth; pub mod common; -#[cfg(feature = "mcp")] -mod mcp; mod pages; pub use common::{AppState, create_auth_middleware, logging_middleware}; +/// Adds security headers to all responses. +async fn security_headers(req: axum::extract::Request, next: Next) -> Response { + let mut response = next.run(req).await; + let headers = response.headers_mut(); + headers.insert( + header::X_CONTENT_TYPE_OPTIONS, + HeaderValue::from_static("nosniff"), + ); + headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY")); + headers.insert( + header::REFERRER_POLICY, + HeaderValue::from_static("strict-origin-when-cross-origin"), + ); + response +} + pub fn mode_server( cmd: &mut Command, settings: &config::Settings, @@ -107,23 +124,10 @@ async fn run_server( settings: Arc::new(settings.clone()), }; - #[cfg(feature = "mcp")] - let mcp_router = Router::new() - .route("/mcp", post(mcp::handle_mcp_request)) - .with_state(state.clone()); - - #[cfg_attr(not(feature = "mcp"), allow(unused_mut))] - let mut protected_router = Router::new() + let protected_router = Router::new() .merge(api::add_routes(Router::new())) - .merge(pages::add_routes(Router::new())); - - #[cfg(feature = "mcp")] - { - protected_router = protected_router.merge(mcp_router); - } - - let protected_router = - protected_router.layer(axum::middleware::from_fn(create_auth_middleware( + .merge(pages::add_routes(Router::new())) + .layer(axum::middleware::from_fn(create_auth_middleware( config.username.clone(), config.password.clone(), config.password_hash.clone(), @@ -152,18 +156,19 @@ async fn run_server( axum::http::Method::PUT, axum::http::Method::DELETE, ]) - .allow_headers(tower_http::cors::Any) + .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT]) }; // Create the app with documentation routes open and others protected let app = Router::new() // Add documentation routes without authentication .merge(api::add_docs_routes(Router::new())) - // Add API, pages, and MCP routes with authentication + // Add API and pages routes with authentication .merge(protected_router) // Apply state to all routes .with_state(state) // Add other middleware layers to all routes + .layer(axum::middleware::from_fn(security_headers)) .layer(axum::middleware::from_fn(logging_middleware)) .layer( ServiceBuilder::new() diff --git a/src/modes/server/pages.rs b/src/modes/server/pages.rs index ab392cd..074ee01 100644 --- a/src/modes/server/pages.rs +++ b/src/modes/server/pages.rs @@ -6,11 +6,22 @@ use axum::{ extract::{Path, Query, State}, response::{Html, Response}, }; +use html_escape::{encode_double_quoted_attribute, encode_text}; use log::debug; use rusqlite::Connection; use serde::Deserialize; use std::collections::HashMap; +/// Escape text content for safe HTML insertion. +fn esc(s: &str) -> String { + encode_text(s).to_string() +} + +/// Escape attribute values for safe HTML attribute insertion. +fn esc_attr(s: &str) -> String { + encode_double_quoted_attribute(s).to_string() +} + #[derive(Deserialize)] /// Query parameters for the item list endpoint. /// @@ -62,7 +73,7 @@ fn default_count() -> usize { /// /// # Examples /// -/// ``` +/// ```ignore /// let app = pages::add_routes(axum::Router::new()); /// ``` pub fn add_routes(app: axum::Router) -> axum::Router { @@ -90,7 +101,9 @@ async fn list_items( .map_err(|_| Html("Internal Server Error".to_string()))?; Ok(response) } - Err(e) => Err(Html(format!("Error: {e}"))), + Err(_e) => Err(Html( + "An internal error occurred".to_string(), + )), } } @@ -121,7 +134,8 @@ fn build_item_list( // Apply pagination let start = params.start; - let end = std::cmp::min(start + params.count, sorted_items.len()); + let count = params.count.min(10000); + let end = std::cmp::min(start + count, sorted_items.len()); let page_items = if start < sorted_items.len() { sorted_items[start..std::cmp::min(end, sorted_items.len())].to_vec() } else { @@ -153,11 +167,11 @@ fn build_item_list( // Collect all tags from all items, keeping track of their timestamps let mut all_tags_with_time: Vec<(String, chrono::DateTime)> = Vec::new(); for item in &sorted_items { - if let Some(item_id) = item.id { - if let Some(tags) = tags_map.get(&item_id) { - for tag in tags { - all_tags_with_time.push((tag.name.clone(), item.ts)); - } + if let Some(item_id) = item.id + && let Some(tags) = tags_map.get(&item_id) + { + for tag in tags { + all_tags_with_time.push((tag.name.clone(), item.ts)); } } } @@ -184,7 +198,9 @@ fn build_item_list( html.push_str("

"); for tag in recent_tags { html.push_str(&format!( - "{tag}" + "{}", + esc_attr(&tag), + esc(&tag) )); } html.push_str("

"); @@ -196,7 +212,7 @@ fn build_item_list( // Table headers html.push_str(""); for column in columns { - html.push_str(&format!("{}", column.label)); + html.push_str(&format!("{}", esc(&column.label))); } html.push_str("Actions"); html.push_str(""); @@ -229,7 +245,13 @@ fn build_item_list( // Make sure we're using all tags for the item let tag_links: Vec = tags .iter() - .map(|t| format!("{}", t.name, t.name)) + .map(|t| { + format!( + "{}", + esc_attr(&t.name), + esc(&t.name) + ) + }) .collect(); tag_links.join(", ") } @@ -268,7 +290,15 @@ fn build_item_list( crate::config::ColumnAlignment::Center => "text-align: center;", }; - html.push_str(&format!("{display_value}")); + let rendered_value = if column.name == "tags" { + display_value // Already contains escaped HTML links + } else { + esc(&display_value) + }; + + html.push_str(&format!( + "{rendered_value}" + )); } // Actions column @@ -361,7 +391,9 @@ async fn show_item( .map_err(|_| Html("Internal Server Error".to_string()))?; Ok(response) } - Err(e) => Err(Html(format!("Error: {e}"))), + Err(_e) => Err(Html( + "An internal error occurred".to_string(), + )), } } @@ -396,7 +428,7 @@ fn build_item_details(conn: &Connection, id: i64) -> Result { )); html.push_str(&format!( "Compression{}", - item.compression + esc(&item.compression) )); // Tags row @@ -406,7 +438,13 @@ fn build_item_details(conn: &Connection, id: i64) -> Result { } else { let tag_links: Vec = tags .iter() - .map(|t| format!("{}", t.name, t.name)) + .map(|t| { + format!( + "{}", + esc_attr(&t.name), + esc(&t.name) + ) + }) .collect(); html.push_str(&tag_links.join(", ")); } @@ -419,7 +457,8 @@ fn build_item_details(conn: &Connection, id: i64) -> Result { for meta in metas { html.push_str(&format!( "{}{}", - meta.name, meta.value + esc(&meta.name), + esc(&meta.value) )); } } diff --git a/src/modes/status.rs b/src/modes/status.rs index cf3f1e1..fdcd49e 100644 --- a/src/modes/status.rs +++ b/src/modes/status.rs @@ -198,7 +198,7 @@ pub fn mode_status( let status_service = crate::services::status_service::StatusService::new(); let output_format = crate::modes::common::settings_output_format(settings); debug!("STATUS: About to generate status info"); - let status_info = status_service.generate_status(cmd, settings, data_path, db_path); + let status_info = status_service.generate_status(cmd, settings, data_path, db_path)?; debug!("STATUS: Status info generated successfully"); match output_format { diff --git a/src/modes/status_plugins.rs b/src/modes/status_plugins.rs index b332eae..e296630 100644 --- a/src/modes/status_plugins.rs +++ b/src/modes/status_plugins.rs @@ -298,7 +298,7 @@ pub fn mode_status_plugins( let status_service = crate::services::status_service::StatusService::new(); let output_format = crate::modes::common::settings_output_format(settings); debug!("STATUS_PLUGINS: About to generate status info"); - let status_info = status_service.generate_status(cmd, settings, data_path, db_path); + let status_info = status_service.generate_status(cmd, settings, data_path, db_path)?; debug!("STATUS_PLUGINS: Status info generated successfully"); match output_format { diff --git a/src/services/async_data_service.rs b/src/services/async_data_service.rs index 15bce7f..5fdb8fd 100644 --- a/src/services/async_data_service.rs +++ b/src/services/async_data_service.rs @@ -81,24 +81,6 @@ impl AsyncDataService { DataService::find_item(self, &mut conn, ids, tags, meta) } - pub async fn get_item_content_info( - &self, - id: i64, - _filter: Option, - ) -> Result<(Vec, ItemWithMeta, bool), CoreError> { - let mut conn = self.db.lock().await; - let (mut reader, item_with_meta) = self.get_content(&mut conn, id)?; - let mut content = Vec::new(); - reader.read_to_end(&mut content)?; - let is_binary = item_with_meta - .meta - .iter() - .find(|m| m.name == "text") - .map(|m| m.value == "false") - .unwrap_or(false); - Ok((content, item_with_meta, is_binary)) - } - pub async fn get_item_content_info_streaming( &self, id: i64, @@ -196,27 +178,29 @@ impl AsyncDataService { Ok((Box::pin(stream), content_length)) } - /// Get raw item content without decompression. + /// Get raw item content without decompression as a streaming reader. /// - /// Reads the stored file bytes directly from disk, bypassing decompression. + /// Opens the stored file directly from disk, bypassing decompression. /// Used when the client requests raw bytes with `decompress=false`. - pub async fn get_raw_item_content(&self, id: i64) -> Result, CoreError> { + /// Returns a boxed reader that can be used for streaming. + pub async fn get_raw_item_content_reader( + &self, + id: i64, + ) -> Result, CoreError> { let data_path = self.data_path.clone(); tokio::task::spawn_blocking(move || { let mut item_path = data_path; item_path.push(id.to_string()); - let mut file = std::fs::File::open(&item_path).map_err(|e| { + let file = std::fs::File::open(&item_path).map_err(|e| { CoreError::Io(std::io::Error::new( std::io::ErrorKind::NotFound, format!("Item file not found: {item_path:?}: {e}"), )) })?; - let mut content = Vec::new(); - file.read_to_end(&mut content)?; - Ok(content) + Ok(Box::new(file) as Box) }) .await .map_err(|e| CoreError::Other(anyhow::anyhow!("Task join error: {}", e)))? @@ -295,6 +279,6 @@ impl DataService for AsyncDataService { settings, data_path.to_path_buf(), db_path.to_path_buf(), - )) + )?) } } diff --git a/src/services/async_item_service.rs b/src/services/async_item_service.rs index a0c1cae..360f9bd 100644 --- a/src/services/async_item_service.rs +++ b/src/services/async_item_service.rs @@ -101,17 +101,6 @@ impl AsyncItemService { .await } - pub async fn get_item_content_info( - &self, - id: i64, - filter: Option, - ) -> Result<(Vec, String, bool), CoreError> { - self.execute_blocking(move |conn, item_service| { - item_service.get_item_content_info(conn, id, filter) - }) - .await - } - pub async fn stream_item_content_by_id( &self, item_id: i64, @@ -131,41 +120,13 @@ impl AsyncItemService { ), CoreError, > { - let content = self + // Use streaming approach: get reader and stream chunks in requested range + let (reader, mime_type, is_binary) = self .execute_blocking(move |conn, item_service| { - let item_with_content = item_service.get_item_content(conn, item_id)?; - Ok::<_, CoreError>(item_with_content.content) + item_service.get_item_content_info_streaming(conn, item_id, None) }) .await?; - // Clone content for use in the binary check closure - let content_clone = content.clone(); - - // Get metadata to determine MIME type and binary status - let (mime_type, is_binary) = { - let db = self.db.clone(); - let item_service = self.item_service.clone(); - tokio::task::spawn_blocking(move || { - let conn = db.blocking_lock(); - let item_with_meta = item_service.get_item(&conn, item_id)?; - let metadata = item_with_meta.meta_as_map(); - - let mime_type = metadata - .get("mime_type") - .map(|s| s.to_string()) - .unwrap_or_else(|| "application/octet-stream".to_string()); - - let is_binary = crate::common::is_binary::is_content_binary_from_metadata( - &metadata, - &content_clone, - ); - - Ok::<_, CoreError>((mime_type, is_binary)) - }) - .await - .unwrap()? - }; - // Check if content is binary when allow_binary is false if !allow_binary && is_binary { return Err(CoreError::InvalidInput( @@ -173,26 +134,76 @@ impl AsyncItemService { )); } - // Create a stream that reads only the requested portion - let content_len = content.len() as u64; + // Convert the reader into an async stream with offset and length applied + use tokio_util::bytes::Bytes; - // Apply offset and length constraints - let start = std::cmp::min(offset, content_len); - let end = if length > 0 { - std::cmp::min(start + length, content_len) - } else { - content_len - }; + // Create a channel to stream data between the blocking thread and async runtime + let (tx, rx) = tokio::sync::mpsc::channel::>(16); - let stream = if start < content_len { - let chunk = - tokio_util::bytes::Bytes::from(content[start as usize..end as usize].to_vec()); - Box::pin(tokio_stream::iter(vec![Ok(chunk)])) - } else { - Box::pin(tokio_stream::iter(vec![])) - }; + // Spawn a blocking task to read from the reader and send chunks + tokio::task::spawn_blocking(move || { + let mut reader = reader; + let mut buf = [0u8; PIPESIZE]; - Ok((stream, mime_type)) + // Apply offset by reading and discarding bytes + if offset > 0 { + let mut remaining = offset; + while remaining > 0 { + let to_read = std::cmp::min(remaining, buf.len() as u64) as usize; + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, // EOF reached before offset + Ok(n) => remaining -= n as u64, + Err(e) => { + let _ = tx.blocking_send(Err(e)); + return; + } + } + } + } + + // Read and send data up to the specified length + let mut remaining_length = length; + + loop { + // Determine how much to read in this iteration + let to_read = if length > 0 { + // If length is specified, don't read more than remaining_length + std::cmp::min(remaining_length, buf.len() as u64) as usize + } else { + buf.len() + }; + + if to_read == 0 { + break; // We've read the requested length + } + + match reader.read(&mut buf[..to_read]) { + Ok(0) => break, // EOF + Ok(n) => { + let chunk = Bytes::copy_from_slice(&buf[..n]); + // Block on sending to the channel + if tx.blocking_send(Ok(chunk)).is_err() { + break; // Receiver dropped + } + if length > 0 { + remaining_length -= n as u64; + if remaining_length == 0 { + break; // Reached the requested length + } + } + } + Err(e) => { + let _ = tx.blocking_send(Err(e)); + break; + } + } + } + }); + + // Convert the receiver into a stream + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + Ok((Box::pin(stream), mime_type)) } pub async fn stream_item_content_by_id_with_metadata( @@ -374,28 +385,6 @@ impl AsyncItemService { item_service.delete_item(&mut conn, id) }) .await - .unwrap() - } - - pub async fn save_item_from_mcp( - &self, - content: Vec, - tags: Vec, - metadata: HashMap, - ) -> Result { - let db = self.db.clone(); - let item_service = self.item_service.clone(); - let cmd = self.cmd.clone(); - let settings = self.settings.clone(); - - tokio::task::spawn_blocking(move || { - let mut conn = db.blocking_lock(); - let mut cmd = cmd.blocking_lock(); - let settings = settings.as_ref(); - item_service - .save_item_from_mcp(&content, &tags, &metadata, &mut cmd, settings, &mut conn) - }) - .await - .unwrap() + .map_err(|e| CoreError::Other(anyhow::anyhow!("task join error: {e}")))? } } diff --git a/src/services/filter_service.rs b/src/services/filter_service.rs index 6dde518..0a6a3b6 100644 --- a/src/services/filter_service.rs +++ b/src/services/filter_service.rs @@ -188,11 +188,12 @@ static FILTER_PLUGIN_REGISTRY: Lazy>> = /// ```ignore /// register_filter_plugin("custom_filter", || Box::new(CustomFilter::default())); /// ``` -pub fn register_filter_plugin(name: &str, constructor: FilterConstructor) { +pub fn register_filter_plugin(name: &str, constructor: FilterConstructor) -> anyhow::Result<()> { FILTER_PLUGIN_REGISTRY .lock() - .unwrap() + .map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}"))? .insert(name.to_string(), constructor); + Ok(()) } /// Retrieves a snapshot of all registered filter plugins. @@ -214,6 +215,9 @@ pub fn register_filter_plugin(name: &str, constructor: FilterConstructor) { /// let plugins = get_available_filter_plugins(); /// // Plugins are registered at startup via ctors; specific names may vary by configuration. /// ``` -pub fn get_available_filter_plugins() -> HashMap { - FILTER_PLUGIN_REGISTRY.lock().unwrap().clone() +pub fn get_available_filter_plugins() -> anyhow::Result> { + FILTER_PLUGIN_REGISTRY + .lock() + .map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}")) + .map(|guard| guard.clone()) } diff --git a/src/services/item_service.rs b/src/services/item_service.rs index 0176ab0..ca5e111 100644 --- a/src/services/item_service.rs +++ b/src/services/item_service.rs @@ -106,6 +106,8 @@ impl ItemService { /// Retrieves an item with its content, metadata, and tags. /// /// Loads the item, its metadata/tags, and decompresses the full content. + /// This method is intended for CLI use only and has a size guard (100MB). + /// For larger items or server use, use `get_item_content_info_streaming`. /// /// # Arguments /// @@ -120,6 +122,7 @@ impl ItemService { /// /// * `CoreError::ItemNotFound(id)` - If the item does not exist. /// * `CoreError::Io(...)` - If file read or decompression fails. + /// * `CoreError::InvalidInput(...)` - If item exceeds 100MB size limit. /// /// # Examples /// @@ -132,6 +135,9 @@ impl ItemService { conn: &Connection, id: i64, ) -> Result { + // Size limit for loading entire content into memory (100MB) + const MAX_CONTENT_SIZE: i64 = 100 * 1024 * 1024; + debug!("ITEM_SERVICE: Getting item content for id: {id}"); let item_with_meta = self.get_item(conn, id)?; let item_id = item_with_meta @@ -145,6 +151,16 @@ impl ItemService { ))); } + // Check size guard before loading content + if let Some(size) = item_with_meta.item.size + && size > MAX_CONTENT_SIZE + { + return Err(CoreError::InvalidInput(format!( + "Item {} exceeds size limit ({} > {}). Use streaming API for large items.", + item_id, size, MAX_CONTENT_SIZE + ))); + } + let mut item_path = self.data_path.clone(); item_path.push(item_id.to_string()); debug!("ITEM_SERVICE: Reading content from path: {item_path:?}"); @@ -164,47 +180,6 @@ impl ItemService { }) } - /// Retrieves item content with binary detection and optional filtering. - /// - /// Loads content, applies filters if specified, and determines MIME type and binary status. - /// - /// # Arguments - /// - /// * `conn` - Database connection. - /// * `id` - Item ID. - /// * `filter` - Optional filter string to apply to content. - /// - /// # Returns - /// - /// * `Result<(Vec, String, bool), CoreError>` - (content, MIME type, is_binary). - /// - /// # Errors - /// - /// * `CoreError::ItemNotFound(id)` - If item not found. - /// * Filter or compression errors. - /// - /// # Examples - /// - /// ```ignore - /// let (content, mime, is_binary) = item_service.get_item_content_info(&conn, 1, Some("head_lines(10)"))?; - /// ``` - pub fn get_item_content_info( - &self, - conn: &Connection, - id: i64, - filter: Option, - ) -> Result<(Vec, String, bool), CoreError> { - // Use streaming approach to handle all filtering options consistently - let (mut reader, mime_type, is_binary) = - self.get_item_content_info_streaming(conn, id, filter)?; - - // Read all the filtered content into a buffer - let mut content = Vec::new(); - reader.read_to_end(&mut content)?; - - Ok((content, mime_type, is_binary)) - } - /// Determines if item content is binary based on metadata or sampling. /// /// Checks existing "text" metadata first; if absent, samples the first 8192 bytes. @@ -717,110 +692,6 @@ impl ItemService { Ok(item) } - /// Saves pre-loaded content as a new item, typically from MCP (Machine-Common-Processing) sources. - /// - /// Bypasses streaming read, directly writes content and applies metadata/plugins. - /// - /// # Arguments - /// - /// * `content` - Byte slice of content to save. - /// * `tags` - Tags to associate. - /// * `metadata` - Initial metadata key-value pairs. - /// * `cmd` - Mutable command. - /// * `settings` - Settings. - /// * `conn` - Mutable database connection. - /// - /// # Returns - /// - /// * `Result` - The saved item with full details. - /// - /// # Errors - /// - /// * `CoreError::Database(...)` - If DB insert fails. - /// * `CoreError::Io(...)` - If file write fails. - /// - /// # Examples - /// - /// ```ignore - /// let content = b"Hello, world!"; - /// let tags = vec!["mcp".to_string()]; - /// let meta = HashMap::from([("source".to_string(), "api".to_string())]); - /// let item = service.save_item_from_mcp(content, &tags, &meta, &mut cmd, &settings, &mut conn)?; - /// ``` - pub fn save_item_from_mcp( - &self, - content: &[u8], - tags: &Vec, - metadata: &HashMap, - cmd: &mut Command, - settings: &Settings, - conn: &mut Connection, - ) -> Result { - debug!( - "ITEM_SERVICE: Starting save_item_from_mcp with {} bytes, {} tags, {} metadata entries", - content.len(), - tags.len(), - metadata.len() - ); - let compression_type = CompressionType::LZ4; - let compression_engine = get_compression_engine(compression_type.clone())?; - - let item_id; - let mut item; - - { - item = db::create_item(conn, compression_type.clone())?; - item_id = item - .id - .ok_or_else(|| CoreError::InvalidInput("Item missing ID".to_string()))?; - debug!("ITEM_SERVICE: Created MCP item with id: {item_id}"); - - // Add tags - for tag in tags { - db::add_tag(conn, item_id, tag)?; - } - debug!("ITEM_SERVICE: Added {} tags to MCP item", tags.len()); - - // Add custom metadata - for (key, value) in metadata { - db::add_meta(conn, item_id, key, value)?; - } - debug!( - "ITEM_SERVICE: Added {} custom metadata entries to MCP item", - metadata.len() - ); - } - - let mut item_path = self.data_path.clone(); - item_path.push(item_id.to_string()); - debug!("ITEM_SERVICE: Writing MCP item to path: {item_path:?}"); - - let mut writer = compression_engine.create(item_path.clone())?; - writer.write_all(content)?; - drop(writer); - - let mut plugins = self.meta_service.get_plugins(cmd, settings); - debug!( - "ITEM_SERVICE: Got {} configured meta plugins for MCP item", - plugins.len() - ); - - self.meta_service - .initialize_plugins(&mut plugins, conn, item_id); - self.meta_service - .process_chunk(&mut plugins, content, conn, item_id); - self.meta_service - .finalize_plugins(&mut plugins, conn, item_id); - debug!("ITEM_SERVICE: Processed MCP item through configured meta plugins"); - - item.size = Some(content.len() as i64); - db::update_item(conn, item.clone())?; - - debug!("ITEM_SERVICE: MCP item saved successfully"); - - self.get_item(conn, item_id) - } - /// Returns a reference to the internal compression service. /// /// # Returns diff --git a/src/services/status_service.rs b/src/services/status_service.rs index d4c8a8a..bde5cfe 100644 --- a/src/services/status_service.rs +++ b/src/services/status_service.rs @@ -74,7 +74,7 @@ impl StatusService { settings: &Settings, data_path: PathBuf, db_path: PathBuf, - ) -> StatusInfo { + ) -> anyhow::Result { // Get meta plugins directly from config let meta_plugin_types: Vec = crate::modes::common::settings_meta_plugin_types(cmd, settings); @@ -91,10 +91,10 @@ impl StatusService { db_path, &meta_plugin_types, enabled_compression_type, - ); + )?; // Add detailed filter plugins information - let filter_plugins_map = get_available_filter_plugins(); + let filter_plugins_map = get_available_filter_plugins()?; let mut filter_plugins_info = Vec::new(); for (name, creator) in filter_plugins_map { @@ -114,7 +114,7 @@ impl StatusService { // Add configured meta plugins information status_info.configured_meta_plugins = settings.meta_plugins.clone(); - status_info + Ok(status_info) } } diff --git a/src/services/sync_data_service.rs b/src/services/sync_data_service.rs index ebb36d8..f6a58ca 100644 --- a/src/services/sync_data_service.rs +++ b/src/services/sync_data_service.rs @@ -57,34 +57,6 @@ impl SyncDataService { .save_item(content, cmd, settings, tags, conn) } - pub fn save_item_with_reader( - &self, - conn: &mut Connection, - reader: &mut R, - tags: Vec, - metadata: HashMap, - ) -> Result { - let mut cmd = Command::new("keep"); - let settings = &self.settings; - let mut tags = tags; - - // Read content from reader - let mut content = Vec::new(); - reader.read_to_end(&mut content)?; - - let item = self.save_item(&*content, &mut cmd, settings, &mut tags, conn)?; - let item_id = item - .id - .ok_or_else(|| CoreError::InvalidInput("Item missing ID".to_string()))?; - - // Set metadata - for (key, value) in metadata { - crate::db::add_meta(conn, item_id, &key, &value)?; - } - - self.get_item(conn, item_id) - } - /// Save an item with granular control over compression and meta plugins. /// /// This method allows clients to control whether compression and meta plugins @@ -266,7 +238,9 @@ impl SyncDataService { db_path: PathBuf, ) -> StatusInfo { let status_service = StatusService::new(); - status_service.generate_status(cmd, settings, data_path, db_path) + status_service + .generate_status(cmd, settings, data_path, db_path) + .unwrap_or_else(|_| StatusInfo::default()) } } @@ -359,6 +333,6 @@ impl DataService for SyncDataService { settings, data_path.to_path_buf(), db_path.to_path_buf(), - )) + )?) } }