diff --git a/Cargo.toml b/Cargo.toml index e969142..48617e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,9 @@ stderrlog = "0.6.0" strum = { version = "0.27.2", features = ["derive"] } strum_macros = "0.27.2" term = "1.1.0" +thiserror = "1.0" tokio = { version = "1.0", features = ["full"] } +tokio-stream = "0.1" tower = "0.5.2" tower-http = { version = "0.6.6", features = ["cors", "fs", "trace"] } utoipa = { version = "5.4.0", features = ["axum_extras"] } diff --git a/src/db.rs b/src/db.rs index 97f9967..9b9cb03 100644 --- a/src/db.rs +++ b/src/db.rs @@ -87,6 +87,37 @@ pub fn insert_item(conn: &Connection, item: Item) -> Result { Ok(conn.last_insert_rowid()) } +pub fn create_item(conn: &Connection, compression_type: crate::compression_engine::CompressionType) -> Result { + let item = Item { + id: None, + ts: chrono::Utc::now(), + size: None, + compression: compression_type.to_string(), + }; + let item_id = insert_item(conn, item.clone())?; + Ok(Item { + id: Some(item_id), + ..item + }) +} + +pub fn add_tag(conn: &Connection, item_id: i64, tag_name: &str) -> Result<()> { + let tag = Tag { + id: item_id, + name: tag_name.to_string(), + }; + insert_tag(conn, tag) +} + +pub fn add_meta(conn: &Connection, item_id: i64, name: &str, value: &str) -> Result<()> { + let meta = Meta { + id: item_id, + name: name.to_string(), + value: value.to_string(), + }; + store_meta(conn, meta) +} + pub fn update_item(conn: &Connection, item: Item) -> Result<()> { debug!("DB: Updating item: {:?}", item); conn.execute( diff --git a/src/modes/server/api/mcp.rs b/src/modes/server/api/mcp.rs index 37fa58a..d418e36 100644 --- a/src/modes/server/api/mcp.rs +++ b/src/modes/server/api/mcp.rs @@ -4,14 +4,13 @@ use axum::{ http::StatusCode, }; use futures::stream::{self, Stream}; -use log::{debug, error, info}; +use log::{debug, info}; use std::convert::Infallible; use std::time::Duration; use tokio_stream::StreamExt as _; use crate::modes::server::common::AppState; use crate::modes::server::mcp::KeepMcpServer; -use rmcp::ServiceExt; #[utoipa::path( get, @@ -34,10 +33,10 @@ pub async fn handle_mcp_sse( ) -> Result>>, StatusCode> { debug!("MCP: Starting SSE endpoint"); - let mcp_server = KeepMcpServer::new(state); + let _mcp_server = KeepMcpServer::new(state); // Create a simple message channel for SSE communication - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); // Send initial connection message let _ = tx.send("data: {\"type\":\"connection\",\"status\":\"connected\"}\n\n".to_string()); diff --git a/src/modes/server/mcp/server.rs b/src/modes/server/mcp/server.rs index 8567527..191f812 100644 --- a/src/modes/server/mcp/server.rs +++ b/src/modes/server/mcp/server.rs @@ -1,21 +1,8 @@ use anyhow::Result; -use rmcp::{ - handler::server::ServerHandler, - protocol::{ - InitializeParams, InitializeResult, ServerCapabilities, ToolsCapability, - CallToolParams, CallToolResult, ListToolsParams, ListToolsResult, - Tool, TextContent, Content, - }, -}; -use serde_json::Value; -use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::Arc; -use tokio::sync::Mutex; use log::{debug, warn}; +use serde_json::Value; use crate::modes::server::common::AppState; -use crate::db; use super::tools::{KeepTools, ToolError}; #[derive(Clone)] @@ -29,158 +16,19 @@ impl KeepMcpServer { } } -#[rmcp::async_trait] -impl ServerHandler for KeepMcpServer { - async fn initialize(&self, params: InitializeParams) -> Result { - debug!("MCP: Initializing Keep MCP server with client info: {:?}", params.client_info); - - Ok(InitializeResult { - protocol_version: "2024-11-05".to_string(), - capabilities: ServerCapabilities { - tools: Some(ToolsCapability { - list_changed: Some(false), - }), - ..Default::default() - }, - server_info: rmcp::protocol::ServerInfo { - name: "keep".to_string(), - version: "0.1.0".to_string(), - }, - }) - } - - async fn list_tools(&self, _params: ListToolsParams) -> Result { - debug!("MCP: Listing available tools"); - - let tools = vec![ - Tool { - name: "save_item".to_string(), - description: Some("Save content as a new item with optional tags and metadata".to_string()), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The content to save" - }, - "tags": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional tags to associate with the item" - }, - "metadata": { - "type": "object", - "additionalProperties": {"type": "string"}, - "description": "Optional metadata key-value pairs" - } - }, - "required": ["content"] - }), - }, - Tool { - name: "get_item".to_string(), - description: Some("Retrieve an item by ID".to_string()), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "id": { - "type": "integer", - "description": "The ID of the item to retrieve" - } - }, - "required": ["id"] - }), - }, - Tool { - name: "get_latest_item".to_string(), - description: Some("Retrieve the most recently saved item, optionally filtered by tags".to_string()), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "tags": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional tags to filter by - returns latest item with ALL specified tags" - } - } - }), - }, - Tool { - name: "list_items".to_string(), - description: Some("List stored items with optional filtering and pagination".to_string()), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "tags": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional tags to filter by" - }, - "limit": { - "type": "integer", - "description": "Maximum number of items to return (default: 10)" - }, - "offset": { - "type": "integer", - "description": "Number of items to skip (default: 0)" - } - } - }), - }, - Tool { - name: "search_items".to_string(), - description: Some("Search items by tags and metadata".to_string()), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "tags": { - "type": "array", - "items": {"type": "string"}, - "description": "Tags that items must have (AND operation)" - }, - "metadata": { - "type": "object", - "additionalProperties": {"type": "string"}, - "description": "Metadata key-value pairs that items must match" - } - } - }), - }, - ]; - - Ok(ListToolsResult { tools }) - } - - async fn call_tool(&self, params: CallToolParams) -> Result { - debug!("MCP: Calling tool '{}' with arguments: {:?}", params.name, params.arguments); +impl KeepMcpServer { + pub async fn handle_request(&self, method: &str, params: Option) -> Result { + debug!("MCP: Handling request '{}' with params: {:?}", method, params); let tools = KeepTools::new(self.state.clone()); - let result = match params.name.as_str() { - "save_item" => tools.save_item(params.arguments).await, - "get_item" => tools.get_item(params.arguments).await, - "get_latest_item" => tools.get_latest_item(params.arguments).await, - "list_items" => tools.list_items(params.arguments).await, - "search_items" => tools.search_items(params.arguments).await, - _ => Err(ToolError::UnknownTool(params.name.clone())), - }; - - match result { - Ok(content) => Ok(CallToolResult { - content: vec![Content::Text(TextContent { - text: content, - })], - is_error: Some(false), - }), - Err(e) => { - warn!("MCP: Tool execution failed: {}", e); - Ok(CallToolResult { - content: vec![Content::Text(TextContent { - text: format!("Error: {}", e), - })], - is_error: Some(true), - }) - } + match method { + "save_item" => tools.save_item(params).await, + "get_item" => tools.get_item(params).await, + "get_latest_item" => tools.get_latest_item(params).await, + "list_items" => tools.list_items(params).await, + "search_items" => tools.search_items(params).await, + _ => Err(ToolError::UnknownTool(method.to_string())), } } } diff --git a/src/modes/server/mcp/tools.rs b/src/modes/server/mcp/tools.rs index a849ebe..62bfcfd 100644 --- a/src/modes/server/mcp/tools.rs +++ b/src/modes/server/mcp/tools.rs @@ -1,7 +1,7 @@ use anyhow::{Result, anyhow}; use serde_json::Value; use std::collections::HashMap; -use std::io::Write; +use std::io::{Write, Read}; use std::str::FromStr; use log::{debug, warn}; @@ -22,6 +22,8 @@ pub enum ToolError { Io(#[from] std::io::Error), #[error("JSON error: {0}")] Json(#[from] serde_json::Error), + #[error("Parse error: {0}")] + Parse(#[from] strum::ParseError), #[error("Other error: {0}")] Other(#[from] anyhow::Error), } @@ -106,10 +108,16 @@ impl KeepTools { let mut item_path = self.state.data_dir.clone(); item_path.push(item_id.to_string()); - if let Err(e) = plugin.process_file(&*conn, item_id, &item_path) { - warn!("Failed to process file with plugin {:?}: {}", plugin_type, e); - continue; - } + // Process the file content through the plugin + let mut item_path = self.state.data_dir.clone(); + item_path.push(item_id.to_string()); + + let compression_engine = get_compression_engine(CompressionType::LZ4)?; + let mut reader = compression_engine.open(item_path)?; + let mut buffer = Vec::new(); + reader.read_to_end(&mut buffer)?; + + plugin.update(&buffer, &*conn); if let Err(e) = plugin.finalize(&*conn) { warn!("Failed to finalize plugin {:?}: {}", plugin_type, e); @@ -136,12 +144,12 @@ impl KeepTools { let mut item_path = self.state.data_dir.clone(); item_path.push(item_id.to_string()); - let compression_type = crate::compression_engine::CompressionType::from_str(&item.compression)?; + let compression_type = CompressionType::from_str(&item.compression)?; let compression_engine = get_compression_engine(compression_type)?; let mut reader = compression_engine.open(item_path)?; let mut content = String::new(); - std::io::Read::read_to_string(&mut reader, &mut content)?; + reader.read_to_string(&mut content)?; // Get metadata and tags let tags = db::get_item_tags(&mut *conn, &item)?; @@ -182,12 +190,12 @@ impl KeepTools { let mut item_path = self.state.data_dir.clone(); item_path.push(item_id.to_string()); - let compression_type = crate::compression_engine::CompressionType::from_str(&item.compression)?; + let compression_type = CompressionType::from_str(&item.compression)?; let compression_engine = get_compression_engine(compression_type)?; let mut reader = compression_engine.open(item_path)?; let mut content = String::new(); - std::io::Read::read_to_string(&mut reader, &mut content)?; + reader.read_to_string(&mut content)?; // Get metadata and tags let tags = db::get_item_tags(&mut *conn, &item)?; @@ -254,7 +262,7 @@ impl KeepTools { .map(|tags| tags.iter().map(|t| &t.name).collect::>()) .unwrap_or_default(); let item_meta = meta_map.get(&item_id) - .map(|meta| meta.iter().map(|m| (&m.name, &m.value)).collect::>()) + .cloned() .unwrap_or_default(); serde_json::json!({ @@ -318,7 +326,7 @@ impl KeepTools { .map(|tags| tags.iter().map(|t| &t.name).collect::>()) .unwrap_or_default(); let item_meta = meta_map.get(&item_id) - .map(|meta| meta.iter().map(|m| (&m.name, &m.value)).collect::>()) + .cloned() .unwrap_or_default(); serde_json::json!({