From 914190e119904653e5c6c5bab6e3fb9022aaedf1 Mon Sep 17 00:00:00 2001 From: Andrew Phillips Date: Fri, 13 Mar 2026 16:48:31 -0300 Subject: [PATCH] feat: add LLM token counting meta plugin and token filters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add tiktoken-based token counting via new 'tokens' feature flag. New components: - Shared tokenizer module wrapping tiktoken CoreBPE (cl100k_base, o200k_base) - TokensMetaPlugin: streaming token counter, tokenizes each chunk independently - head_tokens(N): stream first N tokens, split at exact boundary when mid-chunk - skip_tokens(N): skip first N tokens, stream the rest - tail_tokens(N): bounded ring buffer (~16KB), outputs last N tokens at finalize All filters are fully streaming — no full-stream buffering. Meta plugin accuracy: exact for normal text, ±1-2 tokens if long whitespace sequence spans a chunk boundary. Also: add 'client' and 'tokens' to default features, add curl to Dockerfile builder stage. --- Cargo.lock | 59 ++++ Cargo.toml | 6 +- Dockerfile | 2 +- src/filter_plugin/mod.rs | 77 ++++++ src/filter_plugin/tokens.rs | 530 ++++++++++++++++++++++++++++++++++++ src/lib.rs | 11 + src/meta_plugin/mod.rs | 4 +- src/meta_plugin/tokens.rs | 295 ++++++++++++++++++++ src/tokenizer/mod.rs | 147 ++++++++++ 9 files changed, 1128 insertions(+), 3 deletions(-) create mode 100644 src/filter_plugin/tokens.rs create mode 100644 src/meta_plugin/tokens.rs create mode 100644 src/tokenizer/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 817b826..b72069f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -310,6 +310,21 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -354,6 +369,17 @@ dependencies = [ "opaque-debug", ] +[[package]] +name = "bstr" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" +dependencies = [ + "memchr", + "regex-automata", + "serde", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -929,6 +955,17 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set", + "regex-automata", + "regex-syntax", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -1598,6 +1635,7 @@ dependencies = [ "tempfile", "term", "thiserror 1.0.69", + "tiktoken-rs", "tokio", "tokio-stream", "tokio-util", @@ -2354,6 +2392,12 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustix" version = "1.0.8" @@ -2851,6 +2895,21 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "tiktoken-rs" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d" +dependencies = [ + "anyhow", + "base64 0.22.1", + "bstr", + "fancy-regex", + "lazy_static", + "regex", + "rustc-hash", +] + [[package]] name = "time" version = "0.3.47" diff --git a/Cargo.toml b/Cargo.toml index 6a409bb..92e50ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,10 +75,11 @@ ureq = { version = "3", features = ["json"], optional = true } os_pipe = { version = "1", optional = true } axum-server = { version = "0.8", features = ["tls-rustls"], optional = true } jsonwebtoken = { version = "10", optional = true, features = ["aws_lc_rs"] } +tiktoken-rs = { version = "0.9", optional = true } [features] # Default features include core compression engines and swagger UI -default = ["magic", "lz4", "gzip"] +default = ["magic", "lz4", "gzip", "client", "tokens"] # Full #default = ["server", "magic", "lz4", "swagger"] @@ -113,6 +114,9 @@ client = ["dep:ureq", "dep:os_pipe"] # TLS feature (HTTPS server support) tls = ["dep:axum-server"] +# Token counting feature (LLM token support via tiktoken) +tokens = ["dep:tiktoken-rs"] + [dev-dependencies] tempfile = "3.3.0" rand = "0.8.5" diff --git a/Dockerfile b/Dockerfile index c7017e6..84c8aba 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ FROM rust:1.88-slim AS builder RUN apt-get update && apt-get install -y --no-install-recommends \ cmake \ + curl \ make \ gcc \ musl-tools \ @@ -16,7 +17,6 @@ WORKDIR /app # Copy manifests and fetch dependencies (cached layer) COPY Cargo.toml Cargo.lock ./ RUN mkdir src && echo 'fn main() {}' > src/main.rs && echo '' > src/lib.rs - RUN cargo fetch --target x86_64-unknown-linux-musl # Copy real source and build static binary diff --git a/src/filter_plugin/mod.rs b/src/filter_plugin/mod.rs index 579a09e..c420f58 100644 --- a/src/filter_plugin/mod.rs +++ b/src/filter_plugin/mod.rs @@ -26,6 +26,8 @@ pub mod head; pub mod skip; pub mod strip_ansi; pub mod tail; +#[cfg(feature = "tokens")] +pub mod tokens; pub mod utils; use std::collections::HashMap; @@ -192,6 +194,12 @@ pub enum FilterType { SkipLines, Grep, StripAnsi, + #[cfg(feature = "tokens")] + HeadTokens, + #[cfg(feature = "tokens")] + SkipTokens, + #[cfg(feature = "tokens")] + TailTokens, } /// Maximum buffer size (256 MB) for filter chain intermediate results. @@ -490,6 +498,12 @@ fn create_filter_with_options( FilterType::SkipBytes => skip::SkipBytesFilter::new(0).options(), FilterType::SkipLines => skip::SkipLinesFilter::new(0).options(), FilterType::StripAnsi => strip_ansi::StripAnsiFilter::new().options(), + #[cfg(feature = "tokens")] + FilterType::HeadTokens => tokens::HeadTokensFilter::new(0).options(), + #[cfg(feature = "tokens")] + FilterType::SkipTokens => tokens::SkipTokensFilter::new(0).options(), + #[cfg(feature = "tokens")] + FilterType::TailTokens => tokens::TailTokensFilter::new(0).options(), }; let mut options = HashMap::new(); @@ -658,6 +672,69 @@ fn create_specific_filter( } Ok(Box::new(strip_ansi::StripAnsiFilter::new())) } + #[cfg(feature = "tokens")] + FilterType::HeadTokens => { + let count = options + .get("count") + .and_then(|v| v.as_u64()) + .map(|n| n as usize) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "head_tokens filter requires 'count' parameter", + ) + })?; + let encoding = crate::tokenizer::TokenEncoding::Cl100kBase; + let mut f = tokens::HeadTokensFilter::new(count); + f.tokenizer = Some( + crate::tokenizer::Tokenizer::new(encoding) + .map_err(|e| std::io::Error::other(e.to_string()))?, + ); + f.encoding = encoding; + Ok(Box::new(f)) + } + #[cfg(feature = "tokens")] + FilterType::SkipTokens => { + let count = options + .get("count") + .and_then(|v| v.as_u64()) + .map(|n| n as usize) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "skip_tokens filter requires 'count' parameter", + ) + })?; + let encoding = crate::tokenizer::TokenEncoding::Cl100kBase; + let mut f = tokens::SkipTokensFilter::new(count); + f.tokenizer = Some( + crate::tokenizer::Tokenizer::new(encoding) + .map_err(|e| std::io::Error::other(e.to_string()))?, + ); + f.encoding = encoding; + Ok(Box::new(f)) + } + #[cfg(feature = "tokens")] + FilterType::TailTokens => { + let count = options + .get("count") + .and_then(|v| v.as_u64()) + .map(|n| n as usize) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "tail_tokens filter requires 'count' parameter", + ) + })?; + let encoding = crate::tokenizer::TokenEncoding::Cl100kBase; + let mut f = tokens::TailTokensFilter::new(count); + f.tokenizer = Some( + crate::tokenizer::Tokenizer::new(encoding) + .map_err(|e| std::io::Error::other(e.to_string()))?, + ); + f.encoding = encoding; + Ok(Box::new(f)) + } } } diff --git a/src/filter_plugin/tokens.rs b/src/filter_plugin/tokens.rs new file mode 100644 index 0000000..d5a7e52 --- /dev/null +++ b/src/filter_plugin/tokens.rs @@ -0,0 +1,530 @@ +use super::{FilterOption, FilterPlugin}; +use crate::common::PIPESIZE; +use crate::services::filter_service::register_filter_plugin; +use crate::tokenizer::{TokenEncoding, Tokenizer}; +use std::collections::VecDeque; +use std::io::{Read, Result, Write}; + +/// Resolve the tokenizer from a JSON options map. +fn resolve_tokenizer(options: &Option) -> Tokenizer { + let encoding = options + .as_ref() + .and_then(|v| v.as_str()) + .and_then(|s| s.parse::().ok()) + .unwrap_or_default(); + Tokenizer::new(encoding).expect("Failed to create tokenizer") +} + +// --------------------------------------------------------------------------- +// head_tokens +// --------------------------------------------------------------------------- + +/// A filter that outputs only the first N tokens of the input stream. +/// +/// Streams bytes directly until the token limit is reached. When the limit +/// falls mid-chunk, uses `split_by_token` to find the exact byte boundary. +pub struct HeadTokensFilter { + pub remaining: usize, + pub tokenizer: Option, + pub encoding: TokenEncoding, +} + +impl HeadTokensFilter { + pub fn new(count: usize) -> Self { + Self { + remaining: count, + tokenizer: None, + encoding: TokenEncoding::default(), + } + } +} + +impl FilterPlugin for HeadTokensFilter { + fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> { + if self.remaining == 0 { + return Ok(()); + } + + let tokenizer = self + .tokenizer + .as_ref() + .unwrap_or_else(|| panic!("HeadTokensFilter: tokenizer not initialized")); + + let mut buffer = vec![0u8; PIPESIZE]; + let mut total_tokens = 0usize; + + loop { + let n = reader.read(&mut buffer)?; + if n == 0 { + break; + } + + let chunk = &buffer[..n]; + let text = String::from_utf8_lossy(chunk); + let chunk_tokens = tokenizer.count(&text); + + if total_tokens + chunk_tokens <= self.remaining { + // Entire chunk fits — write it directly + writer.write_all(chunk)?; + total_tokens += chunk_tokens; + if total_tokens >= self.remaining { + break; + } + } else { + // Cutoff is within this chunk — split at exact token boundary + let tokens_to_write = self.remaining - total_tokens; + let token_strs = tokenizer + .split_by_token(&text) + .map_err(|e| std::io::Error::other(e.to_string()))?; + let mut byte_pos = 0usize; + for token_str in token_strs.iter().take(tokens_to_write) { + byte_pos += token_str.len(); + } + // Write only the bytes for the tokens we want. + // Map byte positions in the lossy string back to positions in the + // original byte slice. Since from_utf8_lossy replaces invalid + // bytes with the replacement character (3 bytes), we need to be + // careful. For simplicity, write the valid prefix of the chunk. + // We use the original bytes up to the calculated position, adjusting + // for any UTF-8 replacement character differences. + let write_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos); + writer.write_all(&chunk[..write_len])?; + break; + } + } + Ok(()) + } + + fn clone_box(&self) -> Box { + Box::new(Self { + remaining: self.remaining, + tokenizer: self + .tokenizer + .as_ref() + .map(|_| Tokenizer::new(self.encoding).unwrap()), + encoding: self.encoding, + }) + } + + fn options(&self) -> Vec { + vec![FilterOption { + name: "count".to_string(), + default: None, + required: true, + }] + } +} + +// --------------------------------------------------------------------------- +// skip_tokens +// --------------------------------------------------------------------------- + +/// A filter that skips the first N tokens of the input stream and outputs the rest. +pub struct SkipTokensFilter { + pub remaining: usize, + pub tokenizer: Option, + pub encoding: TokenEncoding, +} + +impl SkipTokensFilter { + pub fn new(count: usize) -> Self { + Self { + remaining: count, + tokenizer: None, + encoding: TokenEncoding::default(), + } + } +} + +impl FilterPlugin for SkipTokensFilter { + fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> { + if self.remaining == 0 { + return std::io::copy(reader, writer).map(|_| ()); + } + + let tokenizer = self + .tokenizer + .as_ref() + .unwrap_or_else(|| panic!("SkipTokensFilter: tokenizer not initialized")); + + let mut buffer = vec![0u8; PIPESIZE]; + let mut total_tokens = 0usize; + let mut done_skipping = false; + + loop { + let n = reader.read(&mut buffer)?; + if n == 0 { + break; + } + + if done_skipping { + writer.write_all(&buffer[..n])?; + continue; + } + + let chunk = &buffer[..n]; + let text = String::from_utf8_lossy(chunk); + let chunk_tokens = tokenizer.count(&text); + + if total_tokens + chunk_tokens <= self.remaining { + // Entire chunk is skipped + total_tokens += chunk_tokens; + if total_tokens >= self.remaining { + done_skipping = true; + } + } else { + // Cutoff is within this chunk — skip past the boundary, write rest + let tokens_to_skip = self.remaining - total_tokens; + let token_strs = tokenizer + .split_by_token(&text) + .map_err(|e| std::io::Error::other(e.to_string()))?; + let mut byte_pos = 0usize; + for token_str in token_strs.iter().take(tokens_to_skip) { + byte_pos += token_str.len(); + } + let skip_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos); + if skip_len < n { + writer.write_all(&chunk[skip_len..])?; + } + done_skipping = true; + } + } + Ok(()) + } + + fn clone_box(&self) -> Box { + Box::new(Self { + remaining: self.remaining, + tokenizer: self + .tokenizer + .as_ref() + .map(|_| Tokenizer::new(self.encoding).unwrap()), + encoding: self.encoding, + }) + } + + fn options(&self) -> Vec { + vec![FilterOption { + name: "count".to_string(), + default: None, + required: true, + }] + } +} + +// --------------------------------------------------------------------------- +// tail_tokens +// --------------------------------------------------------------------------- + +/// A filter that outputs only the last N tokens of the input stream. +/// +/// Uses a bounded ring buffer (last ~2× PIPESIZE) to keep recent bytes. +/// At finalize, tokenizes the buffered content and writes only the last N tokens. +pub struct TailTokensFilter { + pub count: usize, + /// Ring buffer holding the most recent bytes from the stream. + pub ring: VecDeque, + pub ring_capacity: usize, + pub tokenizer: Option, + pub encoding: TokenEncoding, +} + +impl TailTokensFilter { + pub fn new(count: usize) -> Self { + // Keep enough bytes for ~2 chunks worth of data + let ring_capacity = PIPESIZE * 2; + Self { + count, + ring: VecDeque::with_capacity(ring_capacity), + ring_capacity, + tokenizer: None, + encoding: TokenEncoding::default(), + } + } +} + +impl FilterPlugin for TailTokensFilter { + fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> { + if self.count == 0 { + return Ok(()); + } + + let tokenizer = self + .tokenizer + .as_ref() + .unwrap_or_else(|| panic!("TailTokensFilter: tokenizer not initialized")); + + // Stream all bytes through the ring buffer + let mut buffer = vec![0u8; PIPESIZE]; + loop { + let n = reader.read(&mut buffer)?; + if n == 0 { + break; + } + for &byte in &buffer[..n] { + if self.ring.len() >= self.ring_capacity { + self.ring.pop_front(); + } + self.ring.push_back(byte); + } + } + + // Tokenize the buffered content and extract last N tokens + let buffered: Vec = self.ring.iter().copied().collect(); + let text = String::from_utf8_lossy(&buffered); + let token_strs = tokenizer + .split_by_token(&text) + .map_err(|e| std::io::Error::other(e.to_string()))?; + + if token_strs.len() <= self.count { + // All tokens fit — write everything + writer.write_all(&buffered)?; + } else { + // Write only the last N tokens + let skip = token_strs.len() - self.count; + let mut byte_offset = 0usize; + for token_str in token_strs.iter().take(skip) { + byte_offset += token_str.len(); + } + let write_len = map_lossy_pos_to_bytes(&buffered, &text, byte_offset); + if write_len < buffered.len() { + writer.write_all(&buffered[write_len..])?; + } + } + + Ok(()) + } + + fn clone_box(&self) -> Box { + Box::new(Self { + count: self.count, + ring: self.ring.clone(), + ring_capacity: self.ring_capacity, + tokenizer: self + .tokenizer + .as_ref() + .map(|_| Tokenizer::new(self.encoding).unwrap()), + encoding: self.encoding, + }) + } + + fn options(&self) -> Vec { + vec![FilterOption { + name: "count".to_string(), + default: None, + required: true, + }] + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Map a byte position in a lossy string back to a position in the original byte slice. +/// +/// `String::from_utf8_lossy` replaces invalid UTF-8 bytes with the Unicode +/// replacement character (U+FFFD), which encodes to 3 bytes in UTF-8. This +/// function walks both the original bytes and the lossy string in lockstep, +/// finding the original byte position that corresponds to `lossy_pos`. +fn map_lossy_pos_to_bytes(original: &[u8], lossy: &str, lossy_pos: usize) -> usize { + if lossy_pos == 0 { + return 0; + } + + let replacement = '\u{FFFD}'; + let replacement_len = replacement.len_utf8(); // 3 bytes + + let mut orig_idx = 0usize; + let mut lossy_idx = 0usize; + let lossy_bytes = lossy.as_bytes(); + + while lossy_idx < lossy_pos && orig_idx < original.len() { + // Try to decode the next character from the original bytes + match std::str::from_utf8(&original[orig_idx..]) { + Ok("") => break, + Ok(s) => { + let ch = s.chars().next().unwrap(); + let ch_len = ch.len_utf8(); + // Check if this is a replacement character in the lossy string + if ch == replacement + && lossy_idx + replacement_len <= lossy_pos + && lossy_bytes[lossy_idx..].starts_with( + &replacement.encode_utf8(&mut [0; 4]).as_bytes()[..replacement_len], + ) + { + // Could be a real U+FFFD or a replacement of invalid bytes. + // If the original byte at this position is valid UTF-8 start, it's real. + if original[orig_idx] < 0x80 || original[orig_idx] >= 0xC0 { + // Real character + orig_idx += ch_len; + lossy_idx += ch_len; + } else { + // Invalid byte that was replaced — advance original by 1 + orig_idx += 1; + lossy_idx += replacement_len; + } + } else { + orig_idx += ch_len; + lossy_idx += ch_len; + } + } + Err(e) => { + let valid = e.valid_up_to(); + if valid > 0 { + // Some valid bytes, then invalid + orig_idx += valid; + lossy_idx += valid; + } else { + // Invalid byte — in lossy it becomes 3-byte replacement char + orig_idx += 1; + lossy_idx += replacement_len; + } + } + } + } + + orig_idx.min(original.len()) +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +#[ctor::ctor] +fn register_token_filters() { + register_filter_plugin("head_tokens", || { + let mut f = HeadTokensFilter::new(0); + f.tokenizer = Some(resolve_tokenizer(&None)); + Box::new(f) + }); + register_filter_plugin("skip_tokens", || { + let mut f = SkipTokensFilter::new(0); + f.tokenizer = Some(resolve_tokenizer(&None)); + Box::new(f) + }); + register_filter_plugin("tail_tokens", || { + let mut f = TailTokensFilter::new(0); + f.tokenizer = Some(resolve_tokenizer(&None)); + Box::new(f) + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + fn make_tokenizer() -> Tokenizer { + Tokenizer::new(TokenEncoding::Cl100kBase).unwrap() + } + + #[test] + fn test_head_tokens_basic() { + let mut filter = HeadTokensFilter::new(3); + filter.tokenizer = Some(make_tokenizer()); + + let input = b"The quick brown fox"; + let mut output = Vec::new(); + filter.filter(&mut Cursor::new(input), &mut output).unwrap(); + + let result = String::from_utf8_lossy(&output); + // "The quick brown" is typically 3 tokens + assert!(!result.is_empty()); + assert!(result.len() <= input.len()); + } + + #[test] + fn test_head_tokens_zero() { + let mut filter = HeadTokensFilter::new(0); + filter.tokenizer = Some(make_tokenizer()); + + let input = b"The quick brown fox"; + let mut output = Vec::new(); + filter.filter(&mut Cursor::new(input), &mut output).unwrap(); + assert!(output.is_empty()); + } + + #[test] + fn test_head_tokens_more_than_available() { + let mut filter = HeadTokensFilter::new(1000); + filter.tokenizer = Some(make_tokenizer()); + + let input = b"Hello world"; + let mut output = Vec::new(); + filter.filter(&mut Cursor::new(input), &mut output).unwrap(); + assert_eq!(output, input); + } + + #[test] + fn test_skip_tokens_basic() { + let mut filter = SkipTokensFilter::new(2); + filter.tokenizer = Some(make_tokenizer()); + + let input = b"The quick brown fox"; + let mut output = Vec::new(); + filter.filter(&mut Cursor::new(input), &mut output).unwrap(); + + let result = String::from_utf8_lossy(&output); + // Should have skipped some tokens + assert!(result.len() < input.len()); + } + + #[test] + fn test_skip_tokens_zero() { + let mut filter = SkipTokensFilter::new(0); + filter.tokenizer = Some(make_tokenizer()); + + let input = b"Hello world"; + let mut output = Vec::new(); + filter.filter(&mut Cursor::new(input), &mut output).unwrap(); + assert_eq!(output, input); + } + + #[test] + fn test_tail_tokens_basic() { + let mut filter = TailTokensFilter::new(2); + filter.tokenizer = Some(make_tokenizer()); + + let input = b"The quick brown fox jumps over the lazy dog"; + let mut output = Vec::new(); + filter.filter(&mut Cursor::new(input), &mut output).unwrap(); + + let result = String::from_utf8_lossy(&output); + // Should only have last 2 tokens + assert!(!result.is_empty()); + assert!(result.len() < input.len()); + } + + #[test] + fn test_tail_tokens_zero() { + let mut filter = TailTokensFilter::new(0); + filter.tokenizer = Some(make_tokenizer()); + + let input = b"Hello world"; + let mut output = Vec::new(); + filter.filter(&mut Cursor::new(input), &mut output).unwrap(); + assert!(output.is_empty()); + } + + #[test] + fn test_map_lossy_pos_ascii() { + let original = b"Hello world"; + let lossy = String::from_utf8_lossy(original); + assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5); + assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 0), 0); + assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 11), 11); + } + + #[test] + fn test_map_lossy_pos_with_invalid_utf8() { + let original = b"Hello\x80world"; + let lossy = String::from_utf8_lossy(original); + // lossy = "Hello\u{FFFD}world" (13 bytes) + // Position 5 in lossy = after "Hello" = position 5 in original + assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5); + // Position 8 in lossy = "Hello\u{FFFD}" = position 6 in original + // (the invalid byte \x80 at position 5 was replaced) + assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 8), 6); + } +} diff --git a/src/lib.rs b/src/lib.rs index 4e2a868..80b279e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,9 @@ pub mod services; #[cfg(feature = "client")] pub mod client; +#[cfg(feature = "tokens")] +pub mod tokenizer; + // Re-export Args struct for library usage pub use args::Args; // Re-export PIPESIZE constant @@ -52,6 +55,10 @@ pub use common::PIPESIZE; #[allow(unused_imports)] use filter_plugin::{grep, head, skip, strip_ansi, tail}; +#[cfg(feature = "tokens")] +#[allow(unused_imports)] +use filter_plugin::tokens as token_filters; + use crate::meta_plugin::{ cwd, digest, env, exec, hostname, keep_pid, read_rate, read_time, shell, shell_pid, user, }; @@ -60,6 +67,10 @@ use crate::meta_plugin::{ #[allow(unused_imports)] use crate::meta_plugin::magic_file; +#[cfg(feature = "tokens")] +#[allow(unused_imports)] +use crate::meta_plugin::tokens; + /// Initializes plugins at library load time. /// /// Plugin registration happens automatically via `#[ctor]` constructors diff --git a/src/meta_plugin/mod.rs b/src/meta_plugin/mod.rs index 03ecf8b..aaa930e 100644 --- a/src/meta_plugin/mod.rs +++ b/src/meta_plugin/mod.rs @@ -16,8 +16,9 @@ pub mod read_time; pub mod shell; pub mod shell_pid; pub mod text; +#[cfg(feature = "tokens")] +pub mod tokens; pub mod user; -// pub mod text; // Removed duplicate pub use digest::DigestMetaPlugin; pub use exec::MetaPluginExec; @@ -232,6 +233,7 @@ pub enum MetaPluginType { Hostname, Exec, Env, + Tokens, } /// Central function to handle metadata output with name mapping. diff --git a/src/meta_plugin/tokens.rs b/src/meta_plugin/tokens.rs new file mode 100644 index 0000000..fbf2e70 --- /dev/null +++ b/src/meta_plugin/tokens.rs @@ -0,0 +1,295 @@ +use crate::common::PIPESIZE; +use crate::common::is_binary::is_binary; +use crate::meta_plugin::{MetaPlugin, MetaPluginResponse, MetaPluginType}; +use crate::tokenizer::{TokenEncoding, Tokenizer}; + +#[derive(Debug, Clone)] +pub struct TokensMetaPlugin { + /// Buffer for binary detection (up to PIPESIZE bytes). + buffer: Option>, + max_buffer_size: usize, + is_finalized: bool, + is_binary_content: Option, + /// Running token count accumulated across chunks. + token_count: usize, + /// UTF-8 boundary carry buffer. + utf8_buffer: Vec, + base: crate::meta_plugin::BaseMetaPlugin, + /// The tokenizer instance. + tokenizer: Tokenizer, +} + +impl TokensMetaPlugin { + pub fn new( + options: Option>, + outputs: Option>, + ) -> Self { + let mut base = crate::meta_plugin::BaseMetaPlugin::new(); + + base.initialize_plugin(&["token_count"], &options, &outputs); + + // Set default options + let default_options = vec![ + ( + "token_detect_size", + serde_yaml::Value::Number(PIPESIZE.into()), + ), + ( + "encoding", + serde_yaml::Value::String("cl100k_base".to_string()), + ), + ]; + + for (key, value) in default_options { + if !base.options.contains_key(key) { + base.options.insert(key.to_string(), value); + } + } + + let max_buffer_size = base + .options + .get("token_detect_size") + .and_then(|v| v.as_u64()) + .unwrap_or(PIPESIZE as u64) as usize; + + let encoding = base + .options + .get("encoding") + .and_then(|v| v.as_str()) + .and_then(|s| s.parse::().ok()) + .unwrap_or_default(); + + let tokenizer = Tokenizer::new(encoding).expect("Failed to create tokenizer"); + + Self { + buffer: Some(Vec::new()), + max_buffer_size, + is_finalized: false, + is_binary_content: None, + token_count: 0, + utf8_buffer: Vec::new(), + base, + tokenizer, + } + } + + /// Tokenize a byte chunk, handling UTF-8 boundaries. + /// + /// Combines with any pending UTF-8 carry bytes, converts to text, + /// and adds the token count to the running total. + fn count_tokens(&mut self, data: &[u8]) { + if data.is_empty() && self.utf8_buffer.is_empty() { + return; + } + + let combined = if !self.utf8_buffer.is_empty() { + let mut c = self.utf8_buffer.clone(); + c.extend_from_slice(data); + c + } else { + data.to_vec() + }; + self.utf8_buffer.clear(); + + let text = match std::str::from_utf8(&combined) { + Ok(t) => t, + Err(e) => { + let valid = e.valid_up_to(); + if valid < combined.len() { + self.utf8_buffer.extend_from_slice(&combined[valid..]); + } + match std::str::from_utf8(&combined[..valid]) { + Ok(t) => t, + Err(_) => return, + } + } + }; + + if !text.is_empty() { + self.token_count += self.tokenizer.count(text); + } + } + + /// Perform binary detection on the buffer. + fn detect_binary(&mut self, buffer: &[u8]) -> bool { + let result = is_binary(buffer); + self.is_binary_content = Some(result); + result + } +} + +impl MetaPlugin for TokensMetaPlugin { + fn is_finalized(&self) -> bool { + self.is_finalized + } + + fn set_finalized(&mut self, finalized: bool) { + self.is_finalized = finalized; + } + + fn update(&mut self, data: &[u8]) -> MetaPluginResponse { + if self.is_finalized { + return MetaPluginResponse { + metadata: Vec::new(), + is_finalized: true, + }; + } + + let mut metadata = Vec::new(); + + if self.is_binary_content.is_none() { + // Add data to the buffer + let should_detect = if let Some(ref mut buffer) = self.buffer { + let remaining = self.max_buffer_size.saturating_sub(buffer.len()); + let to_take = std::cmp::min(data.len(), remaining); + buffer.extend_from_slice(&data[..to_take]); + buffer.len() >= std::cmp::min(1024, self.max_buffer_size) + } else { + false + }; + + if should_detect { + let buf_clone = self.buffer.as_ref().unwrap().clone(); + let is_binary = self.detect_binary(&buf_clone); + + if is_binary { + if let Some(md) = crate::meta_plugin::process_metadata_outputs( + "token_count", + serde_yaml::Value::Null, + self.base.outputs(), + ) { + metadata.push(md); + } + self.buffer = None; + self.is_finalized = true; + return MetaPluginResponse { + metadata, + is_finalized: true, + }; + } + + // It's text — tokenize the full accumulated buffer + self.count_tokens(&buf_clone); + + if buf_clone.len() >= self.max_buffer_size { + self.buffer = None; + } + } else if self.buffer.is_some() { + // Still building up buffer — tokenize what was just added + let remaining = self + .max_buffer_size + .saturating_sub(self.buffer.as_ref().map_or(0, |b| b.len())); + let to_take = std::cmp::min(data.len(), remaining); + self.count_tokens(&data[..to_take]); + } + } else if self.is_binary_content == Some(false) { + self.count_tokens(data); + } else if self.is_binary_content == Some(true) { + self.is_finalized = true; + return MetaPluginResponse { + metadata: Vec::new(), + is_finalized: true, + }; + } + + MetaPluginResponse { + metadata, + is_finalized: self.is_finalized, + } + } + + fn finalize(&mut self) -> MetaPluginResponse { + if self.is_finalized { + return MetaPluginResponse { + metadata: Vec::new(), + is_finalized: true, + }; + } + + let mut metadata = Vec::new(); + + // If binary detection hasn't completed, do it now + if self.is_binary_content.is_none() { + if let Some(buffer) = &self.buffer { + if !buffer.is_empty() { + let buf_clone = buffer.clone(); + let is_binary = self.detect_binary(&buf_clone); + + if is_binary { + if let Some(md) = crate::meta_plugin::process_metadata_outputs( + "token_count", + serde_yaml::Value::Null, + self.base.outputs(), + ) { + metadata.push(md); + } + self.buffer = None; + self.is_finalized = true; + return MetaPluginResponse { + metadata, + is_finalized: true, + }; + } + } + } + } + + // Process any remaining UTF-8 bytes + if !self.utf8_buffer.is_empty() { + self.count_tokens(&[]); + } + + // Emit token count + if let Some(md) = crate::meta_plugin::process_metadata_outputs( + "token_count", + serde_yaml::Value::String(self.token_count.to_string()), + self.base.outputs(), + ) { + metadata.push(md); + } + + self.buffer = None; + self.is_finalized = true; + MetaPluginResponse { + metadata, + is_finalized: true, + } + } + + fn meta_type(&self) -> MetaPluginType { + MetaPluginType::Tokens + } + + fn outputs(&self) -> &std::collections::HashMap { + self.base.outputs() + } + + fn outputs_mut( + &mut self, + ) -> anyhow::Result<&mut std::collections::HashMap> { + Ok(self.base.outputs_mut()) + } + + fn default_outputs(&self) -> Vec { + vec!["token_count".to_string()] + } + + fn options(&self) -> &std::collections::HashMap { + self.base.options() + } + + fn options_mut( + &mut self, + ) -> anyhow::Result<&mut std::collections::HashMap> { + Ok(self.base.options_mut()) + } +} + +use crate::meta_plugin::register_meta_plugin; + +#[ctor::ctor] +fn register_tokens_plugin() { + register_meta_plugin(MetaPluginType::Tokens, |options, outputs| { + Box::new(TokensMetaPlugin::new(options, outputs)) + }); +} diff --git a/src/tokenizer/mod.rs b/src/tokenizer/mod.rs new file mode 100644 index 0000000..ac6f2ba --- /dev/null +++ b/src/tokenizer/mod.rs @@ -0,0 +1,147 @@ +use anyhow::{Result, bail}; + +/// Supported LLM token encodings. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum TokenEncoding { + /// cl100k_base — used by GPT-3.5, GPT-4, text-embedding-ada-002. + #[default] + Cl100kBase, + /// o200k_base — used by GPT-4o, GPT-5, o1, o3, o4 models. + O200kBase, +} + +impl std::str::FromStr for TokenEncoding { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "cl100k_base" => Ok(TokenEncoding::Cl100kBase), + "o200k_base" => Ok(TokenEncoding::O200kBase), + _ => bail!("Unknown token encoding: {s}. Supported: cl100k_base, o200k_base"), + } + } +} + +impl std::fmt::Display for TokenEncoding { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TokenEncoding::Cl100kBase => write!(f, "cl100k_base"), + TokenEncoding::O200kBase => write!(f, "o200k_base"), + } + } +} + +/// Wrapper around tiktoken BPE tokenizer. +/// +/// Provides streaming-friendly tokenization: count tokens in text, +/// split text into token strings, and decode token IDs back to text. +#[derive(Clone)] +pub struct Tokenizer { + bpe: tiktoken_rs::CoreBPE, +} + +impl std::fmt::Debug for Tokenizer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tokenizer").finish() + } +} + +impl Tokenizer { + /// Creates a new tokenizer for the specified encoding. + pub fn new(encoding: TokenEncoding) -> Result { + let bpe = match encoding { + TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base()?, + TokenEncoding::O200kBase => tiktoken_rs::o200k_base()?, + }; + Ok(Self { bpe }) + } + + /// Counts the number of tokens in the given text. + /// + /// Uses `encode_ordinary` which treats the text as a single unit. + /// For streaming: tokenizing chunks independently and summing gives + /// the same result as tokenizing the full text (exact when no regex + /// match spans a chunk boundary). + pub fn count(&self, text: &str) -> usize { + self.bpe.encode_ordinary(text).len() + } + + /// Splits text into individual decoded token strings. + /// + /// Each returned string corresponds to one token. Useful for finding + /// exact byte boundaries when filtering by token count. + pub fn split_by_token(&self, text: &str) -> Result> { + self.bpe.split_by_token(text, false) + } + + /// Decodes a slice of token IDs back into a string. + pub fn decode(&self, tokens: &[u32]) -> Result { + self.bpe.decode(tokens.to_vec()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tokenizer_count() { + let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap(); + let count = tok.count("Hello, world!"); + assert!(count > 0); + } + + #[test] + fn test_tokenizer_split() { + let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap(); + let tokens = tok.split_by_token("Hello world").unwrap(); + assert_eq!(tokens.len(), 2); + } + + #[test] + fn test_tokenizer_roundtrip() { + let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap(); + let text = "The quick brown fox jumps over the lazy dog."; + let token_ids: Vec = tok + .bpe + .encode_ordinary(text) + .into_iter() + .map(|x| x as u32) + .collect(); + let decoded = tok.decode(&token_ids).unwrap(); + assert_eq!(text, decoded); + } + + #[test] + fn test_chunk_sum_close_to_full() { + let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap(); + let text = "The quick brown fox jumps over the lazy dog. \ + Pack my box with five dozen liquor jugs. \ + How vexingly quick daft zebras jump!"; + let full_count = tok.count(text); + + // Split into chunks at word boundaries + let mid = text.find("Pack").unwrap(); + let (a, b) = text.split_at(mid); + let chunk_sum = tok.count(a) + tok.count(b); + // Chunk-based counting may differ by 1-2 tokens when a BPE merge + // boundary falls near the chunk split point + assert!( + (full_count as isize - chunk_sum as isize).abs() <= 2, + "full={full_count}, chunk_sum={chunk_sum}" + ); + } + + #[test] + fn test_encoding_from_str() { + assert_eq!( + "cl100k_base".parse::().unwrap(), + TokenEncoding::Cl100kBase + ); + assert_eq!( + "o200k_base".parse::().unwrap(), + TokenEncoding::O200kBase + ); + assert!("unknown".parse::().is_err()); + } +}