refactor: streaming, security hardening, and MCP removal

Major overhaul of server architecture and security posture:

- Streaming: Unified all I/O through PIPESIZE (8192-byte) buffers.
  POST bodies stream via MpscReader through the save pipeline. GET
  content streams from disk via decompression to client. Removed
  save_item_with_reader, get_item_content_info, ChannelReader.
  413 responses keep partial items (nonfatal by design).

- Security: XSS protection in all HTML pages via html_escape crate.
  Security headers middleware (nosniff, frame deny, referrer policy).
  CORS tightened to explicit headers. Input validation for tags
  (256 chars), metadata (128/4096), pagination (10k cap). Config
  file reads use from_utf8_lossy. Generic error messages in HTML.
  Diff endpoint has 10 MB per-item cap. max_body_size config option.

- Panics eliminated: Path unwraps → proper error propagation.
  Mutex unwraps → map_err (registries) / expect with message (local).

- MCP removed: Deleted all MCP code, rmcp dependency, mcp feature.

- Docs: Updated README, DESIGN, AGENTS to reflect all changes.
This commit is contained in:
2026-03-14 00:03:42 -03:00
parent 560ba6e20c
commit 17be6abaab
51 changed files with 876 additions and 1309 deletions

View File

@@ -92,7 +92,7 @@ pub fn mode(
let digest = format!("{:x}", hasher.finalize());
// Set shared state for main thread
let mut shared = shared_reader.lock().unwrap();
let mut shared = shared_reader.lock().expect("client save mutex poisoned");
*shared = (total_bytes, digest.clone());
Ok((total_bytes, digest))
@@ -135,7 +135,7 @@ pub fn mode(
// Read results from shared state
let (uncompressed_size, digest) = {
let shared = shared.lock().unwrap();
let shared = shared.lock().expect("client save mutex poisoned");
shared.clone()
};

View File

@@ -162,7 +162,7 @@ fn show_item(
item_path_buf.push(item_id.to_string());
let path_str = item_path_buf
.to_str()
.expect("Unable to get item path")
.ok_or_else(|| anyhow::anyhow!("non-UTF-8 item path"))?
.to_string();
table.add_row(vec![
Cell::new("Path").add_attribute(Attribute::Bold),

View File

@@ -240,7 +240,11 @@ pub fn mode_list(
Ok(metadata) => format_size(metadata.len(), settings.human_readable),
Err(_) => "Missing".to_string(),
},
ColumnType::FilePath => item_path.clone().into_os_string().into_string().unwrap(),
ColumnType::FilePath => item_path
.clone()
.into_os_string()
.into_string()
.unwrap_or_else(|os| os.to_string_lossy().into_owned()),
ColumnType::Tags => tags.join(" "),
ColumnType::Meta => match meta_name {
Some(meta_name) => match meta.get(meta_name) {

View File

@@ -16,7 +16,9 @@ use axum::{
use http_body_util::BodyExt;
use log::{debug, warn};
use std::collections::HashMap;
use std::io::{Cursor, Read};
use std::io::Read;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::mpsc;
use tokio::task;
@@ -24,14 +26,15 @@ use tokio::task;
///
/// Used in `spawn_blocking` contexts to consume data from an async body
/// stream as a regular reader. Blocks on each `read()` call until data
/// is available or the channel is closed.
struct ChannelReader {
/// is available or the channel is closed. No internal buffering — each
/// channel message is consumed fully before the next is fetched.
struct MpscReader {
rx: mpsc::Receiver<Result<Vec<u8>, std::io::Error>>,
current: Vec<u8>,
pos: usize,
}
impl ChannelReader {
impl MpscReader {
fn new(rx: mpsc::Receiver<Result<Vec<u8>, std::io::Error>>) -> Self {
Self {
rx,
@@ -41,9 +44,8 @@ impl ChannelReader {
}
}
impl Read for ChannelReader {
impl Read for MpscReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// If we have buffered data, return it first
if self.pos < self.current.len() {
let remaining = &self.current[self.pos..];
let n = std::cmp::min(buf.len(), remaining.len());
@@ -52,7 +54,6 @@ impl Read for ChannelReader {
return Ok(n);
}
// Need more data from the channel - block until available
match self.rx.blocking_recv() {
Some(Ok(data)) => {
let n = std::cmp::min(buf.len(), data.len());
@@ -64,7 +65,7 @@ impl Read for ChannelReader {
Ok(n)
}
Some(Err(e)) => Err(e),
None => Ok(0), // Channel closed, EOF
None => Ok(0),
}
}
}
@@ -116,23 +117,6 @@ fn get_mime_type(metadata: &HashMap<String, String>) -> String {
.unwrap_or_else(|| "application/octet-stream".to_string())
}
/// Helper function to apply offset and length to content
fn apply_offset_length(content: &[u8], offset: u64, length: u64) -> &[u8] {
let content_len = content.len() as u64;
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
if start < content_len {
&content[start as usize..end as usize]
} else {
&[]
}
}
/// Helper function to handle item not found errors
fn handle_item_error(error: CoreError) -> StatusCode {
match error {
@@ -203,7 +187,7 @@ pub async fn handle_list_items(
// Apply pagination
let start = params.start.unwrap_or(0) as usize;
let count = params.count.unwrap_or(100) as usize;
let count = params.count.unwrap_or(100).min(10000) as usize;
let items_with_meta: Vec<_> = items_with_meta
.into_iter()
.skip(start)
@@ -253,7 +237,7 @@ async fn handle_as_meta_response(
handle_as_meta_response_with_metadata(data_service, item_id, &metadata, offset, length).await
}
/// Handle as_meta=true response with pre-fetched metadata
/// Handle as_meta=true response with pre-fetched metadata using streaming
async fn handle_as_meta_response_with_metadata(
data_service: &AsyncDataService,
item_id: i64,
@@ -279,61 +263,121 @@ async fn handle_as_meta_response_with_metadata(
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
// Get the content as text
match data_service.get_item_content_info(item_id, None).await {
Ok((content, _, _)) => {
// Apply offset and length
let content_len = content.len() as u64;
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
// Use streaming approach to read content
let db = data_service.db();
let data_path = data_service.data_path().clone();
let settings = data_service.settings();
let response_content = if start < content_len {
&content[start as usize..end as usize]
} else {
&[]
};
// Get streaming reader from sync service
let (reader, content_len_result) = tokio::task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_path, settings.as_ref().clone());
let (reader, item_with_meta) = sync_service.get_content(&mut conn, item_id)?;
let content_len = item_with_meta.item.size.unwrap_or(0);
Ok::<_, CoreError>((reader, content_len))
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to get content reader for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Convert to UTF-8 string
let content_str = match String::from_utf8(response_content.to_vec()) {
Ok(s) => s,
Err(_) => {
// This shouldn't happen since we checked is_binary, but handle it just in case
let response_body = serde_json::json!({
"metadata": metadata,
"content": serde_json::Value::Null,
"error": "Content is not valid UTF-8"
});
let content_len = content_len_result as u64;
let response = Response::builder()
.header(header::CONTENT_TYPE, "application/json")
.status(StatusCode::UNPROCESSABLE_ENTITY)
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
return Ok(response);
// Calculate offset and length bounds
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
let response_len = end - start;
// Read content with offset and length using fixed-size buffer
let content_str = tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; crate::common::PIPESIZE];
let mut result = Vec::new();
let mut bytes_read = 0u64;
// Skip offset bytes
if offset > 0 {
let mut remaining = offset;
while remaining > 0 {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => remaining -= n as u64,
Err(e) => return Err(CoreError::Io(e)),
}
};
}
}
// Return JSON with metadata and content
// Read up to length bytes
let mut remaining = if length > 0 { length } else { u64::MAX };
while remaining > 0 && bytes_read < response_len {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => {
result.extend_from_slice(&buf[..n]);
bytes_read += n as u64;
if length > 0 {
remaining -= n as u64;
}
}
Err(e) => return Err(CoreError::Io(e)),
}
}
Ok::<Vec<u8>, CoreError>(result)
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to read content for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Convert to UTF-8 string
let content_str = match String::from_utf8(content_str) {
Ok(s) => s,
Err(_) => {
// This shouldn't happen since we checked is_binary, but handle it just in case
let response_body = serde_json::json!({
"metadata": metadata,
"content": content_str,
"error": serde_json::Value::Null
"content": serde_json::Value::Null,
"error": "Content is not valid UTF-8"
});
Response::builder()
let response = Response::builder()
.header(header::CONTENT_TYPE, "application/json")
.status(StatusCode::UNPROCESSABLE_ENTITY)
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
return Ok(response);
}
Err(e) => {
warn!("Failed to get content for item {item_id}: {e}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
};
// Return JSON with metadata and content
let response_body = serde_json::json!({
"metadata": metadata,
"content": content_str,
"error": serde_json::Value::Null
});
Response::builder()
.header(header::CONTENT_TYPE, "application/json")
.body(axum::body::Body::from(response_body.to_string()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
}
@@ -380,6 +424,14 @@ pub async fn handle_post_item(
.map(parse_comma_tags)
.unwrap_or_default();
// Validate tag lengths
for tag in &tags {
if tag.len() > 256 {
warn!("Tag too long: {} chars (max 256)", tag.len());
return Err(StatusCode::BAD_REQUEST);
}
}
// Parse metadata from query parameter
let metadata: HashMap<String, String> = if let Some(ref meta_str) = params.metadata {
serde_json::from_str(meta_str).map_err(|e| {
@@ -390,91 +442,116 @@ pub async fn handle_post_item(
HashMap::new()
};
// Validate metadata key/value lengths
for (key, value) in &metadata {
if key.len() > 128 {
warn!("Metadata key too long: {} chars (max 128)", key.len());
return Err(StatusCode::BAD_REQUEST);
}
if value.len() > 4096 {
warn!("Metadata value too long: {} chars (max 4096)", value.len());
return Err(StatusCode::BAD_REQUEST);
}
}
let compress = params.compress;
let run_meta = params.meta;
// When server handles both compression and meta, save_item_with_reader
// buffers internally anyway, so collect body in memory.
// When client handles compression/meta, stream the body to avoid buffering.
let item_with_meta = if compress && run_meta {
let body_bytes = body
.collect()
.await
.map_err(|e| {
warn!("Failed to read request body: {e}");
StatusCode::BAD_REQUEST
})?
.to_bytes();
// Stream body through an mpsc channel with fixed-size frames.
// Size tracking ensures we never buffer the whole body in memory.
// Treat Some(0) as unlimited (None).
let max_body_size: Option<u64> = settings
.server
.as_ref()
.and_then(|s| s.max_body_size)
.filter(|&v| v > 0);
task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_dir, settings.as_ref().clone());
let mut cursor = Cursor::new(body_bytes.to_vec());
sync_service.save_item_with_reader(&mut conn, &mut cursor, tags, metadata)
})
.await
.map_err(|e| {
warn!("Failed to save item (task error): {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to save item: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
// Stream body through a channel to avoid buffering in memory
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, std::io::Error>>(16);
let (tx, rx) = mpsc::channel::<Result<Vec<u8>, std::io::Error>>(16);
let body_truncated = Arc::new(AtomicBool::new(false));
let truncated_flag = body_truncated.clone();
// Task to read body frames and send through channel
tokio::spawn(async move {
let mut body = body;
loop {
match body.frame().await {
None => break, // Body complete
Some(Err(e)) => {
let _ = tx
.send(Err(std::io::Error::other(format!("Body error: {e}"))))
.await;
break;
}
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
if tx.send(Ok(data.to_vec())).await.is_err() {
break; // Receiver dropped
}
// Async task: read body frames, track size, stop when limit exceeded
tokio::spawn(async move {
let mut body = body;
let mut total_bytes: u64 = 0;
loop {
match body.frame().await {
None => break,
Some(Err(e)) => {
let _ = tx
.send(Err(std::io::Error::other(format!("Body error: {e}"))))
.await;
break;
}
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
total_bytes += data.len() as u64;
if let Some(limit) = max_body_size
&& total_bytes > limit
{
truncated_flag.store(true, Ordering::Relaxed);
break; // Drop sender → reader sees EOF
}
if tx.send(Ok(data.to_vec())).await.is_err() {
break;
}
}
}
}
});
}
});
task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_dir, settings.as_ref().clone());
// Blocking task: consume streaming reader, save via save_item_raw_streaming
let item_with_meta = task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_dir, settings.as_ref().clone());
let mut reader = MpscReader::new(rx);
sync_service.save_item_raw_streaming(
&mut conn,
&mut reader,
tags,
metadata,
compress,
run_meta,
)
})
.await
.map_err(|e| {
warn!("Failed to save item (task error): {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to save item: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Convert async mpsc receiver into a sync Read
let mut stream_reader = ChannelReader::new(rx);
sync_service.save_item_raw_streaming(
&mut conn,
&mut stream_reader,
tags,
metadata,
compress,
run_meta,
)
})
.await
.map_err(|e| {
warn!("Failed to save item (task error): {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to save item: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
};
// If body was truncated due to size limit, return 413 but keep the partial item.
//
// Design rationale for preserving truncated data:
//
// 1. Keep is a temporary file manager — partial content is still useful data.
// The client sent it, the server received it, and it was written through the
// normal save pipeline (compression, meta plugins). Discarding it would waste
// work already done and lose data the user may want to inspect or retry from.
//
// 2. The streaming pipeline is nonfatal by design. The sender stops pushing
// frames, the channel closes cleanly (sender drop → reader sees EOF), and
// save_item_raw_streaming finishes with whatever it received. No corruption,
// no partial writes to the storage file — the compression engine flushes and
// closes normally.
//
// 3. No destructive rollback. Deleting the item would require either a second
// DB transaction (fragile under concurrency) or leaving orphaned files on
// disk. Keeping the item is simpler and safer.
//
// 4. The 413 response tells the client exactly what happened. The client can
// use GET /api/item/ to list items and find the partial one by timestamp
// if it needs to resume or clean up.
//
if body_truncated.load(Ordering::Relaxed) {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
let compression = item_with_meta.item.compression.clone();
let tags = item_with_meta.tags.iter().map(|t| t.name.clone()).collect();
@@ -696,37 +773,110 @@ async fn stream_item_content_response(
/// Stream raw (unprocessed) content directly from the item file.
///
/// Returns the stored file bytes without decompression or filtering.
/// Returns the stored file bytes without decompression or filtering using streaming.
async fn stream_raw_content_response(
data_service: &AsyncDataService,
item_id: i64,
offset: u64,
length: u64,
) -> Result<Response, StatusCode> {
// Get item info to find the file path and compression type
// Get item info to find the compression type
let item_with_meta = data_service.get_item(item_id).await.map_err(|e| {
warn!("Failed to get item {item_id} for raw content: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let compression = item_with_meta.item.compression.clone();
let content_size = item_with_meta.item.size.unwrap_or(0);
// Read raw file bytes
let content = data_service
.get_raw_item_content(item_id)
// Get streaming reader for raw content
let reader = data_service
.get_raw_item_content_reader(item_id)
.await
.map_err(|e| {
warn!("Failed to get raw content for item {item_id}: {e}");
warn!("Failed to get raw content reader for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response_content = apply_offset_length(&content, offset, length);
// Calculate the actual response length
let content_len = content_size as u64;
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
let response_len = end - start;
// Create a channel to stream data between blocking thread and async runtime
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, std::io::Error>>(16);
// Spawn blocking task to read with offset and length
tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; crate::common::PIPESIZE];
// Apply offset by reading and discarding bytes
if offset > 0 {
let mut remaining = offset;
while remaining > 0 {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => remaining -= n as u64,
Err(e) => {
let _ = tx.blocking_send(Err(e));
return;
}
}
}
}
// Read and send data up to the specified length
let mut remaining_length = length;
loop {
let to_read = if length > 0 {
std::cmp::min(remaining_length, buf.len() as u64) as usize
} else {
buf.len()
};
if to_read == 0 {
break;
}
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => {
let chunk = buf[..n].to_vec();
if tx.blocking_send(Ok(chunk)).is_err() {
break;
}
if length > 0 {
remaining_length -= n as u64;
if remaining_length == 0 {
break;
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(e));
break;
}
}
}
});
// Convert the receiver into a stream
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let body = axum::body::Body::from_stream(stream);
let response = Response::builder()
.header(header::CONTENT_TYPE, "application/octet-stream")
.header("X-Keep-Compression", &compression)
.header(header::CONTENT_LENGTH, response_content.len())
.body(axum::body::Body::from(response_content.to_vec()))
.header(header::CONTENT_LENGTH, response_len)
.body(body)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(response)
@@ -776,24 +926,94 @@ async fn stream_item_content_response_with_metadata(
}
}
} else {
debug!("NON-STREAMING: Building full response in memory");
match data_service.get_item_content_info(item_id, None).await {
Ok((content, _, _)) => {
let response_content = apply_offset_length(&content, offset, length);
debug!("NON-STREAMING: Building full response in memory using streaming reader");
// Use streaming approach even for non-streaming response
let db = data_service.db();
let data_path = data_service.data_path().clone();
let settings = data_service.settings();
debug!(
"NON-STREAMING: Content length: {}, response length: {}",
content.len(),
response_content.len()
);
// Get streaming reader from sync service
let (reader, content_len_result) = tokio::task::spawn_blocking(move || {
let mut conn = db.blocking_lock();
let sync_service =
crate::services::SyncDataService::new(data_path, settings.as_ref().clone());
let (reader, item_with_meta) = sync_service.get_content(&mut conn, item_id)?;
let content_len = item_with_meta.item.size.unwrap_or(0);
Ok::<_, CoreError>((reader, content_len))
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to get content reader for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
ResponseBuilder::binary(response_content, &mime_type)
let content_len = content_len_result as u64;
// Calculate offset and length bounds
let start = std::cmp::min(offset, content_len);
let end = if length > 0 {
std::cmp::min(start + length, content_len)
} else {
content_len
};
let response_len = end - start;
// Read content with offset and length using fixed-size buffer
let content = tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; crate::common::PIPESIZE];
let mut result = Vec::new();
let mut bytes_read = 0u64;
// Skip offset bytes
if offset > 0 {
let mut remaining = offset;
while remaining > 0 {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => remaining -= n as u64,
Err(e) => return Err(CoreError::Io(e)),
}
}
}
Err(e) => {
warn!("Failed to get content for item {item_id}: {e}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
// Read up to length bytes
let mut remaining = if length > 0 { length } else { u64::MAX };
while remaining > 0 && bytes_read < response_len {
let to_read = std::cmp::min(remaining, buf.len() as u64) as usize;
match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => {
result.extend_from_slice(&buf[..n]);
bytes_read += n as u64;
if length > 0 {
remaining -= n as u64;
}
}
Err(e) => return Err(CoreError::Io(e)),
}
}
}
Ok::<Vec<u8>, CoreError>(result)
})
.await
.map_err(|e| {
warn!("Blocking task failed for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.map_err(|e| {
warn!("Failed to read content for item {item_id}: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
debug!("NON-STREAMING: Response length: {}", content.len());
ResponseBuilder::binary(&content, &mime_type)
}
}
@@ -870,6 +1090,10 @@ pub async fn handle_get_item_meta(
State(state): State<AppState>,
Path(item_id): Path<i64>,
) -> Result<Json<ApiResponse<HashMap<String, String>>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let data_service = create_data_service(&state);
match data_service.get_item(item_id).await {
@@ -893,6 +1117,10 @@ pub async fn handle_post_item_meta(
Path(item_id): Path<i64>,
Json(metadata): Json<HashMap<String, String>>,
) -> Result<Json<ApiResponse<()>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let data_service = create_data_service(&state);
// Verify item exists
@@ -942,6 +1170,10 @@ pub async fn handle_delete_item(
State(state): State<AppState>,
Path(item_id): Path<i64>,
) -> Result<Json<ApiResponse<ItemInfo>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let mut conn = state.db.lock().await;
let sync_service = crate::services::SyncDataService::new(
@@ -995,6 +1227,10 @@ pub async fn handle_get_item_info(
State(state): State<AppState>,
Path(item_id): Path<i64>,
) -> Result<Json<ApiResponse<ItemInfo>>, StatusCode> {
if item_id <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let mut conn = state.db.lock().await;
let sync_service = crate::services::SyncDataService::new(
@@ -1097,6 +1333,28 @@ pub async fn handle_diff_items(
let id_a = item_a.item.id.ok_or(StatusCode::BAD_REQUEST)?;
let id_b = item_b.item.id.ok_or(StatusCode::BAD_REQUEST)?;
// Size limit for diff operation (10MB per item)
const MAX_DIFF_SIZE: i64 = 10 * 1024 * 1024;
if let Some(size_a) = item_a.item.size
&& size_a > MAX_DIFF_SIZE
{
warn!(
"Item A ({}) exceeds diff size limit: {} > {}",
id_a, size_a, MAX_DIFF_SIZE
);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
if let Some(size_b) = item_b.item.size
&& size_b > MAX_DIFF_SIZE
{
warn!(
"Item B ({}) exceeds diff size limit: {} > {}",
id_b, size_b, MAX_DIFF_SIZE
);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
let (mut reader_a, _) = sync_service
.get_content(&mut conn, id_a)
.map_err(handle_item_error)?;

View File

@@ -1,72 +0,0 @@
use axum::{
extract::State,
http::StatusCode,
response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::{self, Stream};
use log::{debug, info};
use std::convert::Infallible;
use std::time::Duration;
use crate::modes::server::common::AppState;
use crate::modes::server::mcp::KeepMcpServer;
#[utoipa::path(
get,
path = "/mcp/sse",
operation_id = "mcp_sse",
summary = "MCP SSE endpoint",
description = "Server-Sent Events for Model Context Protocol. Enables AI tools to interact with Keep's storage and retrieval functions.",
responses(
(status = 200, description = "SSE stream established"),
(status = 401, description = "Unauthorized"),
(status = 500, description = "Internal server error")
),
security(
("bearerAuth" = [])
),
tag = "mcp"
)]
pub async fn handle_mcp_sse(
State(state): State<AppState>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
debug!("MCP: Starting SSE endpoint");
let _mcp_server = KeepMcpServer::new(state);
// Create a simple message channel for SSE communication
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();
// Send initial connection message
let _ = tx.send("data: {\"type\":\"connection\",\"status\":\"connected\"}\n\n".to_string());
// For now, create a simple stream that sends periodic keep-alive messages
// In a full implementation, this would integrate with the rmcp transport layer
let stream = stream::unfold((rx, tx), |(mut rx, tx)| async move {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(data) => {
let event = Event::default().data(data);
Some((Ok(event), (rx, tx)))
}
None => None,
}
}
_ = tokio::time::sleep(Duration::from_secs(30)) => {
let event = Event::default()
.event("keep-alive")
.data("ping");
Some((Ok(event), (rx, tx)))
}
}
});
info!("MCP: SSE endpoint established");
Ok(Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(30))
.text("keep-alive"),
))
}

View File

@@ -1,7 +1,5 @@
pub mod common;
pub mod item;
#[cfg(feature = "mcp")]
pub mod mcp;
pub mod status;
use axum::{
@@ -60,8 +58,7 @@ use utoipa_swagger_ui::SwaggerUi;
struct ApiDoc;
pub fn add_routes(router: Router<AppState>) -> Router<AppState> {
#[cfg_attr(not(feature = "mcp"), allow(unused_mut))]
let mut router = router
router
// Status endpoints
.route("/api/status", get(status::handle_status))
.route("/api/plugins/status", get(status::handle_plugins_status))
@@ -88,14 +85,7 @@ pub fn add_routes(router: Router<AppState>) -> Router<AppState> {
)
.route("/api/item/{item_id}", delete(item::handle_delete_item))
.route("/api/item/{item_id}/info", get(item::handle_get_item_info))
.route("/api/diff", get(item::handle_diff_items));
#[cfg(feature = "mcp")]
{
router = router.route("/mcp/sse", get(mcp::handle_mcp_sse));
}
router
.route("/api/diff", get(item::handle_diff_items))
}
#[cfg(feature = "swagger")]

View File

@@ -39,7 +39,7 @@ use crate::modes::server::common::{ApiResponse, AppState, StatusInfoResponse};
///
/// # Examples
///
/// ```
/// ```ignore
/// // In an Axum app:
/// async fn app() -> Result<Json<StatusInfoResponse>, StatusCode> {
/// handle_status(State(app_state)).await
@@ -60,12 +60,17 @@ pub async fn handle_status(
// Use the status service to generate status info showing configured plugins
let status_service = crate::services::status_service::StatusService::new();
let mut cmd = state.cmd.lock().await;
let status_info = status_service.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
);
let status_info = status_service
.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
)
.map_err(|e| {
log::warn!("Failed to generate status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response = StatusInfoResponse {
success: true,
@@ -112,12 +117,17 @@ pub async fn handle_plugins_status(
let status_service = crate::services::status_service::StatusService::new();
let mut cmd = state.cmd.lock().await;
let status_info = status_service.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
);
let status_info = status_service
.generate_status(
&mut cmd,
&state.settings,
state.data_dir.clone(),
db_path.into(),
)
.map_err(|e| {
log::warn!("Failed to generate status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response_data = PluginsStatusResponse {
meta_plugins: status_info.meta_plugins,

View File

@@ -726,34 +726,33 @@ fn check_basic_auth(
}
let encoded = &auth_str[6..];
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
if let Some(colon_pos) = decoded_str.find(':') {
let provided_username = &decoded_str[..colon_pos];
let provided_password = &decoded_str[colon_pos + 1..];
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded)
&& let Ok(decoded_str) = String::from_utf8(decoded_bytes)
&& let Some(colon_pos) = decoded_str.find(':')
{
let provided_username = &decoded_str[..colon_pos];
let provided_password = &decoded_str[colon_pos + 1..];
// Check username with constant-time comparison
if !bool::from(
provided_username
.as_bytes()
.ct_eq(expected_username.as_bytes()),
) {
return false;
}
// If we have a password hash, verify against it
if let Some(hash) = expected_hash {
return pwhash::unix::verify(provided_password, hash);
}
// Otherwise, do constant-time comparison to prevent timing attacks
return bool::from(
provided_password
.as_bytes()
.ct_eq(expected_password.as_bytes()),
);
}
// Check username with constant-time comparison
if !bool::from(
provided_username
.as_bytes()
.ct_eq(expected_username.as_bytes()),
) {
return false;
}
// If we have a password hash, verify against it
if let Some(hash) = expected_hash {
return pwhash::unix::verify(provided_password, hash);
}
// Otherwise, do constant-time comparison to prevent timing attacks
return bool::from(
provided_password
.as_bytes()
.ct_eq(expected_password.as_bytes()),
);
}
false
}

View File

@@ -1,83 +0,0 @@
pub mod server;
pub mod tools;
pub use server::KeepMcpServer;
/// Module for handling MCP (Model Context Protocol) requests in the server.
///
/// Provides handlers for JSON-RPC style requests to interact with Keep's storage
/// via the API.
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
use serde::Deserialize;
use serde_json::Value;
use crate::modes::server::common::ApiResponse;
use crate::modes::server::common::AppState;
/// Request structure for MCP JSON-RPC calls.
///
/// # Fields
///
/// * `method` - The MCP method name (e.g., "save_item").
/// * `params` - Optional JSON parameters for the method.
#[derive(Deserialize)]
pub struct McpRequest {
pub method: String,
pub params: Option<Value>,
}
/// Handles an MCP request via the Axum framework.
///
/// Parses the JSON request, delegates to `KeepMcpServer`, and returns an API response.
/// Attempts to parse the result as JSON; falls back to string if invalid.
///
/// # Arguments
///
/// * `State(state)` - The application state.
/// * `Json(request)` - The deserialized MCP request.
///
/// # Returns
///
/// An `IntoResponse` with status code and JSON API response.
///
/// # Errors
///
/// Returns 400 Bad Request on handler errors.
pub async fn handle_mcp_request(
State(state): State<AppState>,
Json(request): Json<McpRequest>,
) -> impl IntoResponse {
let mcp_server = KeepMcpServer::new(state);
match mcp_server
.handle_request(&request.method, request.params)
.await
{
Ok(result) => match serde_json::from_str(&result) {
Ok(parsed_result) => {
let response = ApiResponse {
success: true,
data: Some(parsed_result),
error: None,
};
(StatusCode::OK, Json(response))
}
Err(_) => {
let response = ApiResponse {
success: true,
data: Some(serde_json::Value::String(result)),
error: None,
};
(StatusCode::OK, Json(response))
}
},
Err(e) => {
let response = ApiResponse {
success: false,
data: None,
error: Some(e.to_string()),
};
(StatusCode::BAD_REQUEST, Json(response))
}
}
}

View File

@@ -1,83 +0,0 @@
use log::debug;
use serde_json::Value;
use super::tools::{KeepTools, ToolError};
use crate::modes::server::common::AppState;
/// Server handler for MCP (Model Context Protocol) requests.
///
/// Routes requests to appropriate tools and handles responses. Clones AppState for tool usage.
///
/// # Fields
///
/// * `state` - The shared application state (DB, config, etc.).
#[derive(Clone)]
pub struct KeepMcpServer {
state: AppState,
}
/// Creates a new `KeepMcpServer` instance.
///
/// # Arguments
///
/// * `state` - The application state containing DB, config, and services.
///
/// # Returns
///
/// A new `KeepMcpServer` instance.
///
/// # Examples
///
/// ```
/// let server = KeepMcpServer::new(app_state);
/// ```
impl KeepMcpServer {
pub fn new(state: AppState) -> Self {
Self { state }
}
/// Handles an MCP request by routing to the appropriate tool.
///
/// Supports methods like "save_item", "get_item", "list_items". Logs the request and delegates to KeepTools.
///
/// # Arguments
///
/// * `method` - The MCP method name (string).
/// * `params` - Optional JSON parameters as serde_json::Value.
///
/// # Returns
///
/// `Ok(String)` with JSON-serialized response on success, or `Err(ToolError)` on failure.
///
/// # Errors
///
/// * ToolError::UnknownTool if method unsupported.
/// * Propagates tool-specific errors (e.g., invalid args, DB failures).
///
/// # Examples
///
/// ```
/// let result = server.handle_request("save_item", Some(params)).await?;
/// ```
pub async fn handle_request(
&self,
method: &str,
params: Option<Value>,
) -> Result<String, ToolError> {
debug!(
"MCP: Handling request '{}' with params: {:?}",
method, params
);
let tools = KeepTools::new(self.state.clone());
match method {
"save_item" => tools.save_item(params).await,
"get_item" => tools.get_item(params).await,
"get_latest_item" => tools.get_latest_item(params).await,
"list_items" => tools.list_items(params).await,
"search_items" => tools.search_items(params).await,
_ => Err(ToolError::UnknownTool(method.to_string())),
}
}
}

View File

@@ -1,344 +0,0 @@
use anyhow::{Result, anyhow};
use log::debug;
use serde_json::Value;
use std::collections::HashMap;
use crate::modes::server::common::AppState;
use crate::services::async_item_service::AsyncItemService;
use crate::services::error::CoreError;
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
#[error("Unknown tool: {0}")]
UnknownTool(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error("Database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Parse error: {0}")]
Parse(#[from] strum::ParseError),
#[error("Other error: {0}")]
Other(#[from] anyhow::Error),
}
pub struct KeepTools {
state: AppState,
}
impl KeepTools {
pub fn new(state: AppState) -> Self {
Self { state }
}
pub async fn save_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let args =
args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidArguments("Missing 'content' field".to_string()))?;
let tags: Vec<String> = args
.get("tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let metadata: HashMap<String, String> = args
.get("metadata")
.and_then(|v| v.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
debug!(
"MCP: Saving item with {} bytes, {} tags, {} metadata entries",
content.len(),
tags.len(),
metadata.len()
);
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let item_with_meta = service
.save_item_from_mcp(content.as_bytes().to_vec(), tags, metadata)
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
let item_id = item_with_meta
.item
.id
.ok_or_else(|| anyhow!("Failed to get item ID"))?;
Ok(format!("Successfully saved item with ID: {}", item_id))
}
pub async fn get_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let args =
args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?;
let item_id = args.get("id").and_then(|v| v.as_i64()).ok_or_else(|| {
ToolError::InvalidArguments("Missing or invalid 'id' field".to_string())
})?;
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let item_with_content = match service.get_item_content(item_id).await {
Ok(iwc) => iwc,
Err(CoreError::ItemNotFound(_)) => {
return Err(ToolError::InvalidArguments(format!(
"Item {} not found",
item_id
)));
}
Err(e) => return Err(ToolError::Other(anyhow::Error::from(e))),
};
let content = String::from_utf8_lossy(&item_with_content.content).to_string();
let tags: Vec<String> = item_with_content
.item_with_meta
.tags
.iter()
.map(|t| t.name.clone())
.collect();
let metadata = item_with_content.item_with_meta.meta_as_map();
let item = item_with_content.item_with_meta.item;
let response = serde_json::json!({
"id": item_id,
"content": content,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": tags,
"metadata": metadata,
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn get_latest_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let tags: Vec<String> = args
.as_ref()
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let item_with_meta = match service.find_item(vec![], tags, HashMap::new()).await {
Ok(iwm) => iwm,
Err(CoreError::ItemNotFoundGeneric) => {
return Err(ToolError::InvalidArguments("No items found".to_string()));
}
Err(e) => return Err(ToolError::Other(anyhow::Error::from(e))),
};
let item_id = item_with_meta
.item
.id
.ok_or_else(|| anyhow!("Item missing ID after find"))?;
let item_with_content = service
.get_item_content(item_id)
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
let content = String::from_utf8_lossy(&item_with_content.content).to_string();
let tags: Vec<String> = item_with_content
.item_with_meta
.tags
.iter()
.map(|t| t.name.clone())
.collect();
let metadata = item_with_content.item_with_meta.meta_as_map();
let item = item_with_content.item_with_meta.item;
let response = serde_json::json!({
"id": item_id,
"content": content,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": tags,
"metadata": metadata,
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn list_items(&self, args: Option<Value>) -> Result<String, ToolError> {
let args_ref = args.as_ref();
let tags: Vec<String> = args_ref
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let limit = args_ref
.and_then(|v| v.get("limit"))
.and_then(|v| v.as_u64())
.unwrap_or(10) as usize;
let offset = args_ref
.and_then(|v| v.get("offset"))
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let mut items_with_meta = service
.list_items(tags, HashMap::new())
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
// Sort by timestamp (newest first) and apply pagination
items_with_meta.sort_by(|a, b| b.item.ts.cmp(&a.item.ts));
let items_with_meta: Vec<_> = items_with_meta
.into_iter()
.skip(offset)
.take(limit)
.collect();
let items_info: Vec<_> = items_with_meta
.into_iter()
.map(|item_with_meta| {
let item_tags: Vec<String> =
item_with_meta.tags.iter().map(|t| t.name.clone()).collect();
let item_meta = item_with_meta.meta_as_map();
let item = item_with_meta.item;
let item_id = item.id.unwrap_or(0);
serde_json::json!({
"id": item_id,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": item_tags,
"metadata": item_meta
})
})
.collect();
let response = serde_json::json!({
"items": items_info,
"count": items_info.len(),
"offset": offset,
"limit": limit
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn search_items(&self, args: Option<Value>) -> Result<String, ToolError> {
let tags: Vec<String> = args
.as_ref()
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let metadata: HashMap<String, String> = args
.as_ref()
.and_then(|v| v.get("metadata"))
.and_then(|v| v.as_object())
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
let service = AsyncItemService::new(
self.state.data_dir.clone(),
self.state.db.clone(),
self.state.item_service.clone(),
self.state.cmd.clone(),
self.state.settings.clone(),
);
let mut items_with_meta = service
.list_items(tags.clone(), metadata.clone())
.await
.map_err(|e| ToolError::Other(anyhow::Error::from(e)))?;
// Sort by timestamp (newest first)
items_with_meta.sort_by(|a, b| b.item.ts.cmp(&a.item.ts));
let items_info: Vec<_> = items_with_meta
.into_iter()
.map(|item_with_meta| {
let item_tags: Vec<String> =
item_with_meta.tags.iter().map(|t| t.name.clone()).collect();
let item_meta = item_with_meta.meta_as_map();
let item = item_with_meta.item;
let item_id = item.id.unwrap_or(0);
serde_json::json!({
"id": item_id,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": item_tags,
"metadata": item_meta
})
})
.collect();
let response = serde_json::json!({
"items": items_info,
"count": items_info.len(),
"search_criteria": {
"tags": tags,
"metadata": metadata
}
});
Ok(serde_json::to_string_pretty(&response)?)
}
}

View File

@@ -1,7 +1,10 @@
use crate::config;
use crate::services::item_service::ItemService;
use anyhow::Result;
use axum::{Router, routing::post};
use axum::Router;
use axum::http::{HeaderValue, header};
use axum::middleware::Next;
use axum::response::Response;
use clap::Command;
use log::{debug, info};
use std::net::SocketAddr;
@@ -15,12 +18,26 @@ use tower_http::trace::TraceLayer;
mod api;
pub mod auth;
pub mod common;
#[cfg(feature = "mcp")]
mod mcp;
mod pages;
pub use common::{AppState, create_auth_middleware, logging_middleware};
/// Adds security headers to all responses.
async fn security_headers(req: axum::extract::Request, next: Next) -> Response {
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
);
headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
headers.insert(
header::REFERRER_POLICY,
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
response
}
pub fn mode_server(
cmd: &mut Command,
settings: &config::Settings,
@@ -107,23 +124,10 @@ async fn run_server(
settings: Arc::new(settings.clone()),
};
#[cfg(feature = "mcp")]
let mcp_router = Router::new()
.route("/mcp", post(mcp::handle_mcp_request))
.with_state(state.clone());
#[cfg_attr(not(feature = "mcp"), allow(unused_mut))]
let mut protected_router = Router::new()
let protected_router = Router::new()
.merge(api::add_routes(Router::new()))
.merge(pages::add_routes(Router::new()));
#[cfg(feature = "mcp")]
{
protected_router = protected_router.merge(mcp_router);
}
let protected_router =
protected_router.layer(axum::middleware::from_fn(create_auth_middleware(
.merge(pages::add_routes(Router::new()))
.layer(axum::middleware::from_fn(create_auth_middleware(
config.username.clone(),
config.password.clone(),
config.password_hash.clone(),
@@ -152,18 +156,19 @@ async fn run_server(
axum::http::Method::PUT,
axum::http::Method::DELETE,
])
.allow_headers(tower_http::cors::Any)
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
};
// Create the app with documentation routes open and others protected
let app = Router::new()
// Add documentation routes without authentication
.merge(api::add_docs_routes(Router::new()))
// Add API, pages, and MCP routes with authentication
// Add API and pages routes with authentication
.merge(protected_router)
// Apply state to all routes
.with_state(state)
// Add other middleware layers to all routes
.layer(axum::middleware::from_fn(security_headers))
.layer(axum::middleware::from_fn(logging_middleware))
.layer(
ServiceBuilder::new()

View File

@@ -6,11 +6,22 @@ use axum::{
extract::{Path, Query, State},
response::{Html, Response},
};
use html_escape::{encode_double_quoted_attribute, encode_text};
use log::debug;
use rusqlite::Connection;
use serde::Deserialize;
use std::collections::HashMap;
/// Escape text content for safe HTML insertion.
fn esc(s: &str) -> String {
encode_text(s).to_string()
}
/// Escape attribute values for safe HTML attribute insertion.
fn esc_attr(s: &str) -> String {
encode_double_quoted_attribute(s).to_string()
}
#[derive(Deserialize)]
/// Query parameters for the item list endpoint.
///
@@ -62,7 +73,7 @@ fn default_count() -> usize {
///
/// # Examples
///
/// ```
/// ```ignore
/// let app = pages::add_routes(axum::Router::new());
/// ```
pub fn add_routes(app: axum::Router<AppState>) -> axum::Router<AppState> {
@@ -90,7 +101,9 @@ async fn list_items(
.map_err(|_| Html("<html><body>Internal Server Error</body></html>".to_string()))?;
Ok(response)
}
Err(e) => Err(Html(format!("<html><body>Error: {e}</body></html>"))),
Err(_e) => Err(Html(
"<html><body>An internal error occurred</body></html>".to_string(),
)),
}
}
@@ -121,7 +134,8 @@ fn build_item_list(
// Apply pagination
let start = params.start;
let end = std::cmp::min(start + params.count, sorted_items.len());
let count = params.count.min(10000);
let end = std::cmp::min(start + count, sorted_items.len());
let page_items = if start < sorted_items.len() {
sorted_items[start..std::cmp::min(end, sorted_items.len())].to_vec()
} else {
@@ -153,11 +167,11 @@ fn build_item_list(
// Collect all tags from all items, keeping track of their timestamps
let mut all_tags_with_time: Vec<(String, chrono::DateTime<chrono::Utc>)> = Vec::new();
for item in &sorted_items {
if let Some(item_id) = item.id {
if let Some(tags) = tags_map.get(&item_id) {
for tag in tags {
all_tags_with_time.push((tag.name.clone(), item.ts));
}
if let Some(item_id) = item.id
&& let Some(tags) = tags_map.get(&item_id)
{
for tag in tags {
all_tags_with_time.push((tag.name.clone(), item.ts));
}
}
}
@@ -184,7 +198,9 @@ fn build_item_list(
html.push_str("<p>");
for tag in recent_tags {
html.push_str(&format!(
"<a href=\"/?tags={tag}\" style=\"margin-right: 8px;\">{tag}</a>"
"<a href=\"/?tags={}\" style=\"margin-right: 8px;\">{}</a>",
esc_attr(&tag),
esc(&tag)
));
}
html.push_str("</p>");
@@ -196,7 +212,7 @@ fn build_item_list(
// Table headers
html.push_str("<tr>");
for column in columns {
html.push_str(&format!("<th>{}</th>", column.label));
html.push_str(&format!("<th>{}</th>", esc(&column.label)));
}
html.push_str("<th>Actions</th>");
html.push_str("</tr>");
@@ -229,7 +245,13 @@ fn build_item_list(
// Make sure we're using all tags for the item
let tag_links: Vec<String> = tags
.iter()
.map(|t| format!("<a href=\"/?tags={}\">{}</a>", t.name, t.name))
.map(|t| {
format!(
"<a href=\"/?tags={}\">{}</a>",
esc_attr(&t.name),
esc(&t.name)
)
})
.collect();
tag_links.join(", ")
}
@@ -268,7 +290,15 @@ fn build_item_list(
crate::config::ColumnAlignment::Center => "text-align: center;",
};
html.push_str(&format!("<td style=\"{align_style}\">{display_value}</td>"));
let rendered_value = if column.name == "tags" {
display_value // Already contains escaped HTML links
} else {
esc(&display_value)
};
html.push_str(&format!(
"<td style=\"{align_style}\">{rendered_value}</td>"
));
}
// Actions column
@@ -361,7 +391,9 @@ async fn show_item(
.map_err(|_| Html("<html><body>Internal Server Error</body></html>".to_string()))?;
Ok(response)
}
Err(e) => Err(Html(format!("<html><body>Error: {e}</body></html>"))),
Err(_e) => Err(Html(
"<html><body>An internal error occurred</body></html>".to_string(),
)),
}
}
@@ -396,7 +428,7 @@ fn build_item_details(conn: &Connection, id: i64) -> Result<String> {
));
html.push_str(&format!(
"<tr><th>Compression</th><td>{}</td></tr>",
item.compression
esc(&item.compression)
));
// Tags row
@@ -406,7 +438,13 @@ fn build_item_details(conn: &Connection, id: i64) -> Result<String> {
} else {
let tag_links: Vec<String> = tags
.iter()
.map(|t| format!("<a href=\"/?tags={}\">{}</a>", t.name, t.name))
.map(|t| {
format!(
"<a href=\"/?tags={}\">{}</a>",
esc_attr(&t.name),
esc(&t.name)
)
})
.collect();
html.push_str(&tag_links.join(", "));
}
@@ -419,7 +457,8 @@ fn build_item_details(conn: &Connection, id: i64) -> Result<String> {
for meta in metas {
html.push_str(&format!(
"<tr><th>{}</th><td>{}</td></tr>",
meta.name, meta.value
esc(&meta.name),
esc(&meta.value)
));
}
}

View File

@@ -198,7 +198,7 @@ pub fn mode_status(
let status_service = crate::services::status_service::StatusService::new();
let output_format = crate::modes::common::settings_output_format(settings);
debug!("STATUS: About to generate status info");
let status_info = status_service.generate_status(cmd, settings, data_path, db_path);
let status_info = status_service.generate_status(cmd, settings, data_path, db_path)?;
debug!("STATUS: Status info generated successfully");
match output_format {

View File

@@ -298,7 +298,7 @@ pub fn mode_status_plugins(
let status_service = crate::services::status_service::StatusService::new();
let output_format = crate::modes::common::settings_output_format(settings);
debug!("STATUS_PLUGINS: About to generate status info");
let status_info = status_service.generate_status(cmd, settings, data_path, db_path);
let status_info = status_service.generate_status(cmd, settings, data_path, db_path)?;
debug!("STATUS_PLUGINS: Status info generated successfully");
match output_format {