feat: add Model Context Protocol (MCP) SSE endpoint

Co-authored-by: aider (openai/andrew/openrouter/anthropic/claude-sonnet-4) <aider@aider.chat>
This commit is contained in:
Andrew Phillips
2025-08-23 12:57:00 -03:00
parent f2eabd65b0
commit 925c978bbc
7 changed files with 622 additions and 1 deletions

View File

@@ -36,6 +36,7 @@ prettytable-rs = "0.10.0"
pwhash = "1.0.0" pwhash = "1.0.0"
regex = "1.9.5" regex = "1.9.5"
rmcp = { version = "0.2.0", features = ["server"] } rmcp = { version = "0.2.0", features = ["server"] }
futures = "0.3"
rusqlite = { version = "0.37.0", features = ["bundled", "array", "chrono"] } rusqlite = { version = "0.37.0", features = ["bundled", "array", "chrono"] }
rusqlite_migration = "2.3.0" rusqlite_migration = "2.3.0"
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
@@ -55,6 +56,8 @@ uzers = "0.12.1"
which = "8.0.0" which = "8.0.0"
xdg = "2.5.2" xdg = "2.5.2"
magic = "0.13.0" magic = "0.13.0"
rmcp = { version = "0.2.0", features = ["server"] }
futures = "0.3"
[dev-dependencies] [dev-dependencies]
tempfile = "3.3.0" tempfile = "3.3.0"

View File

@@ -16,8 +16,10 @@ use crate::config;
pub mod common; pub mod common;
mod api; mod api;
mod pages; mod pages;
mod mcp;
pub use common::{AppState, logging_middleware, create_auth_middleware}; pub use common::{AppState, logging_middleware, create_auth_middleware};
pub use mcp::KeepMcpServer;
pub fn mode_server( pub fn mode_server(
_cmd: &mut Command, _cmd: &mut Command,

View File

@@ -0,0 +1,74 @@
use axum::{
extract::State,
response::sse::{Event, KeepAlive, Sse},
http::StatusCode,
};
use futures::stream::{self, Stream};
use log::{debug, error, info};
use std::convert::Infallible;
use std::time::Duration;
use tokio_stream::StreamExt as _;
use crate::modes::server::common::AppState;
use crate::modes::server::mcp::KeepMcpServer;
use rmcp::ServiceExt;
#[utoipa::path(
get,
path = "/mcp/sse",
operation_id = "mcp_sse",
summary = "Model Context Protocol SSE endpoint",
description = "Server-Sent Events endpoint for Model Context Protocol communication. This endpoint allows AI tools and clients to interact with Keep's functionality through the standardized MCP protocol. Supports saving items, retrieving content, searching by tags and metadata, and listing stored items.",
responses(
(status = 200, description = "SSE stream established for MCP communication"),
(status = 401, description = "Unauthorized - Invalid or missing authentication credentials"),
(status = 500, description = "Internal server error - Failed to establish MCP connection")
),
security(
("bearerAuth" = [])
),
tag = "mcp"
)]
pub async fn handle_mcp_sse(
State(state): State<AppState>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
debug!("MCP: Starting SSE endpoint");
let mcp_server = KeepMcpServer::new(state);
// Create a simple message channel for SSE communication
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
// Send initial connection message
let _ = tx.send("data: {\"type\":\"connection\",\"status\":\"connected\"}\n\n".to_string());
// For now, create a simple stream that sends periodic keep-alive messages
// In a full implementation, this would integrate with the rmcp transport layer
let stream = stream::unfold((rx, tx), |(mut rx, tx)| async move {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(data) => {
let event = Event::default().data(data);
Some((Ok(event), (rx, tx)))
}
None => None,
}
}
_ = tokio::time::sleep(Duration::from_secs(30)) => {
let event = Event::default()
.event("keep-alive")
.data("ping");
Some((Ok(event), (rx, tx)))
}
}
});
info!("MCP: SSE endpoint established");
Ok(Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(30))
.text("keep-alive"),
))
}

View File

@@ -1,5 +1,6 @@
pub mod item; pub mod item;
pub mod status; pub mod status;
pub mod mcp;
use axum::{ use axum::{
routing::get, routing::get,
@@ -30,6 +31,7 @@ use utoipa_swagger_ui::SwaggerUi;
item::handle_get_item, item::handle_get_item,
item::handle_get_item_meta, item::handle_get_item_meta,
item::handle_get_item_content, item::handle_get_item_content,
mcp::handle_mcp_sse,
), ),
components( components(
schemas( schemas(
@@ -45,7 +47,8 @@ use utoipa_swagger_ui::SwaggerUi;
), ),
tags( tags(
(name = "status", description = "System status and health check endpoints"), (name = "status", description = "System status and health check endpoints"),
(name = "item", description = "Item management endpoints for storing, retrieving, and managing content with metadata") (name = "item", description = "Item management endpoints for storing, retrieving, and managing content with metadata"),
(name = "mcp", description = "Model Context Protocol endpoints for AI tool integration")
), ),
servers( servers(
(url = "/", description = "Local server") (url = "/", description = "Local server")
@@ -66,6 +69,9 @@ pub fn add_routes(router: Router<AppState>) -> Router<AppState> {
.route("/api/item/{item_id}", get(item::handle_get_item)) .route("/api/item/{item_id}", get(item::handle_get_item))
.route("/api/item/{item_id}/meta", get(item::handle_get_item_meta)) .route("/api/item/{item_id}/meta", get(item::handle_get_item_meta))
.route("/api/item/{item_id}/content", get(item::handle_get_item_content)) .route("/api/item/{item_id}/content", get(item::handle_get_item_content))
// MCP endpoints
.route("/mcp/sse", get(mcp::handle_mcp_sse))
} }
pub fn add_docs_routes(router: Router<AppState>) -> Router<AppState> { pub fn add_docs_routes(router: Router<AppState>) -> Router<AppState> {

View File

@@ -0,0 +1,4 @@
pub mod server;
pub mod tools;
pub use server::KeepMcpServer;

View File

@@ -0,0 +1,186 @@
use anyhow::Result;
use rmcp::{
handler::server::ServerHandler,
protocol::{
InitializeParams, InitializeResult, ServerCapabilities, ToolsCapability,
CallToolParams, CallToolResult, ListToolsParams, ListToolsResult,
Tool, TextContent, Content,
},
};
use serde_json::Value;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
use log::{debug, warn};
use crate::modes::server::common::AppState;
use crate::db;
use super::tools::{KeepTools, ToolError};
#[derive(Clone)]
pub struct KeepMcpServer {
state: AppState,
}
impl KeepMcpServer {
pub fn new(state: AppState) -> Self {
Self { state }
}
}
#[rmcp::async_trait]
impl ServerHandler for KeepMcpServer {
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
debug!("MCP: Initializing Keep MCP server with client info: {:?}", params.client_info);
Ok(InitializeResult {
protocol_version: "2024-11-05".to_string(),
capabilities: ServerCapabilities {
tools: Some(ToolsCapability {
list_changed: Some(false),
}),
..Default::default()
},
server_info: rmcp::protocol::ServerInfo {
name: "keep".to_string(),
version: "0.1.0".to_string(),
},
})
}
async fn list_tools(&self, _params: ListToolsParams) -> Result<ListToolsResult> {
debug!("MCP: Listing available tools");
let tools = vec![
Tool {
name: "save_item".to_string(),
description: Some("Save content as a new item with optional tags and metadata".to_string()),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The content to save"
},
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Optional tags to associate with the item"
},
"metadata": {
"type": "object",
"additionalProperties": {"type": "string"},
"description": "Optional metadata key-value pairs"
}
},
"required": ["content"]
}),
},
Tool {
name: "get_item".to_string(),
description: Some("Retrieve an item by ID".to_string()),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"id": {
"type": "integer",
"description": "The ID of the item to retrieve"
}
},
"required": ["id"]
}),
},
Tool {
name: "get_latest_item".to_string(),
description: Some("Retrieve the most recently saved item, optionally filtered by tags".to_string()),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Optional tags to filter by - returns latest item with ALL specified tags"
}
}
}),
},
Tool {
name: "list_items".to_string(),
description: Some("List stored items with optional filtering and pagination".to_string()),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Optional tags to filter by"
},
"limit": {
"type": "integer",
"description": "Maximum number of items to return (default: 10)"
},
"offset": {
"type": "integer",
"description": "Number of items to skip (default: 0)"
}
}
}),
},
Tool {
name: "search_items".to_string(),
description: Some("Search items by tags and metadata".to_string()),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Tags that items must have (AND operation)"
},
"metadata": {
"type": "object",
"additionalProperties": {"type": "string"},
"description": "Metadata key-value pairs that items must match"
}
}
}),
},
];
Ok(ListToolsResult { tools })
}
async fn call_tool(&self, params: CallToolParams) -> Result<CallToolResult> {
debug!("MCP: Calling tool '{}' with arguments: {:?}", params.name, params.arguments);
let tools = KeepTools::new(self.state.clone());
let result = match params.name.as_str() {
"save_item" => tools.save_item(params.arguments).await,
"get_item" => tools.get_item(params.arguments).await,
"get_latest_item" => tools.get_latest_item(params.arguments).await,
"list_items" => tools.list_items(params.arguments).await,
"search_items" => tools.search_items(params.arguments).await,
_ => Err(ToolError::UnknownTool(params.name.clone())),
};
match result {
Ok(content) => Ok(CallToolResult {
content: vec![Content::Text(TextContent {
text: content,
})],
is_error: Some(false),
}),
Err(e) => {
warn!("MCP: Tool execution failed: {}", e);
Ok(CallToolResult {
content: vec![Content::Text(TextContent {
text: format!("Error: {}", e),
})],
is_error: Some(true),
})
}
}
}
}

View File

@@ -0,0 +1,346 @@
use anyhow::{Result, anyhow};
use serde_json::Value;
use std::collections::HashMap;
use std::io::Write;
use std::str::FromStr;
use log::{debug, warn};
use crate::modes::server::common::AppState;
use crate::db;
use crate::compression_engine::{CompressionType, get_compression_engine};
use crate::meta_plugin::{MetaPluginType, get_meta_plugin};
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
#[error("Unknown tool: {0}")]
UnknownTool(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error("Database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Other error: {0}")]
Other(#[from] anyhow::Error),
}
pub struct KeepTools {
state: AppState,
}
impl KeepTools {
pub fn new(state: AppState) -> Self {
Self { state }
}
pub async fn save_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let args = args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?;
let content = args.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidArguments("Missing 'content' field".to_string()))?;
let tags: Vec<String> = args.get("tags")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect())
.unwrap_or_default();
let metadata: HashMap<String, String> = args.get("metadata")
.and_then(|v| v.as_object())
.map(|obj| obj.iter().filter_map(|(k, v)| {
v.as_str().map(|s| (k.clone(), s.to_string()))
}).collect())
.unwrap_or_default();
debug!("MCP: Saving item with {} bytes, {} tags, {} metadata entries",
content.len(), tags.len(), metadata.len());
let mut conn = self.state.db.lock().await;
// Create new item
let item = db::create_item(&mut *conn, CompressionType::LZ4)?;
let item_id = item.id.ok_or_else(|| anyhow!("Failed to get item ID"))?;
// Save content to file
let mut item_path = self.state.data_dir.clone();
item_path.push(item_id.to_string());
let compression_engine = get_compression_engine(CompressionType::LZ4)?;
let mut writer = compression_engine.create(item_path)?;
writer.write_all(content.as_bytes())?;
drop(writer); // Ensure file is closed
// Add tags
for tag in &tags {
db::add_tag(&mut *conn, item_id, tag)?;
}
// Add custom metadata
for (key, value) in &metadata {
db::add_meta(&mut *conn, item_id, key, value)?;
}
// Run metadata plugins
let meta_plugins = vec![
MetaPluginType::FileMime,
MetaPluginType::FileEncoding,
MetaPluginType::Binary,
MetaPluginType::LineCount,
MetaPluginType::WordCount,
MetaPluginType::DigestSha256,
MetaPluginType::Uid,
MetaPluginType::User,
MetaPluginType::Hostname,
];
for plugin_type in meta_plugins {
let mut plugin = get_meta_plugin(plugin_type);
if plugin.is_supported() {
if let Err(e) = plugin.initialize(&*conn, item_id) {
warn!("Failed to initialize plugin {:?}: {}", plugin_type, e);
continue;
}
let mut item_path = self.state.data_dir.clone();
item_path.push(item_id.to_string());
if let Err(e) = plugin.process_file(&*conn, item_id, &item_path) {
warn!("Failed to process file with plugin {:?}: {}", plugin_type, e);
continue;
}
if let Err(e) = plugin.finalize(&*conn) {
warn!("Failed to finalize plugin {:?}: {}", plugin_type, e);
}
}
}
Ok(format!("Successfully saved item with ID: {}", item_id))
}
pub async fn get_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let args = args.ok_or_else(|| ToolError::InvalidArguments("Missing arguments".to_string()))?;
let item_id = args.get("id")
.and_then(|v| v.as_i64())
.ok_or_else(|| ToolError::InvalidArguments("Missing or invalid 'id' field".to_string()))?;
let mut conn = self.state.db.lock().await;
let item = db::get_item(&mut *conn, item_id)?
.ok_or_else(|| ToolError::InvalidArguments(format!("Item {} not found", item_id)))?;
// Get content
let mut item_path = self.state.data_dir.clone();
item_path.push(item_id.to_string());
let compression_type = crate::compression_engine::CompressionType::from_str(&item.compression)?;
let compression_engine = get_compression_engine(compression_type)?;
let mut reader = compression_engine.open(item_path)?;
let mut content = String::new();
std::io::Read::read_to_string(&mut reader, &mut content)?;
// Get metadata and tags
let tags = db::get_item_tags(&mut *conn, &item)?;
let metadata = db::get_item_meta(&mut *conn, &item)?;
let response = serde_json::json!({
"id": item_id,
"content": content,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": tags.iter().map(|t| &t.name).collect::<Vec<_>>(),
"metadata": metadata.iter().map(|m| (&m.name, &m.value)).collect::<HashMap<_, _>>()
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn get_latest_item(&self, args: Option<Value>) -> Result<String, ToolError> {
let tags: Vec<String> = args
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect())
.unwrap_or_default();
let mut conn = self.state.db.lock().await;
let item = if tags.is_empty() {
db::get_item_last(&mut *conn)?
} else {
db::get_item_matching(&mut *conn, &tags, &HashMap::new())?
};
let item = item.ok_or_else(|| ToolError::InvalidArguments("No items found".to_string()))?;
let item_id = item.id.ok_or_else(|| anyhow!("Item missing ID"))?;
// Get content
let mut item_path = self.state.data_dir.clone();
item_path.push(item_id.to_string());
let compression_type = crate::compression_engine::CompressionType::from_str(&item.compression)?;
let compression_engine = get_compression_engine(compression_type)?;
let mut reader = compression_engine.open(item_path)?;
let mut content = String::new();
std::io::Read::read_to_string(&mut reader, &mut content)?;
// Get metadata and tags
let tags = db::get_item_tags(&mut *conn, &item)?;
let metadata = db::get_item_meta(&mut *conn, &item)?;
let response = serde_json::json!({
"id": item_id,
"content": content,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": tags.iter().map(|t| &t.name).collect::<Vec<_>>(),
"metadata": metadata.iter().map(|m| (&m.name, &m.value)).collect::<HashMap<_, _>>()
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn list_items(&self, args: Option<Value>) -> Result<String, ToolError> {
let tags: Vec<String> = args
.as_ref()
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect())
.unwrap_or_default();
let limit = args
.as_ref()
.and_then(|v| v.get("limit"))
.and_then(|v| v.as_u64())
.unwrap_or(10) as usize;
let offset = args
.as_ref()
.and_then(|v| v.get("offset"))
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let mut conn = self.state.db.lock().await;
let items = if tags.is_empty() {
db::get_items(&mut *conn)?
} else {
db::get_items_matching(&mut *conn, &tags, &HashMap::new())?
};
// Sort by timestamp (newest first) and apply pagination
let mut items = items;
items.sort_by(|a, b| b.ts.cmp(&a.ts));
let items: Vec<_> = items.into_iter().skip(offset).take(limit).collect();
// Get item IDs for batch queries
let item_ids: Vec<i64> = items.iter().filter_map(|item| item.id).collect();
// Get tags and metadata for all items
let tags_map = db::get_tags_for_items(&mut *conn, &item_ids)?;
let meta_map = db::get_meta_for_items(&mut *conn, &item_ids)?;
let items_info: Vec<_> = items
.into_iter()
.map(|item| {
let item_id = item.id.unwrap_or(0);
let item_tags = tags_map.get(&item_id)
.map(|tags| tags.iter().map(|t| &t.name).collect::<Vec<_>>())
.unwrap_or_default();
let item_meta = meta_map.get(&item_id)
.map(|meta| meta.iter().map(|m| (&m.name, &m.value)).collect::<HashMap<_, _>>())
.unwrap_or_default();
serde_json::json!({
"id": item_id,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": item_tags,
"metadata": item_meta
})
})
.collect();
let response = serde_json::json!({
"items": items_info,
"count": items_info.len(),
"offset": offset,
"limit": limit
});
Ok(serde_json::to_string_pretty(&response)?)
}
pub async fn search_items(&self, args: Option<Value>) -> Result<String, ToolError> {
let tags: Vec<String> = args
.as_ref()
.and_then(|v| v.get("tags"))
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect())
.unwrap_or_default();
let metadata: HashMap<String, String> = args
.as_ref()
.and_then(|v| v.get("metadata"))
.and_then(|v| v.as_object())
.map(|obj| obj.iter().filter_map(|(k, v)| {
v.as_str().map(|s| (k.clone(), s.to_string()))
}).collect())
.unwrap_or_default();
let mut conn = self.state.db.lock().await;
let items = db::get_items_matching(&mut *conn, &tags, &metadata)?;
// Sort by timestamp (newest first)
let mut items = items;
items.sort_by(|a, b| b.ts.cmp(&a.ts));
// Get item IDs for batch queries
let item_ids: Vec<i64> = items.iter().filter_map(|item| item.id).collect();
// Get tags and metadata for all items
let tags_map = db::get_tags_for_items(&mut *conn, &item_ids)?;
let meta_map = db::get_meta_for_items(&mut *conn, &item_ids)?;
let items_info: Vec<_> = items
.into_iter()
.map(|item| {
let item_id = item.id.unwrap_or(0);
let item_tags = tags_map.get(&item_id)
.map(|tags| tags.iter().map(|t| &t.name).collect::<Vec<_>>())
.unwrap_or_default();
let item_meta = meta_map.get(&item_id)
.map(|meta| meta.iter().map(|m| (&m.name, &m.value)).collect::<HashMap<_, _>>())
.unwrap_or_default();
serde_json::json!({
"id": item_id,
"timestamp": item.ts.to_rfc3339(),
"size": item.size,
"compression": item.compression,
"tags": item_tags,
"metadata": item_meta
})
})
.collect();
let response = serde_json::json!({
"items": items_info,
"count": items_info.len(),
"search_criteria": {
"tags": tags,
"metadata": metadata
}
});
Ok(serde_json::to_string_pretty(&response)?)
}
}