refactor: streaming, security hardening, and MCP removal

Major overhaul of server architecture and security posture:

- Streaming: Unified all I/O through PIPESIZE (8192-byte) buffers.
  POST bodies stream via MpscReader through the save pipeline. GET
  content streams from disk via decompression to client. Removed
  save_item_with_reader, get_item_content_info, ChannelReader.
  413 responses keep partial items (nonfatal by design).

- Security: XSS protection in all HTML pages via html_escape crate.
  Security headers middleware (nosniff, frame deny, referrer policy).
  CORS tightened to explicit headers. Input validation for tags
  (256 chars), metadata (128/4096), pagination (10k cap). Config
  file reads use from_utf8_lossy. Generic error messages in HTML.
  Diff endpoint has 10 MB per-item cap. max_body_size config option.

- Panics eliminated: Path unwraps → proper error propagation.
  Mutex unwraps → map_err (registries) / expect with message (local).

- MCP removed: Deleted all MCP code, rmcp dependency, mcp feature.

- Docs: Updated README, DESIGN, AGENTS to reflect all changes.
This commit is contained in:
2026-03-14 00:03:42 -03:00
parent 560ba6e20c
commit 17be6abaab
51 changed files with 876 additions and 1309 deletions

View File

@@ -165,6 +165,10 @@ pub struct OptionsArgs {
#[arg(help("Path to file containing JWT secret (requires --server)"))]
pub server_jwt_secret_file: Option<PathBuf>,
#[arg(long, env("KEEP_SERVER_MAX_BODY_SIZE"))]
#[arg(help("Maximum request body size in bytes (requires --server, default: unlimited)"))]
pub server_max_body_size: Option<u64>,
#[cfg(feature = "client")]
#[arg(long, env("KEEP_CLIENT_URL"), help_heading("Client Options"))]
#[arg(help("Remote keep server URL for client mode"))]

View File

@@ -125,7 +125,7 @@ pub fn gather_meta_plugin_schemas() -> Vec<PluginSchema> {
pub fn gather_filter_plugin_schemas() -> Vec<PluginSchema> {
use crate::services::filter_service::get_available_filter_plugins;
let plugins = get_available_filter_plugins();
let plugins = get_available_filter_plugins().unwrap_or_default();
let mut schemas: Vec<PluginSchema> = plugins
.into_iter()
.map(|(name, creator)| {

View File

@@ -27,6 +27,22 @@ pub struct StatusInfo {
pub configured_meta_plugins: Option<Vec<crate::config::MetaPluginConfig>>,
}
impl Default for StatusInfo {
fn default() -> Self {
Self {
paths: PathInfo {
data: String::new(),
database: String::new(),
},
compression: Vec::new(),
meta_plugins: std::collections::HashMap::new(),
enabled_meta_plugins: Vec::new(),
filter_plugins: Vec::new(),
configured_meta_plugins: None,
}
}
}
#[derive(serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "server", derive(ToSchema))]
pub struct PathInfo {
@@ -59,17 +75,17 @@ pub fn generate_status_info(
db_path: PathBuf,
enabled_meta_plugins: &[MetaPluginType],
enabled_compression_type: Option<CompressionType>,
) -> StatusInfo {
) -> anyhow::Result<StatusInfo> {
log::debug!("STATUS: Starting status info generation");
let path_info = PathInfo {
data: data_path
.into_os_string()
.into_string()
.expect("Unable to convert data path to string"),
.map_err(|_| anyhow::anyhow!("Unable to convert data path to string"))?,
database: db_path
.into_os_string()
.into_string()
.expect("Unable to convert DB path to string"),
.map_err(|_| anyhow::anyhow!("Unable to convert DB path to string"))?,
};
let _default_type = crate::compression_engine::default_compression_type();
@@ -183,7 +199,7 @@ pub fn generate_status_info(
}
// Populate filter plugin info from the global registry
let filter_plugins_map = crate::services::filter_service::get_available_filter_plugins();
let filter_plugins_map = crate::services::filter_service::get_available_filter_plugins()?;
let filter_plugins_info: Vec<FilterPluginInfo> = filter_plugins_map
.into_iter()
.map(|(name, creator)| {
@@ -196,12 +212,12 @@ pub fn generate_status_info(
})
.collect();
StatusInfo {
Ok(StatusInfo {
paths: path_info,
compression: compression_info,
meta_plugins: meta_plugins_map,
enabled_meta_plugins: enabled_meta_plugins_vec,
filter_plugins: filter_plugins_info,
configured_meta_plugins: None,
}
})
}

View File

@@ -152,6 +152,7 @@ pub struct ServerConfig {
pub cert_file: Option<PathBuf>,
pub key_file: Option<PathBuf>,
pub cors_origin: Option<String>,
pub max_body_size: Option<u64>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -256,7 +257,11 @@ impl Settings {
// Override with CLI args
if let Some(dir) = &args.options.dir {
debug!("CONFIG: Overriding dir with CLI arg: {dir:?}");
config_builder = config_builder.set_override("dir", dir.to_str().unwrap())?;
config_builder = config_builder.set_override(
"dir",
dir.to_str()
.ok_or_else(|| anyhow::anyhow!("non-UTF-8 directory path"))?,
)?;
}
if args.options.human_readable {
@@ -316,6 +321,10 @@ impl Settings {
.set_override("server.key_file", server_key.to_string_lossy().as_ref())?;
}
if let Some(max_body_size) = args.options.server_max_body_size {
config_builder = config_builder.set_override("server.max_body_size", max_body_size)?;
}
if let Some(compression) = &args.item.compression {
config_builder =
config_builder.set_override("compression_plugin.name", compression.as_str())?;
@@ -486,10 +495,10 @@ impl Settings {
// First check for password_file
if let Some(password_file) = &server.password_file {
debug!("CONFIG: Reading password from file: {password_file:?}");
let password = fs::read_to_string(password_file)
.with_context(|| format!("Failed to read password file: {password_file:?}"))?
.trim()
.to_string();
let password = fs::read(password_file)
.with_context(|| format!("Failed to read password file: {password_file:?}"))?;
let end = password.len().min(4096);
let password = String::from_utf8_lossy(&password[..end]).trim().to_string();
return Ok(Some(password));
}
@@ -521,12 +530,11 @@ impl Settings {
// First check for jwt_secret_file
if let Some(jwt_secret_file) = &server.jwt_secret_file {
debug!("CONFIG: Reading JWT secret from file: {jwt_secret_file:?}");
let secret = fs::read_to_string(jwt_secret_file)
.with_context(|| {
format!("Failed to read JWT secret file: {jwt_secret_file:?}")
})?
.trim()
.to_string();
let secret = fs::read(jwt_secret_file).with_context(|| {
format!("Failed to read JWT secret file: {jwt_secret_file:?}")
})?;
let end = secret.len().min(4096);
let secret = String::from_utf8_lossy(&secret[..end]).trim().to_string();
return Ok(Some(secret));
}

View File

@@ -224,5 +224,6 @@ fn register_exec_filter() {
stdin_writer: None,
stdout_reader: None,
})
});
})
.expect("Failed to register exec filter");
}

View File

@@ -283,6 +283,8 @@ impl FilterPlugin for HeadLinesFilter {
// Register the plugin at module initialization time
#[ctor::ctor]
fn register_head_filters() {
register_filter_plugin("head_bytes", || Box::new(HeadBytesFilter::new(0)));
register_filter_plugin("head_lines", || Box::new(HeadLinesFilter::new(0)));
register_filter_plugin("head_bytes", || Box::new(HeadBytesFilter::new(0)))
.expect("Failed to register head_bytes filter");
register_filter_plugin("head_lines", || Box::new(HeadLinesFilter::new(0)))
.expect("Failed to register head_lines filter");
}

View File

@@ -150,6 +150,8 @@ impl FilterPlugin for SkipLinesFilter {
// Register the plugin at module initialization time
#[ctor::ctor]
fn register_skip_filters() {
register_filter_plugin("skip_bytes", || Box::new(SkipBytesFilter::new(0)));
register_filter_plugin("skip_lines", || Box::new(SkipLinesFilter::new(0)));
register_filter_plugin("skip_bytes", || Box::new(SkipBytesFilter::new(0)))
.expect("Failed to register skip_bytes filter");
register_filter_plugin("skip_lines", || Box::new(SkipLinesFilter::new(0)))
.expect("Failed to register skip_lines filter");
}

View File

@@ -169,6 +169,8 @@ impl FilterPlugin for TailLinesFilter {
// Register the plugin at module initialization time
#[ctor::ctor]
fn register_tail_filters() {
register_filter_plugin("tail_bytes", || Box::new(TailBytesFilter::new(0)));
register_filter_plugin("tail_lines", || Box::new(TailLinesFilter::new(0)));
register_filter_plugin("tail_bytes", || Box::new(TailBytesFilter::new(0)))
.expect("Failed to register tail_bytes filter");
register_filter_plugin("tail_lines", || Box::new(TailLinesFilter::new(0)))
.expect("Failed to register tail_lines filter");
}

View File

@@ -377,9 +377,12 @@ fn map_lossy_pos_to_bytes(original: &[u8], lossy: &str, lossy_pos: usize) -> usi
#[ctor::ctor]
fn register_token_filters() {
register_filter_plugin("head_tokens", || Box::new(HeadTokensFilter::new(0)));
register_filter_plugin("skip_tokens", || Box::new(SkipTokensFilter::new(0)));
register_filter_plugin("tail_tokens", || Box::new(TailTokensFilter::new(0)));
register_filter_plugin("head_tokens", || Box::new(HeadTokensFilter::new(0)))
.expect("Failed to register head_tokens filter");
register_filter_plugin("skip_tokens", || Box::new(SkipTokensFilter::new(0)))
.expect("Failed to register skip_tokens filter");
register_filter_plugin("tail_tokens", || Box::new(TailTokensFilter::new(0)))
.expect("Failed to register tail_tokens filter");
}
#[cfg(test)]

View File

@@ -128,5 +128,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_cwd_plugin() {
register_meta_plugin(MetaPluginType::Cwd, |options, outputs| {
Box::new(CwdMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register CwdMetaPlugin");
}

View File

@@ -271,5 +271,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_digest_plugin() {
register_meta_plugin(MetaPluginType::Digest, |options, outputs| {
Box::new(DigestMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register DigestMetaPlugin");
}

View File

@@ -227,5 +227,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_env_plugin() {
register_meta_plugin(MetaPluginType::Env, |options, outputs| {
Box::new(EnvMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register EnvMetaPlugin");
}

View File

@@ -311,5 +311,6 @@ fn register_exec_plugin() {
options,
outputs,
))
});
})
.expect("Failed to register ExecMetaPlugin");
}

View File

@@ -406,5 +406,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_hostname_plugin() {
register_meta_plugin(MetaPluginType::Hostname, |options, outputs| {
Box::new(HostnameMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register HostnameMetaPlugin");
}

View File

@@ -204,5 +204,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_keep_pid_plugin() {
register_meta_plugin(MetaPluginType::KeepPid, |options, outputs| {
Box::new(KeepPidMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register KeepPidMetaPlugin");
}

View File

@@ -457,5 +457,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_magic_file_plugin() {
register_meta_plugin(MetaPluginType::MagicFile, |options, outputs| {
Box::new(MagicFileMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register MagicFileMetaPlugin");
}

View File

@@ -578,11 +578,15 @@ static META_PLUGIN_REGISTRY: Lazy<Mutex<HashMap<MetaPluginType, PluginConstructo
///
/// * `meta_plugin_type` - The type of the meta plugin to register.
/// * `constructor` - The constructor function for creating plugin instances.
pub fn register_meta_plugin(meta_plugin_type: MetaPluginType, constructor: PluginConstructor) {
pub fn register_meta_plugin(
meta_plugin_type: MetaPluginType,
constructor: PluginConstructor,
) -> anyhow::Result<()> {
META_PLUGIN_REGISTRY
.lock()
.unwrap()
.map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}"))?
.insert(meta_plugin_type, constructor);
Ok(())
}
pub fn get_meta_plugin(
@@ -590,7 +594,9 @@ pub fn get_meta_plugin(
options: Option<std::collections::HashMap<String, serde_yaml::Value>>,
outputs: Option<std::collections::HashMap<String, serde_yaml::Value>>,
) -> anyhow::Result<Box<dyn MetaPlugin>> {
let registry = META_PLUGIN_REGISTRY.lock().unwrap();
let registry = META_PLUGIN_REGISTRY
.lock()
.map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}"))?;
if let Some(constructor) = registry.get(&meta_plugin_type) {
return Ok(constructor(options, outputs));
}

View File

@@ -237,5 +237,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_read_rate_plugin() {
register_meta_plugin(MetaPluginType::ReadRate, |options, outputs| {
Box::new(ReadRateMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register ReadRateMetaPlugin");
}

View File

@@ -124,5 +124,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_read_time_plugin() {
register_meta_plugin(MetaPluginType::ReadTime, |options, outputs| {
Box::new(ReadTimeMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register ReadTimeMetaPlugin");
}

View File

@@ -240,5 +240,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_shell_plugin() {
register_meta_plugin(MetaPluginType::Shell, |options, outputs| {
Box::new(ShellMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register ShellMetaPlugin");
}

View File

@@ -132,5 +132,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_shell_pid_plugin() {
register_meta_plugin(MetaPluginType::ShellPid, |options, outputs| {
Box::new(ShellPidMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register ShellPidMetaPlugin");
}

View File

@@ -818,5 +818,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_text_plugin() {
register_meta_plugin(MetaPluginType::Text, |options, outputs| {
Box::new(TextMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register TextMetaPlugin");
}

View File

@@ -312,5 +312,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_tokens_plugin() {
register_meta_plugin(MetaPluginType::Tokens, |options, outputs| {
Box::new(TokensMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register TokensMetaPlugin");
}

View File

@@ -166,5 +166,6 @@ use crate::meta_plugin::register_meta_plugin;
fn register_user_plugin() {
register_meta_plugin(MetaPluginType::User, |options, outputs| {
Box::new(UserMetaPlugin::new(options, outputs))
});
})
.expect("Failed to register UserMetaPlugin");
}

View File

@@ -92,7 +92,7 @@ pub fn mode(
let digest = format!("{:x}", hasher.finalize());
// Set shared state for main thread
let mut shared = shared_reader.lock().unwrap();
let mut shared = shared_reader.lock().expect("client save mutex poisoned");
*shared = (total_bytes, digest.clone());
Ok((total_bytes, digest))
@@ -135,7 +135,7 @@ pub fn mode(
// Read results from shared state
let (uncompressed_size, digest) = {
let shared = shared.lock().unwrap();
let shared = shared.lock().expect("client save mutex poisoned");
shared.clone()
};

View File

@@ -162,7 +162,7 @@ fn show_item(
item_path_buf.push(item_id.to_string());
let path_str = item_path_buf
.to_str()
.expect("Unable to get item path")
.ok_or_else(|| anyhow::anyhow!("non-UTF-8 item path"))?
.to_string();
table.add_row(vec![
Cell::new("Path").add_attribute(Attribute::Bold),

View File

@@ -240,7 +240,11 @@ pub fn mode_list(
Ok(metadata) => format_size(metadata.len(), settings.human_readable),
Err(_) => "Missing".to_string(),
},
ColumnType::FilePath => item_path.clone().into_os_string().into_string().unwrap(),
ColumnType::FilePath => item_path
.clone()
.into_os_string()
.into_string()
.unwrap_or_else(|os| os.to_string_lossy().into_owned()),
ColumnType::Tags => tags.join(" "),
ColumnType::Meta => match meta_name {
Some(meta_name) => match meta.get(meta_name) {

View File

@@ -16,7 +16,9 @@ use axum::{
use http_body_util::BodyExt;
use log::{debug, warn};
use std::collections::HashMap;
use std::io::{Cursor, Read};
use std::io::Read;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::mpsc;
use tokio::task;
@@ -24,14 +26,15 @@ use tokio::task;
///
/// Used in `spawn_blocking` contexts to consume data from an async body
/// stream as a regular reader. Blocks on each `read()` call until data
/// is available or the channel is closed.
struct ChannelReader {
/// is available or the channel is closed. No internal buffering — each
/// channel message is consumed fully before the next is fetched.
struct MpscReader {
rx: mpsc::Receiver<Result<Vec<u8>, std::io::Error>>,
current: Vec<u8>,
pos: usize,
}
impl ChannelReader {
impl MpscReader {
fn new(rx: mpsc::Receiver<Result<Vec<u8>, std::io::Error>>) -> Self {
Self {
rx,
@@ -41,9 +44,8 @@ impl ChannelReader {
}
}
impl Read for ChannelReader {
impl Read for MpscReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// If we have buffered data, return it first
if self.pos < self.current.len() {
let remaining = &self.current[self.pos..];
let n = std::cmp::min(buf.len(), remaining.len());
@@ -52,7 +54,6 @@ impl Read for ChannelReader {
return Ok(n);
}
// Need more data from the channel - block until available
match self.rx.blocking_recv() {
Some(Ok(data)) => {
let n = std::cmp::min(buf.len(), data.len());
@@ -64,7 +65,7 @@ impl Read for ChannelReader {
Ok(n)
}
Some(Err(e)) => Err(e),
None => Ok(0), // Channel closed, EOF
None => Ok(0),
}
}
}
@@ -116,23 +117,6 @@ fn get_mime_type(metadata: &HashMap<String, String>) -> String {
.unwrap_or_else(|| "application/octet-stream".to_string())
}
/// Helper function to apply offset and length to content
fn apply_offset_length(content: &[u8], offset: u64, length: u64) -> &[u8] {
let content_len = content.len() as u64;
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
if start < content_len {
&content[start as usize..end as usize]
} else {
&[]
}
}
/// Helper function to handle item not found errors
fn handle_item_error(error: CoreError) -> StatusCode {
match error {
@@ -203,7 +187,7 @@ pub async fn handle_list_items(
// Apply pagination
let start = params.start.unwrap_or(0) as usize;
let count = params.count.unwrap_or(100) as usize;
let count = params.count.unwrap_or(100).min(10000) as usize;
let items_with_meta: Vec<_> = items_with_meta
.into_iter()
.skip(start)
@@ -253,7 +237,7 @@ async fn handle_as_meta_response(
handle_as_meta_response_with_metadata(data_service, item_id, &metadata, offset, length).await
}
/// Handle as_meta=true response with pre-fetched metadata
/// Handle as_meta=true response with pre-fetched metadata using streaming
async fn handle_as_meta_response_with_metadata(
data_service: &AsyncDataService,
item_id: i64,
@@ -279,61 +263,121 @@ async fn handle_as_meta_response_with_metadata(
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
// Get the content as text
match data_service.get_item_content_info(item_id, None).await {
Ok((content, _, _)) => {
// Apply offset and length
let content_len = content.len() as u64;
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
// Use streaming approach to read content
let db = data_service.db();
let data_path = data_service.data_path().clone();
let settings = data_service.settings();
let response_content = if start < content_len {
&content[start as usize..end as usize]
} else {
&[]
};
// Get streaming reader from sync service
let (reader, content_len_result) = tokio::task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_path, settings.as_ref().clone());
let (reader, item_with_meta) = sync_service.get_content(&mut conn, item_id)?;
let content_len = item_with_meta.item.size.unwrap_or(0);
Ok::<_, CoreError>((reader, content_len))
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to get content reader for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Convert to UTF-8 string
let content_str = match String::from_utf8(response_content.to_vec()) {
Ok(s) => s,
Err(_) => {
// This shouldn't happen since we checked is_binary, but handle it just in case
let response_body = serde_json::json!({
"metadata": metadata,
"content": serde_json::Value::Null,
"error": "Content is not valid UTF-8"
});
let content_len = content_len_result as u64;
let response = Response::builder()
.header(header::CONTENT_TYPE, "application/json")
.status(StatusCode::UNPROCESSABLE_ENTITY)
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
return Ok(response);
// Calculate offset and length bounds
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
let response_len = end - start;
// Read content with offset and length using fixed-size buffer
let content_str = tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; crate::common::PIPESIZE];
let mut result = Vec::new();
let mut bytes_read = 0u64;
// Skip offset bytes
if offset > 0 {
let mut remaining = offset;
while remaining > 0 {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => remaining -= n as u64,
Err(e) => return Err(CoreError::Io(e)),
}
};
}
}
// Return JSON with metadata and content
// Read up to length bytes
let mut remaining = if length > 0 { length } else { u64::MAX };
while remaining > 0 && bytes_read < response_len {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => {
result.extend_from_slice(&buf[..n]);
bytes_read += n as u64;
if length > 0 {
remaining -= n as u64;
}
}
Err(e) => return Err(CoreError::Io(e)),
}
}
Ok::<Vec<u8>, CoreError>(result)
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to read content for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Convert to UTF-8 string
let content_str = match String::from_utf8(content_str) {
Ok(s) => s,
Err(_) => {
// This shouldn't happen since we checked is_binary, but handle it just in case
let response_body = serde_json::json!({
"metadata": metadata,
"content": content_str,
"error": serde_json::Value::Null
"content": serde_json::Value::Null,
"error": "Content is not valid UTF-8"
});
Response::builder()
let response = Response::builder()
.header(header::CONTENT_TYPE, "application/json")
.status(StatusCode::UNPROCESSABLE_ENTITY)
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
return Ok(response);
}
Err(e) => {
warn!("Failed to get content for item {item_id}: {e}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
};
// Return JSON with metadata and content
let response_body = serde_json::json!({
"metadata": metadata,
"content": content_str,
"error": serde_json::Value::Null
});
Response::builder()
.header(header::CONTENT_TYPE, "application/json")
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
}
@@ -380,6 +424,14 @@ pub async fn handle_post_item(
.map(parse_comma_tags)
.unwrap_or_default();
// Validate tag lengths
for tag in &tags {
if tag.len() > 256 {
warn!("Tag too long: {} chars (max 256)", tag.len());
return Err(StatusCode::BAD_REQUEST);
}
}
// Parse metadata from query parameter
let metadata: HashMap<String, String> = if let Some(ref meta_str) = params.metadata {
serde_json::from_str(meta_str).map_err(|e| {
@@ -390,91 +442,116 @@ pub async fn handle_post_item(
HashMap::new()
};
// Validate metadata key/value lengths
for (key, value) in &metadata {
if key.len() > 128 {
warn!("Metadata key too long: {} chars (max 128)", key.len());
return Err(StatusCode::BAD_REQUEST);
}
if value.len() > 4096 {
warn!("Metadata value too long: {} chars (max 4096)", value.len());
return Err(StatusCode::BAD_REQUEST);
}
}
let compress = params.compress;
let run_meta = params.meta;
// When server handles both compression and meta, save_item_with_reader
// buffers internally anyway, so collect body in memory.
// When client handles compression/meta, stream the body to avoid buffering.
let item_with_meta = if compress && run_meta {
let body_bytes = body
.collect()
.await
.map_err(|e| {
warn!("Failed to read request body: {e}");
StatusCode::BAD_REQUEST
})?
.to_bytes();
// Stream body through an mpsc channel with fixed-size frames.
// Size tracking ensures we never buffer the whole body in memory.
// Treat Some(0) as unlimited (None).
let max_body_size: Option<u64> = settings
.server
.as_ref()
.and_then(|s| s.max_body_size)
.filter(|&v| v > 0);
task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_dir, settings.as_ref().clone());
let mut cursor = Cursor::new(body_bytes.to_vec());
sync_service.save_item_with_reader(&mut conn, &mut cursor, tags, metadata)
})
.await
.map_err(|e| {
warn!("Failed to save item (task error): {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to save item: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
// Stream body through a channel to avoid buffering in memory
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, std::io::Error>>(16);
let (tx, rx) = mpsc::channel::<Result<Vec<u8>, std::io::Error>>(16);
let body_truncated = Arc::new(AtomicBool::new(false));
let truncated_flag = body_truncated.clone();
// Task to read body frames and send through channel
tokio::spawn(async move {
let mut body = body;
loop {
match body.frame().await {
None => break, // Body complete
Some(Err(e)) => {
let _ = tx
.send(Err(std::io::Error::other(format!("Body error: {e}"))))
.await;
break;
}
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
if tx.send(Ok(data.to_vec())).await.is_err() {
break; // Receiver dropped
}
// Async task: read body frames, track size, stop when limit exceeded
tokio::spawn(async move {
let mut body = body;
let mut total_bytes: u64 = 0;
loop {
match body.frame().await {
None => break,
Some(Err(e)) => {
let _ = tx
.send(Err(std::io::Error::other(format!("Body error: {e}"))))
.await;
break;
}
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
total_bytes += data.len() as u64;
if let Some(limit) = max_body_size
&& total_bytes > limit
{
truncated_flag.store(true, Ordering::Relaxed);
break; // Drop sender → reader sees EOF
}
if tx.send(Ok(data.to_vec())).await.is_err() {
break;
}
}
}
}
});
}
});
task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_dir, settings.as_ref().clone());
// Blocking task: consume streaming reader, save via save_item_raw_streaming
let item_with_meta = task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_dir, settings.as_ref().clone());
let mut reader = MpscReader::new(rx);
sync_service.save_item_raw_streaming(
&mut conn,
&mut reader,
tags,
metadata,
compress,
run_meta,
)
})
.await
.map_err(|e| {
warn!("Failed to save item (task error): {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to save item: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Convert async mpsc receiver into a sync Read
let mut stream_reader = ChannelReader::new(rx);
sync_service.save_item_raw_streaming(
&mut conn,
&mut stream_reader,
tags,
metadata,
compress,
run_meta,
)
})
.await
.map_err(|e| {
warn!("Failed to save item (task error): {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to save item: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
};
// If body was truncated due to size limit, return 413 but keep the partial item.
//
// Design rationale for preserving truncated data:
//
// 1. Keep is a temporary file manager — partial content is still useful data.
// The client sent it, the server received it, and it was written through the
// normal save pipeline (compression, meta plugins). Discarding it would waste
// work already done and lose data the user may want to inspect or retry from.
//
// 2. The streaming pipeline is nonfatal by design. The sender stops pushing
// frames, the channel closes cleanly (sender drop → reader sees EOF), and
// save_item_raw_streaming finishes with whatever it received. No corruption,
// no partial writes to the storage file — the compression engine flushes and
// closes normally.
//
// 3. No destructive rollback. Deleting the item would require either a second
// DB transaction (fragile under concurrency) or leaving orphaned files on
// disk. Keeping the item is simpler and safer.
//
// 4. The 413 response tells the client exactly what happened. The client can
// use GET /api/item/ to list items and find the partial one by timestamp
// if it needs to resume or clean up.
//
if body_truncated.load(Ordering::Relaxed) {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
let compression = item_with_meta.item.compression.clone();
let tags = item_with_meta.tags.iter().map(|t| t.name.clone()).collect();
@@ -696,37 +773,110 @@ async fn stream_item_content_response(
/// Stream raw (unprocessed) content directly from the item file.
///
/// Returns the stored file bytes without decompression or filtering.
/// Returns the stored file bytes without decompression or filtering using streaming.
async fn stream_raw_content_response(
data_service: &AsyncDataService,
item_id: i64,
offset: u64,
length: u64,
) -> Result<Response, StatusCode> {
// Get item info to find the file path and compression type
// Get item info to find the compression type
let item_with_meta = data_service.get_item(item_id).await.map_err(|e| {
warn!("Failed to get item {item_id} for raw content: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let compression = item_with_meta.item.compression.clone();
let content_size = item_with_meta.item.size.unwrap_or(0);
// Read raw file bytes
let content = data_service
.get_raw_item_content(item_id)
// Get streaming reader for raw content
let reader = data_service
.get_raw_item_content_reader(item_id)
.await
.map_err(|e| {
warn!("Failed to get raw content for item {item_id}: {e}");
warn!("Failed to get raw content reader for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response_content = apply_offset_length(&content, offset, length);
// Calculate the actual response length
let content_len = content_size as u64;
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
let response_len = end - start;
// Create a channel to stream data between blocking thread and async runtime
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, std::io::Error>>(16);
// Spawn blocking task to read with offset and length
tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; crate::common::PIPESIZE];
// Apply offset by reading and discarding bytes
if offset > 0 {
let mut remaining = offset;
while remaining > 0 {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => remaining -= n as u64,
Err(e) => {
let _ = tx.blocking_send(Err(e));
return;
}
}
}
}
// Read and send data up to the specified length
let mut remaining_length = length;
loop {
let to_read = if length > 0 {
std::cmp::min(remaining_length, buf.len() as u64) as usize
} else {
buf.len()
};
if to_read == 0 {
break;
}
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => {
let chunk = buf[..n].to_vec();
if tx.blocking_send(Ok(chunk)).is_err() {
break;
}
if length > 0 {
remaining_length -= n as u64;
if remaining_length == 0 {
break;
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(e));
break;
}
}
}
});
// Convert the receiver into a stream
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let body = axum::body::Body::from_stream(stream);
let response = Response::builder()
.header(header::CONTENT_TYPE, "application/octet-stream")
.header("X-Keep-Compression", &compression)
.header(header::CONTENT_LENGTH, response_content.len())
.body(axum::body::Body::from(response_content.to_vec()))
.header(header::CONTENT_LENGTH, response_len)
.body(body)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(response)
@@ -776,24 +926,94 @@ async fn stream_item_content_response_with_metadata(
}
}
} else {
debug!("NON-STREAMING: Building full response in memory");
match data_service.get_item_content_info(item_id, None).await {
Ok((content, _, _)) => {
let response_content = apply_offset_length(&content, offset, length);
debug!("NON-STREAMING: Building full response in memory using streaming reader");
// Use streaming approach even for non-streaming response
let db = data_service.db();
let data_path = data_service.data_path().clone();
let settings = data_service.settings();
debug!(
"NON-STREAMING: Content length: {}, response length: {}",
content.len(),
response_content.len()
);
// Get streaming reader from sync service
let (reader, content_len_result) = tokio::task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_path, settings.as_ref().clone());
let (reader, item_with_meta) = sync_service.get_content(&mut conn, item_id)?;
let content_len = item_with_meta.item.size.unwrap_or(0);
Ok::<_, CoreError>((reader, content_len))
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to get content reader for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
ResponseBuilder::binary(response_content, &mime_type)
let content_len = content_len_result as u64;
// Calculate offset and length bounds
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
let response_len = end - start;
// Read content with offset and length using fixed-size buffer
let content = tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; crate::common::PIPESIZE];
let mut result = Vec::new();
let mut bytes_read = 0u64;
// Skip offset bytes
if offset > 0 {
let mut remaining = offset;
while remaining > 0 {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => remaining -= n as u64,
Err(e) => return Err(CoreError::Io(e)),
}
}
}
Err(e) => {
warn!("Failed to get content for item {item_id}: {e}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
// Read up to length bytes
let mut remaining = if length > 0 { length } else { u64::MAX };
while remaining > 0 && bytes_read < response_len {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => {
result.extend_from_slice(&buf[..n]);
bytes_read += n as u64;
if length > 0 {
remaining -= n as u64;
}
}
Err(e) => return Err(CoreError::Io(e)),
}
}
}
Ok::<Vec<u8>, CoreError>(result)
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to read content for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
debug!("NON-STREAMING: Response length: {}", content.len());
ResponseBuilder::binary(&content, &mime_type)
}
}
@@ -870,6 +1090,10 @@ pub async fn handle_get_item_meta(
State(state): State<AppState>,
Path(item_id): Path<i64>,
) -> Result<Json<ApiResponse<HashMap<String, String>>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let data_service = create_data_service(&state);
match data_service.get_item(item_id).await {
@@ -893,6 +1117,10 @@ pub async fn handle_post_item_meta(
Path(item_id): Path<i64>,
Json(metadata): Json<HashMap<String, String>>,
) -> Result<Json<ApiResponse<()>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let data_service = create_data_service(&state);
// Verify item exists
@@ -942,6 +1170,10 @@ pub async fn handle_delete_item(
State(state): State<AppState>,
Path(item_id): Path<i64>,
) -> Result<Json<ApiResponse<ItemInfo>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let mut conn = state.db.lock().await;
let sync_service = crate::services::SyncDataService::new(
@@ -995,6 +1227,10 @@ pub async fn handle_get_item_info(
State(state): State<AppState>,
Path(item_id): Path<i64>,
) -> Result<Json<ApiResponse<ItemInfo>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let mut conn = state.db.lock().await;
let sync_service = crate::services::SyncDataService::new(
@@ -1097,6 +1333,28 @@ pub async fn handle_diff_items(
let id_a = item_a.item.id.ok_or(StatusCode::BAD_REQUEST)?;
let id_b = item_b.item.id.ok_or(StatusCode::BAD_REQUEST)?;
// Size limit for diff operation (10MB per item)
const MAX_DIFF_SIZE: i64 = 10 * 1024 * 1024;
if let Some(size_a) = item_a.item.size
&& size_a > MAX_DIFF_SIZE
{
warn!(
"Item A ({}) exceeds diff size limit: {} > {}",
id_a, size_a, MAX_DIFF_SIZE
);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
if let Some(size_b) = item_b.item.size
&& size_b > MAX_DIFF_SIZE
{
warn!(
"Item B ({}) exceeds diff size limit: {} > {}",
id_b, size_b, MAX_DIFF_SIZE
);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
let (mut reader_a, _) = sync_service
.get_content(&mut conn, id_a)
.map_err(handle_item_error)?;

View File

@@ -1,72 +0,0 @@
use axum::{
extract::State,
http::StatusCode,
response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::{self, Stream};
use log::{debug, info};
use std::convert::Infallible;
use std::time::Duration;
use crate::modes::server::common::AppState;
use crate::modes::server::mcp::KeepMcpServer;
#[utoipa::path(
get,
path = "/mcp/sse",
operation_id = "mcp_sse",
summary = "MCP SSE endpoint",
description = "Server-Sent Events for Model Context Protocol. Enables AI tools to interact with Keep's storage and retrieval functions.",
responses(
(status = 200, description = "SSE stream established"),
(status = 401, description = "Unauthorized"),
(status = 500, description = "Internal server error")
),
security(
("bearerAuth" = [])
),
tag = "mcp"
)]
pub async fn handle_mcp_sse(
State(state): State<AppState>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
debug!("MCP: Starting SSE endpoint");
let _mcp_server = KeepMcpServer::new(state);
// Create a simple message channel for SSE communication
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();
// Send initial connection message
let _ = tx.send("data: {\"type\":\"connection\",\"status\":\"connected\"}\n\n".to_string());
// For now, create a simple stream that sends periodic keep-alive messages
// In a full implementation, this would integrate with the rmcp transport layer
let stream = stream::unfold((rx, tx), |(mut rx, tx)| async move {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(data) => {
let event = Event::default().data(data);
Some((Ok(event), (rx, tx)))
}
None => None,
}
}
_ = tokio::time::sleep(Duration::from_secs(30)) => {
let event = Event::default()
.event("keep-alive")
.data("ping");
Some((Ok(event), (rx, tx)))
}
}
});
info!("MCP: SSE endpoint established");
Ok(Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(30))
.text("keep-alive"),
))
}

View File

@@ -1,7 +1,5 @@
pub mod common;
pub mod item;
#[cfg(feature = "mcp")]
pub mod mcp;
pub mod status;
use axum::{
@@ -60,8 +58,7 @@ use utoipa_swagger_ui::SwaggerUi;
struct ApiDoc;
pub fn add_routes(router: Router<AppState>) -> Router<AppState> {
#[cfg_attr(not(feature = "mcp"), allow(unused_mut))]
let mut router = router
router
// Status endpoints
.route("/api/status", get(status::handle_status))
.route("/api/plugins/status", get(status::handle_plugins_status))
@@ -88,14 +85,7 @@ pub fn add_routes(router: Router<AppState>) -> Router<AppState> {
)
.route("/api/item/{item_id}", delete(item::handle_delete_item))
.route("/api/item/{item_id}/info", get(item::handle_get_item_info))
.route("/api/diff", get(item::handle_diff_items));
#[cfg(feature = "mcp")]
{
router = router.route("/mcp/sse", get(mcp::handle_mcp_sse));
}
router
.route("/api/diff", get(item::handle_diff_items))
}
#[cfg(feature = "swagger")]

View File

@@ -39,7 +39,7 @@ use crate::modes::server::common::{ApiResponse, AppState, StatusInfoResponse};
///
/// # Examples
///
/// ```
/// ```ignore
/// // In an Axum app:
/// async fn app() -> Result<Json<StatusInfoResponse>, StatusCode> {
/// handle_status(State(app_state)).await
@@ -60,12 +60,17 @@ pub async fn handle_status(
// Use the status service to generate status info showing configured plugins
let status_service = crate::services::status_service::StatusService::new();
let mut cmd = state.cmd.lock().await;
let status_info = status_service.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
);
let status_info = status_service
.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
)
.map_err(|e| {
log::warn!("Failed to generate status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response = StatusInfoResponse {
success: true,
@@ -112,12 +117,17 @@ pub async fn handle_plugins_status(
let status_service = crate::services::status_service::StatusService::new();
let mut cmd = state.cmd.lock().await;
let status_info = status_service.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
);
let status_info = status_service
.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
)
.map_err(|e| {
log::warn!("Failed to generate status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response_data = PluginsStatusResponse {
meta_plugins: status_info.meta_plugins,

View File

@@ -726,34 +726,33 @@ fn check_basic_auth(
}
let encoded = &auth_str[6..];
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
if let Some(colon_pos) = decoded_str.find(':') {
let provided_username = &decoded_str[..colon_pos];
let provided_password = &decoded_str[colon_pos + 1..];
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded)
&& let Ok(decoded_str) = String::from_utf8(decoded_bytes)
&& let Some(colon_pos) = decoded_str.find(':')
{
let provided_username = &decoded_str[..colon_pos];
let provided_password = &decoded_str[colon_pos + 1..];
// Check username with constant-time comparison
if !bool::from(
provided_username
.as_bytes()
.ct_eq(expected_username.as_bytes()),
) {
return false;
}
// If we have a password hash, verify against it
if let Some(hash) = expected_hash {
return pwhash::unix::verify(provided_password, hash);
}
// Otherwise, do constant-time comparison to prevent timing attacks
return bool::from(
provided_password
.as_bytes()
.ct_eq(expected_password.as_bytes()),
);
}
// Check username with constant-time comparison
if !bool::from(
provided_username
.as_bytes()
.ct_eq(expected_username.as_bytes()),
) {
return false;
}
// If we have a password hash, verify against it
if let Some(hash) = expected_hash {
return pwhash::unix::verify(provided_password, hash);
}
// Otherwise, do constant-time comparison to prevent timing attacks
return bool::from(
provided_password
.as_bytes()
.ct_eq(expected_password.as_bytes()),
);
}
false
}

View File

@@ -1,83 +0,0 @@
pub mod server;
pub mod tools;
pub use server::KeepMcpServer;
/// Module for handling MCP (Model Context Protocol) requests in the server.
///
/// Provides handlers for JSON-RPC style requests to interact with Keep's storage
/// via the API.
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
use serde::Deserialize;
use serde_json::Value;
use crate::modes::server::common::ApiResponse;
use crate::modes::server::common::AppState;
/// Request structure for MCP JSON-RPC calls.
///
/// # Fields
///
/// * `method` - The MCP method name (e.g., "save_item").
/// * `params` - Optional JSON parameters for the method.
#[derive(Deserialize)]
pub struct McpRequest {
pub method: String,
pub params: Option<Value>,
}
/// Handles an MCP request via the Axum framework.
///
/// Parses the JSON request, delegates to `KeepMcpServer`, and returns an API response.
/// Attempts to parse the result as JSON; falls back to string if invalid.
///
/// # Arguments
///
/// * `State(state)` - The application state.
/// * `Json(request)` - The deserialized MCP request.
///
/// # Returns
///
/// An `IntoResponse` with status code and JSON API response.
///
/// # Errors
///
/// Returns 400 Bad Request on handler errors.
pub async fn handle_mcp_request(
State(state): State<AppState>,
Json(request): Json<McpRequest>,
) -> impl IntoResponse {
let mcp_server = KeepMcpServer::new(state);
match mcp_server
.handle_request(&request.method, request.params)
.await
{
Ok(result) => match serde_json::from_str(&result) {
Ok(parsed_result) => {
let response = ApiResponse {
success: true,
data: Some(parsed_result),
error: None,
};
(StatusCode::OK, Json(response))
}
Err(_) => {
let response = ApiResponse {
success: true,
data: Some(serde_json::Value::String(result)),
error: None,
};
(StatusCode::OK, Json(response))
}
},
Err(e) => {
let response = ApiResponse {
success: false,
data: None,
error: Some(e.to_string()),
};
(StatusCode::BAD_REQUEST, Json(response))
}
}
}

View File

@@ -1,83 +0,0 @@
use log::debug;
use serde_json::Value;
use super::tools::{KeepTools, ToolError};
use crate::modes::server::common::AppState;
/// Server handler for MCP (Model Context Protocol) requests.
///
/// Routes requests to appropriate tools and handles responses. Clones AppState for tool usage.
///
/// # Fields
///
/// * `state` - The shared application state (DB, config, etc.).
#[derive(Clone)]
pub struct KeepMcpServer {
state: AppState,
}
/// Creates a new `KeepMcpServer` instance.
///
/// # Arguments
///
/// * `state` - The application state containing DB, config, and services.
///
/// # Returns
///
/// A new `KeepMcpServer` instance.
///
/// # Examples
///
/// ```
/// let server = KeepMcpServer::new(app_state);
/// ```
impl KeepMcpServer {
pub fn new(state: AppState) -> Self {
Self { state }
}
/// Handles an MCP request by routing to the appropriate tool.
///
/// Supports methods like "save_item", "get_item", "list_items". Logs the request and delegates to KeepTools.
///
/// # Arguments
///
/// * `method` - The MCP method name (string).
/// * `params` - Optional JSON parameters as serde_json::Value.
///
/// # Returns
///
/// `Ok(String)` with JSON-serialized response on success, or `Err(ToolError)` on failure.
///
/// # Errors
///
/// * ToolError::UnknownTool if method unsupported.
/// * Propagates tool-specific errors (e.g., invalid args, DB failures).
///
/// # Examples
///
/// ```
/// let result = server.handle_request("save_item", Some(params)).await?;
/// ```
pub async fn handle_request(
&self,
method: &str,
params: Option<Value>,
) -> Result<String, ToolError> {
debug!(
"MCP: Handling request '{}' with params: {:?}",
method, params
);
let tools = KeepTools::new(self.state.clone());
match method {
"save_item" => tools.save_item(params).await,
"get_item" => tools.get_item(params).await,
"get_latest_item" => tools.get_latest_item(params).await,
"list_items" => tools.list_items(params).await,
"search_items" => tools.search_items(params).await,
_ => Err(ToolError::UnknownTool(method.to_string())),
}
}
}

View File

@@ -1,344 +0,0 @@
use anyhow::{Result, anyhow};
use log::debug;
use serde_json::Value;
use std::collections::HashMap;
use crate::modes::server::common::AppState;
use crate::services::async_item_service::AsyncItemService;
use crate::services::error::CoreError;
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
#[error("Unknown tool: {0}")]
UnknownTool(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error("Database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Parse error: {0}")]
Parse(#[from] strum::ParseError),
#[error("Other error: {0}")]
Other(#[from] anyhow::Error),
}
pub struct KeepTools {
state: AppState,
}
impl KeepTools {
pub fn new(state: AppState) -> Self {
Self { state }
}
pub async fn save_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let args =
args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidArguments("Missing 'content' field".to_string()))?;
let tags: Vec<String> = args
.get("tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let metadata: HashMap<String, String> = args
.get("metadata")
.and_then(|v| v.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
debug!(
"MCP: Saving item with {} bytes, {} tags, {} metadata entries",
content.len(),
tags.len(),
metadata.len()
);
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let item_with_meta = service
.save_item_from_mcp(content.as_bytes().to_vec(), tags, metadata)
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
let item_id = item_with_meta
.item
.id
.ok_or_else(|| anyhow!("Failed to get item ID"))?;
Ok(format!("Successfully saved item with ID: {}", item_id))
}
pub async fn get_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let args =
args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?;
let item_id = args.get("id").and_then(|v| v.as_i64()).ok_or_else(|| {
ToolError::InvalidArguments("Missing or invalid 'id' field".to_string())
})?;
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let item_with_content = match service.get_item_content(item_id).await {
Ok(iwc) => iwc,
Err(CoreError::ItemNotFound(_)) => {
return Err(ToolError::InvalidArguments(format!(
"Item {} not found",
item_id
)));
}
Err(e) => return Err(ToolError::Other(anyhow::Error::from(e))),
};
let content = String::from_utf8_lossy(&item_with_content.content).to_string();
let tags: Vec<String> = item_with_content
.item_with_meta
.tags
.iter()
.map(|t| t.name.clone())
.collect();
let metadata = item_with_content.item_with_meta.meta_as_map();
let item = item_with_content.item_with_meta.item;
let response = serde_json::json!({
"id": item_id,
"content": content,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": tags,
"metadata": metadata,
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn get_latest_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let tags: Vec<String> = args
.as_ref()
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let item_with_meta = match service.find_item(vec![], tags, HashMap::new()).await {
Ok(iwm) => iwm,
Err(CoreError::ItemNotFoundGeneric) => {
return Err(ToolError::InvalidArguments("No items found".to_string()));
}
Err(e) => return Err(ToolError::Other(anyhow::Error::from(e))),
};
let item_id = item_with_meta
.item
.id
.ok_or_else(|| anyhow!("Item missing ID after find"))?;
let item_with_content = service
.get_item_content(item_id)
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
let content = String::from_utf8_lossy(&item_with_content.content).to_string();
let tags: Vec<String> = item_with_content
.item_with_meta
.tags
.iter()
.map(|t| t.name.clone())
.collect();
let metadata = item_with_content.item_with_meta.meta_as_map();
let item = item_with_content.item_with_meta.item;
let response = serde_json::json!({
"id": item_id,
"content": content,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": tags,
"metadata": metadata,
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn list_items(&self, args: Option<Value>) -> Result<String, ToolError> {
let args_ref = args.as_ref();
let tags: Vec<String> = args_ref
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let limit = args_ref
.and_then(|v| v.get("limit"))
.and_then(|v| v.as_u64())
.unwrap_or(10) as usize;
let offset = args_ref
.and_then(|v| v.get("offset"))
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let mut items_with_meta = service
.list_items(tags, HashMap::new())
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
// Sort by timestamp (newest first) and apply pagination
items_with_meta.sort_by(|a, b| b.item.ts.cmp(&a.item.ts));
let items_with_meta: Vec<_> = items_with_meta
.into_iter()
.skip(offset)
.take(limit)
.collect();
let items_info: Vec<_> = items_with_meta
.into_iter()
.map(|item_with_meta| {
let item_tags: Vec<String> =
item_with_meta.tags.iter().map(|t| t.name.clone()).collect();
let item_meta = item_with_meta.meta_as_map();
let item = item_with_meta.item;
let item_id = item.id.unwrap_or(0);
serde_json::json!({
"id": item_id,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": item_tags,
"metadata": item_meta
})
})
.collect();
let response = serde_json::json!({
"items": items_info,
"count": items_info.len(),
"offset": offset,
"limit": limit
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn search_items(&self, args: Option<Value>) -> Result<String, ToolError> {
let tags: Vec<String> = args
.as_ref()
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let metadata: HashMap<String, String> = args
.as_ref()
.and_then(|v| v.get("metadata"))
.and_then(|v| v.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let mut items_with_meta = service
.list_items(tags.clone(), metadata.clone())
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
// Sort by timestamp (newest first)
items_with_meta.sort_by(|a, b| b.item.ts.cmp(&a.item.ts));
let items_info: Vec<_> = items_with_meta
.into_iter()
.map(|item_with_meta| {
let item_tags: Vec<String> =
item_with_meta.tags.iter().map(|t| t.name.clone()).collect();
let item_meta = item_with_meta.meta_as_map();
let item = item_with_meta.item;
let item_id = item.id.unwrap_or(0);
serde_json::json!({
"id": item_id,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": item_tags,
"metadata": item_meta
})
})
.collect();
let response = serde_json::json!({
"items": items_info,
"count": items_info.len(),
"search_criteria": {
"tags": tags,
"metadata": metadata
}
});
Ok(serde_json::to_string_pretty(&response)?)
}
}

View File

@@ -1,7 +1,10 @@
use crate::config;
use crate::services::item_service::ItemService;
use anyhow::Result;
use axum::{Router, routing::post};
use axum::Router;
use axum::http::{HeaderValue, header};
use axum::middleware::Next;
use axum::response::Response;
use clap::Command;
use log::{debug, info};
use std::net::SocketAddr;
@@ -15,12 +18,26 @@ use tower_http::trace::TraceLayer;
mod api;
pub mod auth;
pub mod common;
#[cfg(feature = "mcp")]
mod mcp;
mod pages;
pub use common::{AppState, create_auth_middleware, logging_middleware};
/// Adds security headers to all responses.
async fn security_headers(req: axum::extract::Request, next: Next) -> Response {
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
);
headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
headers.insert(
header::REFERRER_POLICY,
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
response
}
pub fn mode_server(
cmd: &mut Command,
settings: &config::Settings,
@@ -107,23 +124,10 @@ async fn run_server(
settings: Arc::new(settings.clone()),
};
#[cfg(feature = "mcp")]
let mcp_router = Router::new()
.route("/mcp", post(mcp::handle_mcp_request))
.with_state(state.clone());
#[cfg_attr(not(feature = "mcp"), allow(unused_mut))]
let mut protected_router = Router::new()
let protected_router = Router::new()
.merge(api::add_routes(Router::new()))
.merge(pages::add_routes(Router::new()));
#[cfg(feature = "mcp")]
{
protected_router = protected_router.merge(mcp_router);
}
let protected_router =
protected_router.layer(axum::middleware::from_fn(create_auth_middleware(
.merge(pages::add_routes(Router::new()))
.layer(axum::middleware::from_fn(create_auth_middleware(
config.username.clone(),
config.password.clone(),
config.password_hash.clone(),
@@ -152,18 +156,19 @@ async fn run_server(
axum::http::Method::PUT,
axum::http::Method::DELETE,
])
.allow_headers(tower_http::cors::Any)
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
};
// Create the app with documentation routes open and others protected
let app = Router::new()
// Add documentation routes without authentication
.merge(api::add_docs_routes(Router::new()))
// Add API, pages, and MCP routes with authentication
// Add API and pages routes with authentication
.merge(protected_router)
// Apply state to all routes
.with_state(state)
// Add other middleware layers to all routes
.layer(axum::middleware::from_fn(security_headers))
.layer(axum::middleware::from_fn(logging_middleware))
.layer(
ServiceBuilder::new()

View File

@@ -6,11 +6,22 @@ use axum::{
extract::{Path, Query, State},
response::{Html, Response},
};
use html_escape::{encode_double_quoted_attribute, encode_text};
use log::debug;
use rusqlite::Connection;
use serde::Deserialize;
use std::collections::HashMap;
/// Escape text content for safe HTML insertion.
fn esc(s: &str) -> String {
encode_text(s).to_string()
}
/// Escape attribute values for safe HTML attribute insertion.
fn esc_attr(s: &str) -> String {
encode_double_quoted_attribute(s).to_string()
}
#[derive(Deserialize)]
/// Query parameters for the item list endpoint.
///
@@ -62,7 +73,7 @@ fn default_count() -> usize {
///
/// # Examples
///
/// ```
/// ```ignore
/// let app = pages::add_routes(axum::Router::new());
/// ```
pub fn add_routes(app: axum::Router<AppState>) -> axum::Router<AppState> {
@@ -90,7 +101,9 @@ async fn list_items(
.map_err(|_| Html("<html><body>Internal Server Error</body></html>".to_string()))?;
Ok(response)
}
Err(e) => Err(Html(format!("<html><body>Error: {e}</body></html>"))),
Err(_e) => Err(Html(
"<html><body>An internal error occurred</body></html>".to_string(),
)),
}
}
@@ -121,7 +134,8 @@ fn build_item_list(
// Apply pagination
let start = params.start;
let end = std::cmp::min(start + params.count, sorted_items.len());
let count = params.count.min(10000);
let end = std::cmp::min(start + count, sorted_items.len());
let page_items = if start < sorted_items.len() {
sorted_items[start..std::cmp::min(end, sorted_items.len())].to_vec()
} else {
@@ -153,11 +167,11 @@ fn build_item_list(
// Collect all tags from all items, keeping track of their timestamps
let mut all_tags_with_time: Vec<(String, chrono::DateTime<chrono::Utc>)> = Vec::new();
for item in &sorted_items {
if let Some(item_id) = item.id {
if let Some(tags) = tags_map.get(&item_id) {
for tag in tags {
all_tags_with_time.push((tag.name.clone(), item.ts));
}
if let Some(item_id) = item.id
&& let Some(tags) = tags_map.get(&item_id)
{
for tag in tags {
all_tags_with_time.push((tag.name.clone(), item.ts));
}
}
}
@@ -184,7 +198,9 @@ fn build_item_list(
html.push_str("<p>");
for tag in recent_tags {
html.push_str(&format!(
"<a href=\"/?tags={tag}\" style=\"margin-right: 8px;\">{tag}</a>"
"<a href=\"/?tags={}\" style=\"margin-right: 8px;\">{}</a>",
esc_attr(&tag),
esc(&tag)
));
}
html.push_str("</p>");
@@ -196,7 +212,7 @@ fn build_item_list(
// Table headers
html.push_str("<tr>");
for column in columns {
html.push_str(&format!("<th>{}</th>", column.label));
html.push_str(&format!("<th>{}</th>", esc(&column.label)));
}
html.push_str("<th>Actions</th>");
html.push_str("</tr>");
@@ -229,7 +245,13 @@ fn build_item_list(
// Make sure we're using all tags for the item
let tag_links: Vec<String> = tags
.iter()
.map(|t| format!("<a href=\"/?tags={}\">{}</a>", t.name, t.name))
.map(|t| {
format!(
"<a href=\"/?tags={}\">{}</a>",
esc_attr(&t.name),
esc(&t.name)
)
})
.collect();
tag_links.join(", ")
}
@@ -268,7 +290,15 @@ fn build_item_list(
crate::config::ColumnAlignment::Center => "text-align: center;",
};
html.push_str(&format!("<td style=\"{align_style}\">{display_value}</td>"));
let rendered_value = if column.name == "tags" {
display_value // Already contains escaped HTML links
} else {
esc(&display_value)
};
html.push_str(&format!(
"<td style=\"{align_style}\">{rendered_value}</td>"
));
}
// Actions column
@@ -361,7 +391,9 @@ async fn show_item(
.map_err(|_| Html("<html><body>Internal Server Error</body></html>".to_string()))?;
Ok(response)
}
Err(e) => Err(Html(format!("<html><body>Error: {e}</body></html>"))),
Err(_e) => Err(Html(
"<html><body>An internal error occurred</body></html>".to_string(),
)),
}
}
@@ -396,7 +428,7 @@ fn build_item_details(conn: &Connection, id: i64) -> Result<String> {
));
html.push_str(&format!(
"<tr><th>Compression</th><td>{}</td></tr>",
item.compression
esc(&item.compression)
));
// Tags row
@@ -406,7 +438,13 @@ fn build_item_details(conn: &Connection, id: i64) -> Result<String> {
} else {
let tag_links: Vec<String> = tags
.iter()
.map(|t| format!("<a href=\"/?tags={}\">{}</a>", t.name, t.name))
.map(|t| {
format!(
"<a href=\"/?tags={}\">{}</a>",
esc_attr(&t.name),
esc(&t.name)
)
})
.collect();
html.push_str(&tag_links.join(", "));
}
@@ -419,7 +457,8 @@ fn build_item_details(conn: &Connection, id: i64) -> Result<String> {
for meta in metas {
html.push_str(&format!(
"<tr><th>{}</th><td>{}</td></tr>",
meta.name, meta.value
esc(&meta.name),
esc(&meta.value)
));
}
}

View File

@@ -198,7 +198,7 @@ pub fn mode_status(
let status_service = crate::services::status_service::StatusService::new();
let output_format = crate::modes::common::settings_output_format(settings);
debug!("STATUS: About to generate status info");
let status_info = status_service.generate_status(cmd, settings, data_path, db_path);
let status_info = status_service.generate_status(cmd, settings, data_path, db_path)?;
debug!("STATUS: Status info generated successfully");
match output_format {

View File

@@ -298,7 +298,7 @@ pub fn mode_status_plugins(
let status_service = crate::services::status_service::StatusService::new();
let output_format = crate::modes::common::settings_output_format(settings);
debug!("STATUS_PLUGINS: About to generate status info");
let status_info = status_service.generate_status(cmd, settings, data_path, db_path);
let status_info = status_service.generate_status(cmd, settings, data_path, db_path)?;
debug!("STATUS_PLUGINS: Status info generated successfully");
match output_format {

View File

@@ -81,24 +81,6 @@ impl AsyncDataService {
DataService::find_item(self, &mut conn, ids, tags, meta)
}
pub async fn get_item_content_info(
&self,
id: i64,
_filter: Option<String>,
) -> Result<(Vec<u8>, ItemWithMeta, bool), CoreError> {
let mut conn = self.db.lock().await;
let (mut reader, item_with_meta) = self.get_content(&mut conn, id)?;
let mut content = Vec::new();
reader.read_to_end(&mut content)?;
let is_binary = item_with_meta
.meta
.iter()
.find(|m| m.name == "text")
.map(|m| m.value == "false")
.unwrap_or(false);
Ok((content, item_with_meta, is_binary))
}
pub async fn get_item_content_info_streaming(
&self,
id: i64,
@@ -196,27 +178,29 @@ impl AsyncDataService {
Ok((Box::pin(stream), content_length))
}
/// Get raw item content without decompression.
/// Get raw item content without decompression as a streaming reader.
///
/// Reads the stored file bytes directly from disk, bypassing decompression.
/// Opens the stored file directly from disk, bypassing decompression.
/// Used when the client requests raw bytes with `decompress=false`.
pub async fn get_raw_item_content(&self, id: i64) -> Result<Vec<u8>, CoreError> {
/// Returns a boxed reader that can be used for streaming.
pub async fn get_raw_item_content_reader(
&self,
id: i64,
) -> Result<Box<dyn Read + Send>, CoreError> {
let data_path = self.data_path.clone();
tokio::task::spawn_blocking(move || {
let mut item_path = data_path;
item_path.push(id.to_string());
let mut file = std::fs::File::open(&item_path).map_err(|e| {
let file = std::fs::File::open(&item_path).map_err(|e| {
CoreError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Item file not found: {item_path:?}: {e}"),
))
})?;
let mut content = Vec::new();
file.read_to_end(&mut content)?;
Ok(content)
Ok(Box::new(file) as Box<dyn Read + Send>)
})
.await
.map_err(|e| CoreError::Other(anyhow::anyhow!("Task join error: {}", e)))?
@@ -295,6 +279,6 @@ impl DataService for AsyncDataService {
settings,
data_path.to_path_buf(),
db_path.to_path_buf(),
))
)?)
}
}

View File

@@ -101,17 +101,6 @@ impl AsyncItemService {
.await
}
pub async fn get_item_content_info(
&self,
id: i64,
filter: Option<String>,
) -> Result<(Vec<u8>, String, bool), CoreError> {
self.execute_blocking(move |conn, item_service| {
item_service.get_item_content_info(conn, id, filter)
})
.await
}
pub async fn stream_item_content_by_id(
&self,
item_id: i64,
@@ -131,41 +120,13 @@ impl AsyncItemService {
),
CoreError,
> {
let content = self
// Use streaming approach: get reader and stream chunks in requested range
let (reader, mime_type, is_binary) = self
.execute_blocking(move |conn, item_service| {
let item_with_content = item_service.get_item_content(conn, item_id)?;
Ok::<_, CoreError>(item_with_content.content)
item_service.get_item_content_info_streaming(conn, item_id, None)
})
.await?;
// Clone content for use in the binary check closure
let content_clone = content.clone();
// Get metadata to determine MIME type and binary status
let (mime_type, is_binary) = {
let db = self.db.clone();
let item_service = self.item_service.clone();
tokio::task::spawn_blocking(move || {
let conn = db.blocking_lock();
let item_with_meta = item_service.get_item(&conn, item_id)?;
let metadata = item_with_meta.meta_as_map();
let mime_type = metadata
.get("mime_type")
.map(|s| s.to_string())
.unwrap_or_else(|| "application/octet-stream".to_string());
let is_binary = crate::common::is_binary::is_content_binary_from_metadata(
&metadata,
&content_clone,
);
Ok::<_, CoreError>((mime_type, is_binary))
})
.await
.unwrap()?
};
// Check if content is binary when allow_binary is false
if !allow_binary && is_binary {
return Err(CoreError::InvalidInput(
@@ -173,26 +134,76 @@ impl AsyncItemService {
));
}
// Create a stream that reads only the requested portion
let content_len = content.len() as u64;
// Convert the reader into an async stream with offset and length applied
use tokio_util::bytes::Bytes;
// Apply offset and length constraints
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
// Create a channel to stream data between the blocking thread and async runtime
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::io::Error>>(16);
let stream = if start < content_len {
let chunk =
tokio_util::bytes::Bytes::from(content[start as usize..end as usize].to_vec());
Box::pin(tokio_stream::iter(vec![Ok(chunk)]))
} else {
Box::pin(tokio_stream::iter(vec![]))
};
// Spawn a blocking task to read from the reader and send chunks
tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; PIPESIZE];
Ok((stream, mime_type))
// Apply offset by reading and discarding bytes
if offset > 0 {
let mut remaining = offset;
while remaining > 0 {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break, // EOF reached before offset
Ok(n) => remaining -= n as u64,
Err(e) => {
let _ = tx.blocking_send(Err(e));
return;
}
}
}
}
// Read and send data up to the specified length
let mut remaining_length = length;
loop {
// Determine how much to read in this iteration
let to_read = if length > 0 {
// If length is specified, don't read more than remaining_length
std::cmp::min(remaining_length, buf.len() as u64) as usize
} else {
buf.len()
};
if to_read == 0 {
break; // We've read the requested length
}
match reader.read(&mut buf[..to_read]) {
Ok(0) => break, // EOF
Ok(n) => {
let chunk = Bytes::copy_from_slice(&buf[..n]);
// Block on sending to the channel
if tx.blocking_send(Ok(chunk)).is_err() {
break; // Receiver dropped
}
if length > 0 {
remaining_length -= n as u64;
if remaining_length == 0 {
break; // Reached the requested length
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(e));
break;
}
}
}
});
// Convert the receiver into a stream
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
Ok((Box::pin(stream), mime_type))
}
pub async fn stream_item_content_by_id_with_metadata(
@@ -374,28 +385,6 @@ impl AsyncItemService {
item_service.delete_item(&mut conn, id)
})
.await
.unwrap()
}
pub async fn save_item_from_mcp(
&self,
content: Vec<u8>,
tags: Vec<String>,
metadata: HashMap<String, String>,
) -> Result<ItemWithMeta, CoreError> {
let db = self.db.clone();
let item_service = self.item_service.clone();
let cmd = self.cmd.clone();
let settings = self.settings.clone();
tokio::task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let mut cmd = cmd.blocking_lock();
let settings = settings.as_ref();
item_service
.save_item_from_mcp(&content, &tags, &metadata, &mut cmd, settings, &mut conn)
})
.await
.unwrap()
.map_err(|e| CoreError::Other(anyhow::anyhow!("task join error: {e}")))?
}
}

View File

@@ -188,11 +188,12 @@ static FILTER_PLUGIN_REGISTRY: Lazy<Mutex<HashMap<String, FilterConstructor>>> =
/// ```ignore
/// register_filter_plugin("custom_filter", || Box::new(CustomFilter::default()));
/// ```
pub fn register_filter_plugin(name: &str, constructor: FilterConstructor) {
pub fn register_filter_plugin(name: &str, constructor: FilterConstructor) -> anyhow::Result<()> {
FILTER_PLUGIN_REGISTRY
.lock()
.unwrap()
.map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}"))?
.insert(name.to_string(), constructor);
Ok(())
}
/// Retrieves a snapshot of all registered filter plugins.
@@ -214,6 +215,9 @@ pub fn register_filter_plugin(name: &str, constructor: FilterConstructor) {
/// let plugins = get_available_filter_plugins();
/// // Plugins are registered at startup via ctors; specific names may vary by configuration.
/// ```
pub fn get_available_filter_plugins() -> HashMap<String, FilterConstructor> {
FILTER_PLUGIN_REGISTRY.lock().unwrap().clone()
pub fn get_available_filter_plugins() -> anyhow::Result<HashMap<String, FilterConstructor>> {
FILTER_PLUGIN_REGISTRY
.lock()
.map_err(|e| anyhow::anyhow!("plugin registry poisoned: {e}"))
.map(|guard| guard.clone())
}

View File

@@ -106,6 +106,8 @@ impl ItemService {
/// Retrieves an item with its content, metadata, and tags.
///
/// Loads the item, its metadata/tags, and decompresses the full content.
/// This method is intended for CLI use only and has a size guard (100MB).
/// For larger items or server use, use `get_item_content_info_streaming`.
///
/// # Arguments
///
@@ -120,6 +122,7 @@ impl ItemService {
///
/// * `CoreError::ItemNotFound(id)` - If the item does not exist.
/// * `CoreError::Io(...)` - If file read or decompression fails.
/// * `CoreError::InvalidInput(...)` - If item exceeds 100MB size limit.
///
/// # Examples
///
@@ -132,6 +135,9 @@ impl ItemService {
conn: &Connection,
id: i64,
) -> Result<ItemWithContent, CoreError> {
// Size limit for loading entire content into memory (100MB)
const MAX_CONTENT_SIZE: i64 = 100 * 1024 * 1024;
debug!("ITEM_SERVICE: Getting item content for id: {id}");
let item_with_meta = self.get_item(conn, id)?;
let item_id = item_with_meta
@@ -145,6 +151,16 @@ impl ItemService {
)));
}
// Check size guard before loading content
if let Some(size) = item_with_meta.item.size
&& size > MAX_CONTENT_SIZE
{
return Err(CoreError::InvalidInput(format!(
"Item {} exceeds size limit ({} > {}). Use streaming API for large items.",
item_id, size, MAX_CONTENT_SIZE
)));
}
let mut item_path = self.data_path.clone();
item_path.push(item_id.to_string());
debug!("ITEM_SERVICE: Reading content from path: {item_path:?}");
@@ -164,47 +180,6 @@ impl ItemService {
})
}
/// Retrieves item content with binary detection and optional filtering.
///
/// Loads content, applies filters if specified, and determines MIME type and binary status.
///
/// # Arguments
///
/// * `conn` - Database connection.
/// * `id` - Item ID.
/// * `filter` - Optional filter string to apply to content.
///
/// # Returns
///
/// * `Result<(Vec<u8>, String, bool), CoreError>` - (content, MIME type, is_binary).
///
/// # Errors
///
/// * `CoreError::ItemNotFound(id)` - If item not found.
/// * Filter or compression errors.
///
/// # Examples
///
/// ```ignore
/// let (content, mime, is_binary) = item_service.get_item_content_info(&conn, 1, Some("head_lines(10)"))?;
/// ```
pub fn get_item_content_info(
&self,
conn: &Connection,
id: i64,
filter: Option<String>,
) -> Result<(Vec<u8>, String, bool), CoreError> {
// Use streaming approach to handle all filtering options consistently
let (mut reader, mime_type, is_binary) =
self.get_item_content_info_streaming(conn, id, filter)?;
// Read all the filtered content into a buffer
let mut content = Vec::new();
reader.read_to_end(&mut content)?;
Ok((content, mime_type, is_binary))
}
/// Determines if item content is binary based on metadata or sampling.
///
/// Checks existing "text" metadata first; if absent, samples the first 8192 bytes.
@@ -717,110 +692,6 @@ impl ItemService {
Ok(item)
}
/// Saves pre-loaded content as a new item, typically from MCP (Machine-Common-Processing) sources.
///
/// Bypasses streaming read, directly writes content and applies metadata/plugins.
///
/// # Arguments
///
/// * `content` - Byte slice of content to save.
/// * `tags` - Tags to associate.
/// * `metadata` - Initial metadata key-value pairs.
/// * `cmd` - Mutable command.
/// * `settings` - Settings.
/// * `conn` - Mutable database connection.
///
/// # Returns
///
/// * `Result<ItemWithMeta, CoreError>` - The saved item with full details.
///
/// # Errors
///
/// * `CoreError::Database(...)` - If DB insert fails.
/// * `CoreError::Io(...)` - If file write fails.
///
/// # Examples
///
/// ```ignore
/// let content = b"Hello, world!";
/// let tags = vec!["mcp".to_string()];
/// let meta = HashMap::from([("source".to_string(), "api".to_string())]);
/// let item = service.save_item_from_mcp(content, &tags, &meta, &mut cmd, &settings, &mut conn)?;
/// ```
pub fn save_item_from_mcp(
&self,
content: &[u8],
tags: &Vec<String>,
metadata: &HashMap<String, String>,
cmd: &mut Command,
settings: &Settings,
conn: &mut Connection,
) -> Result<ItemWithMeta, CoreError> {
debug!(
"ITEM_SERVICE: Starting save_item_from_mcp with {} bytes, {} tags, {} metadata entries",
content.len(),
tags.len(),
metadata.len()
);
let compression_type = CompressionType::LZ4;
let compression_engine = get_compression_engine(compression_type.clone())?;
let item_id;
let mut item;
{
item = db::create_item(conn, compression_type.clone())?;
item_id = item
.id
.ok_or_else(|| CoreError::InvalidInput("Item missing ID".to_string()))?;
debug!("ITEM_SERVICE: Created MCP item with id: {item_id}");
// Add tags
for tag in tags {
db::add_tag(conn, item_id, tag)?;
}
debug!("ITEM_SERVICE: Added {} tags to MCP item", tags.len());
// Add custom metadata
for (key, value) in metadata {
db::add_meta(conn, item_id, key, value)?;
}
debug!(
"ITEM_SERVICE: Added {} custom metadata entries to MCP item",
metadata.len()
);
}
let mut item_path = self.data_path.clone();
item_path.push(item_id.to_string());
debug!("ITEM_SERVICE: Writing MCP item to path: {item_path:?}");
let mut writer = compression_engine.create(item_path.clone())?;
writer.write_all(content)?;
drop(writer);
let mut plugins = self.meta_service.get_plugins(cmd, settings);
debug!(
"ITEM_SERVICE: Got {} configured meta plugins for MCP item",
plugins.len()
);
self.meta_service
.initialize_plugins(&mut plugins, conn, item_id);
self.meta_service
.process_chunk(&mut plugins, content, conn, item_id);
self.meta_service
.finalize_plugins(&mut plugins, conn, item_id);
debug!("ITEM_SERVICE: Processed MCP item through configured meta plugins");
item.size = Some(content.len() as i64);
db::update_item(conn, item.clone())?;
debug!("ITEM_SERVICE: MCP item saved successfully");
self.get_item(conn, item_id)
}
/// Returns a reference to the internal compression service.
///
/// # Returns

View File

@@ -74,7 +74,7 @@ impl StatusService {
settings: &Settings,
data_path: PathBuf,
db_path: PathBuf,
) -> StatusInfo {
) -> anyhow::Result<StatusInfo> {
// Get meta plugins directly from config
let meta_plugin_types: Vec<MetaPluginType> =
crate::modes::common::settings_meta_plugin_types(cmd, settings);
@@ -91,10 +91,10 @@ impl StatusService {
db_path,
&meta_plugin_types,
enabled_compression_type,
);
)?;
// Add detailed filter plugins information
let filter_plugins_map = get_available_filter_plugins();
let filter_plugins_map = get_available_filter_plugins()?;
let mut filter_plugins_info = Vec::new();
for (name, creator) in filter_plugins_map {
@@ -114,7 +114,7 @@ impl StatusService {
// Add configured meta plugins information
status_info.configured_meta_plugins = settings.meta_plugins.clone();
status_info
Ok(status_info)
}
}

View File

@@ -57,34 +57,6 @@ impl SyncDataService {
.save_item(content, cmd, settings, tags, conn)
}
pub fn save_item_with_reader<R: Read>(
&self,
conn: &mut Connection,
reader: &mut R,
tags: Vec<String>,
metadata: HashMap<String, String>,
) -> Result<ItemWithMeta, CoreError> {
let mut cmd = Command::new("keep");
let settings = &self.settings;
let mut tags = tags;
// Read content from reader
let mut content = Vec::new();
reader.read_to_end(&mut content)?;
let item = self.save_item(&*content, &mut cmd, settings, &mut tags, conn)?;
let item_id = item
.id
.ok_or_else(|| CoreError::InvalidInput("Item missing ID".to_string()))?;
// Set metadata
for (key, value) in metadata {
crate::db::add_meta(conn, item_id, &key, &value)?;
}
self.get_item(conn, item_id)
}
/// Save an item with granular control over compression and meta plugins.
///
/// This method allows clients to control whether compression and meta plugins
@@ -266,7 +238,9 @@ impl SyncDataService {
db_path: PathBuf,
) -> StatusInfo {
let status_service = StatusService::new();
status_service.generate_status(cmd, settings, data_path, db_path)
status_service
.generate_status(cmd, settings, data_path, db_path)
.unwrap_or_else(|_| StatusInfo::default())
}
}
@@ -359,6 +333,6 @@ impl DataService for SyncDataService {
settings,
data_path.to_path_buf(),
db_path.to_path_buf(),
))
)?)
}
}