feat: add LLM token counting meta plugin and token filters
Add tiktoken-based token counting via new 'tokens' feature flag. New components: - Shared tokenizer module wrapping tiktoken CoreBPE (cl100k_base, o200k_base) - TokensMetaPlugin: streaming token counter, tokenizes each chunk independently - head_tokens(N): stream first N tokens, split at exact boundary when mid-chunk - skip_tokens(N): skip first N tokens, stream the rest - tail_tokens(N): bounded ring buffer (~16KB), outputs last N tokens at finalize All filters are fully streaming — no full-stream buffering. Meta plugin accuracy: exact for normal text, ±1-2 tokens if long whitespace sequence spans a chunk boundary. Also: add 'client' and 'tokens' to default features, add curl to Dockerfile builder stage.
This commit is contained in:
@@ -26,6 +26,8 @@ pub mod head;
|
||||
pub mod skip;
|
||||
pub mod strip_ansi;
|
||||
pub mod tail;
|
||||
#[cfg(feature = "tokens")]
|
||||
pub mod tokens;
|
||||
pub mod utils;
|
||||
|
||||
use std::collections::HashMap;
|
||||
@@ -192,6 +194,12 @@ pub enum FilterType {
|
||||
SkipLines,
|
||||
Grep,
|
||||
StripAnsi,
|
||||
#[cfg(feature = "tokens")]
|
||||
HeadTokens,
|
||||
#[cfg(feature = "tokens")]
|
||||
SkipTokens,
|
||||
#[cfg(feature = "tokens")]
|
||||
TailTokens,
|
||||
}
|
||||
|
||||
/// Maximum buffer size (256 MB) for filter chain intermediate results.
|
||||
@@ -490,6 +498,12 @@ fn create_filter_with_options(
|
||||
FilterType::SkipBytes => skip::SkipBytesFilter::new(0).options(),
|
||||
FilterType::SkipLines => skip::SkipLinesFilter::new(0).options(),
|
||||
FilterType::StripAnsi => strip_ansi::StripAnsiFilter::new().options(),
|
||||
#[cfg(feature = "tokens")]
|
||||
FilterType::HeadTokens => tokens::HeadTokensFilter::new(0).options(),
|
||||
#[cfg(feature = "tokens")]
|
||||
FilterType::SkipTokens => tokens::SkipTokensFilter::new(0).options(),
|
||||
#[cfg(feature = "tokens")]
|
||||
FilterType::TailTokens => tokens::TailTokensFilter::new(0).options(),
|
||||
};
|
||||
|
||||
let mut options = HashMap::new();
|
||||
@@ -658,6 +672,69 @@ fn create_specific_filter(
|
||||
}
|
||||
Ok(Box::new(strip_ansi::StripAnsiFilter::new()))
|
||||
}
|
||||
#[cfg(feature = "tokens")]
|
||||
FilterType::HeadTokens => {
|
||||
let count = options
|
||||
.get("count")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|n| n as usize)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
"head_tokens filter requires 'count' parameter",
|
||||
)
|
||||
})?;
|
||||
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
|
||||
let mut f = tokens::HeadTokensFilter::new(count);
|
||||
f.tokenizer = Some(
|
||||
crate::tokenizer::Tokenizer::new(encoding)
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?,
|
||||
);
|
||||
f.encoding = encoding;
|
||||
Ok(Box::new(f))
|
||||
}
|
||||
#[cfg(feature = "tokens")]
|
||||
FilterType::SkipTokens => {
|
||||
let count = options
|
||||
.get("count")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|n| n as usize)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
"skip_tokens filter requires 'count' parameter",
|
||||
)
|
||||
})?;
|
||||
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
|
||||
let mut f = tokens::SkipTokensFilter::new(count);
|
||||
f.tokenizer = Some(
|
||||
crate::tokenizer::Tokenizer::new(encoding)
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?,
|
||||
);
|
||||
f.encoding = encoding;
|
||||
Ok(Box::new(f))
|
||||
}
|
||||
#[cfg(feature = "tokens")]
|
||||
FilterType::TailTokens => {
|
||||
let count = options
|
||||
.get("count")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|n| n as usize)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
"tail_tokens filter requires 'count' parameter",
|
||||
)
|
||||
})?;
|
||||
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
|
||||
let mut f = tokens::TailTokensFilter::new(count);
|
||||
f.tokenizer = Some(
|
||||
crate::tokenizer::Tokenizer::new(encoding)
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?,
|
||||
);
|
||||
f.encoding = encoding;
|
||||
Ok(Box::new(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
530
src/filter_plugin/tokens.rs
Normal file
530
src/filter_plugin/tokens.rs
Normal file
@@ -0,0 +1,530 @@
|
||||
use super::{FilterOption, FilterPlugin};
|
||||
use crate::common::PIPESIZE;
|
||||
use crate::services::filter_service::register_filter_plugin;
|
||||
use crate::tokenizer::{TokenEncoding, Tokenizer};
|
||||
use std::collections::VecDeque;
|
||||
use std::io::{Read, Result, Write};
|
||||
|
||||
/// Resolve the tokenizer from a JSON options map.
|
||||
fn resolve_tokenizer(options: &Option<serde_json::Value>) -> Tokenizer {
|
||||
let encoding = options
|
||||
.as_ref()
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(|s| s.parse::<TokenEncoding>().ok())
|
||||
.unwrap_or_default();
|
||||
Tokenizer::new(encoding).expect("Failed to create tokenizer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// head_tokens
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A filter that outputs only the first N tokens of the input stream.
|
||||
///
|
||||
/// Streams bytes directly until the token limit is reached. When the limit
|
||||
/// falls mid-chunk, uses `split_by_token` to find the exact byte boundary.
|
||||
pub struct HeadTokensFilter {
|
||||
pub remaining: usize,
|
||||
pub tokenizer: Option<Tokenizer>,
|
||||
pub encoding: TokenEncoding,
|
||||
}
|
||||
|
||||
impl HeadTokensFilter {
|
||||
pub fn new(count: usize) -> Self {
|
||||
Self {
|
||||
remaining: count,
|
||||
tokenizer: None,
|
||||
encoding: TokenEncoding::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FilterPlugin for HeadTokensFilter {
|
||||
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||
if self.remaining == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tokenizer = self
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("HeadTokensFilter: tokenizer not initialized"));
|
||||
|
||||
let mut buffer = vec![0u8; PIPESIZE];
|
||||
let mut total_tokens = 0usize;
|
||||
|
||||
loop {
|
||||
let n = reader.read(&mut buffer)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk = &buffer[..n];
|
||||
let text = String::from_utf8_lossy(chunk);
|
||||
let chunk_tokens = tokenizer.count(&text);
|
||||
|
||||
if total_tokens + chunk_tokens <= self.remaining {
|
||||
// Entire chunk fits — write it directly
|
||||
writer.write_all(chunk)?;
|
||||
total_tokens += chunk_tokens;
|
||||
if total_tokens >= self.remaining {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Cutoff is within this chunk — split at exact token boundary
|
||||
let tokens_to_write = self.remaining - total_tokens;
|
||||
let token_strs = tokenizer
|
||||
.split_by_token(&text)
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
let mut byte_pos = 0usize;
|
||||
for token_str in token_strs.iter().take(tokens_to_write) {
|
||||
byte_pos += token_str.len();
|
||||
}
|
||||
// Write only the bytes for the tokens we want.
|
||||
// Map byte positions in the lossy string back to positions in the
|
||||
// original byte slice. Since from_utf8_lossy replaces invalid
|
||||
// bytes with the replacement character (3 bytes), we need to be
|
||||
// careful. For simplicity, write the valid prefix of the chunk.
|
||||
// We use the original bytes up to the calculated position, adjusting
|
||||
// for any UTF-8 replacement character differences.
|
||||
let write_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
|
||||
writer.write_all(&chunk[..write_len])?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||
Box::new(Self {
|
||||
remaining: self.remaining,
|
||||
tokenizer: self
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.map(|_| Tokenizer::new(self.encoding).unwrap()),
|
||||
encoding: self.encoding,
|
||||
})
|
||||
}
|
||||
|
||||
fn options(&self) -> Vec<FilterOption> {
|
||||
vec![FilterOption {
|
||||
name: "count".to_string(),
|
||||
default: None,
|
||||
required: true,
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// skip_tokens
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A filter that skips the first N tokens of the input stream and outputs the rest.
|
||||
pub struct SkipTokensFilter {
|
||||
pub remaining: usize,
|
||||
pub tokenizer: Option<Tokenizer>,
|
||||
pub encoding: TokenEncoding,
|
||||
}
|
||||
|
||||
impl SkipTokensFilter {
|
||||
pub fn new(count: usize) -> Self {
|
||||
Self {
|
||||
remaining: count,
|
||||
tokenizer: None,
|
||||
encoding: TokenEncoding::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FilterPlugin for SkipTokensFilter {
|
||||
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||
if self.remaining == 0 {
|
||||
return std::io::copy(reader, writer).map(|_| ());
|
||||
}
|
||||
|
||||
let tokenizer = self
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("SkipTokensFilter: tokenizer not initialized"));
|
||||
|
||||
let mut buffer = vec![0u8; PIPESIZE];
|
||||
let mut total_tokens = 0usize;
|
||||
let mut done_skipping = false;
|
||||
|
||||
loop {
|
||||
let n = reader.read(&mut buffer)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if done_skipping {
|
||||
writer.write_all(&buffer[..n])?;
|
||||
continue;
|
||||
}
|
||||
|
||||
let chunk = &buffer[..n];
|
||||
let text = String::from_utf8_lossy(chunk);
|
||||
let chunk_tokens = tokenizer.count(&text);
|
||||
|
||||
if total_tokens + chunk_tokens <= self.remaining {
|
||||
// Entire chunk is skipped
|
||||
total_tokens += chunk_tokens;
|
||||
if total_tokens >= self.remaining {
|
||||
done_skipping = true;
|
||||
}
|
||||
} else {
|
||||
// Cutoff is within this chunk — skip past the boundary, write rest
|
||||
let tokens_to_skip = self.remaining - total_tokens;
|
||||
let token_strs = tokenizer
|
||||
.split_by_token(&text)
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
let mut byte_pos = 0usize;
|
||||
for token_str in token_strs.iter().take(tokens_to_skip) {
|
||||
byte_pos += token_str.len();
|
||||
}
|
||||
let skip_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
|
||||
if skip_len < n {
|
||||
writer.write_all(&chunk[skip_len..])?;
|
||||
}
|
||||
done_skipping = true;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||
Box::new(Self {
|
||||
remaining: self.remaining,
|
||||
tokenizer: self
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.map(|_| Tokenizer::new(self.encoding).unwrap()),
|
||||
encoding: self.encoding,
|
||||
})
|
||||
}
|
||||
|
||||
fn options(&self) -> Vec<FilterOption> {
|
||||
vec![FilterOption {
|
||||
name: "count".to_string(),
|
||||
default: None,
|
||||
required: true,
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tail_tokens
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A filter that outputs only the last N tokens of the input stream.
|
||||
///
|
||||
/// Uses a bounded ring buffer (last ~2× PIPESIZE) to keep recent bytes.
|
||||
/// At finalize, tokenizes the buffered content and writes only the last N tokens.
|
||||
pub struct TailTokensFilter {
|
||||
pub count: usize,
|
||||
/// Ring buffer holding the most recent bytes from the stream.
|
||||
pub ring: VecDeque<u8>,
|
||||
pub ring_capacity: usize,
|
||||
pub tokenizer: Option<Tokenizer>,
|
||||
pub encoding: TokenEncoding,
|
||||
}
|
||||
|
||||
impl TailTokensFilter {
|
||||
pub fn new(count: usize) -> Self {
|
||||
// Keep enough bytes for ~2 chunks worth of data
|
||||
let ring_capacity = PIPESIZE * 2;
|
||||
Self {
|
||||
count,
|
||||
ring: VecDeque::with_capacity(ring_capacity),
|
||||
ring_capacity,
|
||||
tokenizer: None,
|
||||
encoding: TokenEncoding::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FilterPlugin for TailTokensFilter {
|
||||
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||
if self.count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tokenizer = self
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("TailTokensFilter: tokenizer not initialized"));
|
||||
|
||||
// Stream all bytes through the ring buffer
|
||||
let mut buffer = vec![0u8; PIPESIZE];
|
||||
loop {
|
||||
let n = reader.read(&mut buffer)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
for &byte in &buffer[..n] {
|
||||
if self.ring.len() >= self.ring_capacity {
|
||||
self.ring.pop_front();
|
||||
}
|
||||
self.ring.push_back(byte);
|
||||
}
|
||||
}
|
||||
|
||||
// Tokenize the buffered content and extract last N tokens
|
||||
let buffered: Vec<u8> = self.ring.iter().copied().collect();
|
||||
let text = String::from_utf8_lossy(&buffered);
|
||||
let token_strs = tokenizer
|
||||
.split_by_token(&text)
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
|
||||
if token_strs.len() <= self.count {
|
||||
// All tokens fit — write everything
|
||||
writer.write_all(&buffered)?;
|
||||
} else {
|
||||
// Write only the last N tokens
|
||||
let skip = token_strs.len() - self.count;
|
||||
let mut byte_offset = 0usize;
|
||||
for token_str in token_strs.iter().take(skip) {
|
||||
byte_offset += token_str.len();
|
||||
}
|
||||
let write_len = map_lossy_pos_to_bytes(&buffered, &text, byte_offset);
|
||||
if write_len < buffered.len() {
|
||||
writer.write_all(&buffered[write_len..])?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||
Box::new(Self {
|
||||
count: self.count,
|
||||
ring: self.ring.clone(),
|
||||
ring_capacity: self.ring_capacity,
|
||||
tokenizer: self
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.map(|_| Tokenizer::new(self.encoding).unwrap()),
|
||||
encoding: self.encoding,
|
||||
})
|
||||
}
|
||||
|
||||
fn options(&self) -> Vec<FilterOption> {
|
||||
vec![FilterOption {
|
||||
name: "count".to_string(),
|
||||
default: None,
|
||||
required: true,
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Map a byte position in a lossy string back to a position in the original byte slice.
|
||||
///
|
||||
/// `String::from_utf8_lossy` replaces invalid UTF-8 bytes with the Unicode
|
||||
/// replacement character (U+FFFD), which encodes to 3 bytes in UTF-8. This
|
||||
/// function walks both the original bytes and the lossy string in lockstep,
|
||||
/// finding the original byte position that corresponds to `lossy_pos`.
|
||||
fn map_lossy_pos_to_bytes(original: &[u8], lossy: &str, lossy_pos: usize) -> usize {
|
||||
if lossy_pos == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let replacement = '\u{FFFD}';
|
||||
let replacement_len = replacement.len_utf8(); // 3 bytes
|
||||
|
||||
let mut orig_idx = 0usize;
|
||||
let mut lossy_idx = 0usize;
|
||||
let lossy_bytes = lossy.as_bytes();
|
||||
|
||||
while lossy_idx < lossy_pos && orig_idx < original.len() {
|
||||
// Try to decode the next character from the original bytes
|
||||
match std::str::from_utf8(&original[orig_idx..]) {
|
||||
Ok("") => break,
|
||||
Ok(s) => {
|
||||
let ch = s.chars().next().unwrap();
|
||||
let ch_len = ch.len_utf8();
|
||||
// Check if this is a replacement character in the lossy string
|
||||
if ch == replacement
|
||||
&& lossy_idx + replacement_len <= lossy_pos
|
||||
&& lossy_bytes[lossy_idx..].starts_with(
|
||||
&replacement.encode_utf8(&mut [0; 4]).as_bytes()[..replacement_len],
|
||||
)
|
||||
{
|
||||
// Could be a real U+FFFD or a replacement of invalid bytes.
|
||||
// If the original byte at this position is valid UTF-8 start, it's real.
|
||||
if original[orig_idx] < 0x80 || original[orig_idx] >= 0xC0 {
|
||||
// Real character
|
||||
orig_idx += ch_len;
|
||||
lossy_idx += ch_len;
|
||||
} else {
|
||||
// Invalid byte that was replaced — advance original by 1
|
||||
orig_idx += 1;
|
||||
lossy_idx += replacement_len;
|
||||
}
|
||||
} else {
|
||||
orig_idx += ch_len;
|
||||
lossy_idx += ch_len;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let valid = e.valid_up_to();
|
||||
if valid > 0 {
|
||||
// Some valid bytes, then invalid
|
||||
orig_idx += valid;
|
||||
lossy_idx += valid;
|
||||
} else {
|
||||
// Invalid byte — in lossy it becomes 3-byte replacement char
|
||||
orig_idx += 1;
|
||||
lossy_idx += replacement_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orig_idx.min(original.len())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Registration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[ctor::ctor]
|
||||
fn register_token_filters() {
|
||||
register_filter_plugin("head_tokens", || {
|
||||
let mut f = HeadTokensFilter::new(0);
|
||||
f.tokenizer = Some(resolve_tokenizer(&None));
|
||||
Box::new(f)
|
||||
});
|
||||
register_filter_plugin("skip_tokens", || {
|
||||
let mut f = SkipTokensFilter::new(0);
|
||||
f.tokenizer = Some(resolve_tokenizer(&None));
|
||||
Box::new(f)
|
||||
});
|
||||
register_filter_plugin("tail_tokens", || {
|
||||
let mut f = TailTokensFilter::new(0);
|
||||
f.tokenizer = Some(resolve_tokenizer(&None));
|
||||
Box::new(f)
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
fn make_tokenizer() -> Tokenizer {
|
||||
Tokenizer::new(TokenEncoding::Cl100kBase).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_head_tokens_basic() {
|
||||
let mut filter = HeadTokensFilter::new(3);
|
||||
filter.tokenizer = Some(make_tokenizer());
|
||||
|
||||
let input = b"The quick brown fox";
|
||||
let mut output = Vec::new();
|
||||
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||
|
||||
let result = String::from_utf8_lossy(&output);
|
||||
// "The quick brown" is typically 3 tokens
|
||||
assert!(!result.is_empty());
|
||||
assert!(result.len() <= input.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_head_tokens_zero() {
|
||||
let mut filter = HeadTokensFilter::new(0);
|
||||
filter.tokenizer = Some(make_tokenizer());
|
||||
|
||||
let input = b"The quick brown fox";
|
||||
let mut output = Vec::new();
|
||||
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||
assert!(output.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_head_tokens_more_than_available() {
|
||||
let mut filter = HeadTokensFilter::new(1000);
|
||||
filter.tokenizer = Some(make_tokenizer());
|
||||
|
||||
let input = b"Hello world";
|
||||
let mut output = Vec::new();
|
||||
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||
assert_eq!(output, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_tokens_basic() {
|
||||
let mut filter = SkipTokensFilter::new(2);
|
||||
filter.tokenizer = Some(make_tokenizer());
|
||||
|
||||
let input = b"The quick brown fox";
|
||||
let mut output = Vec::new();
|
||||
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||
|
||||
let result = String::from_utf8_lossy(&output);
|
||||
// Should have skipped some tokens
|
||||
assert!(result.len() < input.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_tokens_zero() {
|
||||
let mut filter = SkipTokensFilter::new(0);
|
||||
filter.tokenizer = Some(make_tokenizer());
|
||||
|
||||
let input = b"Hello world";
|
||||
let mut output = Vec::new();
|
||||
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||
assert_eq!(output, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tail_tokens_basic() {
|
||||
let mut filter = TailTokensFilter::new(2);
|
||||
filter.tokenizer = Some(make_tokenizer());
|
||||
|
||||
let input = b"The quick brown fox jumps over the lazy dog";
|
||||
let mut output = Vec::new();
|
||||
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||
|
||||
let result = String::from_utf8_lossy(&output);
|
||||
// Should only have last 2 tokens
|
||||
assert!(!result.is_empty());
|
||||
assert!(result.len() < input.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tail_tokens_zero() {
|
||||
let mut filter = TailTokensFilter::new(0);
|
||||
filter.tokenizer = Some(make_tokenizer());
|
||||
|
||||
let input = b"Hello world";
|
||||
let mut output = Vec::new();
|
||||
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||
assert!(output.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_lossy_pos_ascii() {
|
||||
let original = b"Hello world";
|
||||
let lossy = String::from_utf8_lossy(original);
|
||||
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
|
||||
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 0), 0);
|
||||
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 11), 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_lossy_pos_with_invalid_utf8() {
|
||||
let original = b"Hello\x80world";
|
||||
let lossy = String::from_utf8_lossy(original);
|
||||
// lossy = "Hello\u{FFFD}world" (13 bytes)
|
||||
// Position 5 in lossy = after "Hello" = position 5 in original
|
||||
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
|
||||
// Position 8 in lossy = "Hello\u{FFFD}" = position 6 in original
|
||||
// (the invalid byte \x80 at position 5 was replaced)
|
||||
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 8), 6);
|
||||
}
|
||||
}
|
||||
11
src/lib.rs
11
src/lib.rs
@@ -43,6 +43,9 @@ pub mod services;
|
||||
#[cfg(feature = "client")]
|
||||
pub mod client;
|
||||
|
||||
#[cfg(feature = "tokens")]
|
||||
pub mod tokenizer;
|
||||
|
||||
// Re-export Args struct for library usage
|
||||
pub use args::Args;
|
||||
// Re-export PIPESIZE constant
|
||||
@@ -52,6 +55,10 @@ pub use common::PIPESIZE;
|
||||
#[allow(unused_imports)]
|
||||
use filter_plugin::{grep, head, skip, strip_ansi, tail};
|
||||
|
||||
#[cfg(feature = "tokens")]
|
||||
#[allow(unused_imports)]
|
||||
use filter_plugin::tokens as token_filters;
|
||||
|
||||
use crate::meta_plugin::{
|
||||
cwd, digest, env, exec, hostname, keep_pid, read_rate, read_time, shell, shell_pid, user,
|
||||
};
|
||||
@@ -60,6 +67,10 @@ use crate::meta_plugin::{
|
||||
#[allow(unused_imports)]
|
||||
use crate::meta_plugin::magic_file;
|
||||
|
||||
#[cfg(feature = "tokens")]
|
||||
#[allow(unused_imports)]
|
||||
use crate::meta_plugin::tokens;
|
||||
|
||||
/// Initializes plugins at library load time.
|
||||
///
|
||||
/// Plugin registration happens automatically via `#[ctor]` constructors
|
||||
|
||||
@@ -16,8 +16,9 @@ pub mod read_time;
|
||||
pub mod shell;
|
||||
pub mod shell_pid;
|
||||
pub mod text;
|
||||
#[cfg(feature = "tokens")]
|
||||
pub mod tokens;
|
||||
pub mod user;
|
||||
// pub mod text; // Removed duplicate
|
||||
|
||||
pub use digest::DigestMetaPlugin;
|
||||
pub use exec::MetaPluginExec;
|
||||
@@ -232,6 +233,7 @@ pub enum MetaPluginType {
|
||||
Hostname,
|
||||
Exec,
|
||||
Env,
|
||||
Tokens,
|
||||
}
|
||||
|
||||
/// Central function to handle metadata output with name mapping.
|
||||
|
||||
295
src/meta_plugin/tokens.rs
Normal file
295
src/meta_plugin/tokens.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
use crate::common::PIPESIZE;
|
||||
use crate::common::is_binary::is_binary;
|
||||
use crate::meta_plugin::{MetaPlugin, MetaPluginResponse, MetaPluginType};
|
||||
use crate::tokenizer::{TokenEncoding, Tokenizer};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TokensMetaPlugin {
|
||||
/// Buffer for binary detection (up to PIPESIZE bytes).
|
||||
buffer: Option<Vec<u8>>,
|
||||
max_buffer_size: usize,
|
||||
is_finalized: bool,
|
||||
is_binary_content: Option<bool>,
|
||||
/// Running token count accumulated across chunks.
|
||||
token_count: usize,
|
||||
/// UTF-8 boundary carry buffer.
|
||||
utf8_buffer: Vec<u8>,
|
||||
base: crate::meta_plugin::BaseMetaPlugin,
|
||||
/// The tokenizer instance.
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
impl TokensMetaPlugin {
|
||||
pub fn new(
|
||||
options: Option<std::collections::HashMap<String, serde_yaml::Value>>,
|
||||
outputs: Option<std::collections::HashMap<String, serde_yaml::Value>>,
|
||||
) -> Self {
|
||||
let mut base = crate::meta_plugin::BaseMetaPlugin::new();
|
||||
|
||||
base.initialize_plugin(&["token_count"], &options, &outputs);
|
||||
|
||||
// Set default options
|
||||
let default_options = vec![
|
||||
(
|
||||
"token_detect_size",
|
||||
serde_yaml::Value::Number(PIPESIZE.into()),
|
||||
),
|
||||
(
|
||||
"encoding",
|
||||
serde_yaml::Value::String("cl100k_base".to_string()),
|
||||
),
|
||||
];
|
||||
|
||||
for (key, value) in default_options {
|
||||
if !base.options.contains_key(key) {
|
||||
base.options.insert(key.to_string(), value);
|
||||
}
|
||||
}
|
||||
|
||||
let max_buffer_size = base
|
||||
.options
|
||||
.get("token_detect_size")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(PIPESIZE as u64) as usize;
|
||||
|
||||
let encoding = base
|
||||
.options
|
||||
.get("encoding")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(|s| s.parse::<TokenEncoding>().ok())
|
||||
.unwrap_or_default();
|
||||
|
||||
let tokenizer = Tokenizer::new(encoding).expect("Failed to create tokenizer");
|
||||
|
||||
Self {
|
||||
buffer: Some(Vec::new()),
|
||||
max_buffer_size,
|
||||
is_finalized: false,
|
||||
is_binary_content: None,
|
||||
token_count: 0,
|
||||
utf8_buffer: Vec::new(),
|
||||
base,
|
||||
tokenizer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokenize a byte chunk, handling UTF-8 boundaries.
|
||||
///
|
||||
/// Combines with any pending UTF-8 carry bytes, converts to text,
|
||||
/// and adds the token count to the running total.
|
||||
fn count_tokens(&mut self, data: &[u8]) {
|
||||
if data.is_empty() && self.utf8_buffer.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let combined = if !self.utf8_buffer.is_empty() {
|
||||
let mut c = self.utf8_buffer.clone();
|
||||
c.extend_from_slice(data);
|
||||
c
|
||||
} else {
|
||||
data.to_vec()
|
||||
};
|
||||
self.utf8_buffer.clear();
|
||||
|
||||
let text = match std::str::from_utf8(&combined) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let valid = e.valid_up_to();
|
||||
if valid < combined.len() {
|
||||
self.utf8_buffer.extend_from_slice(&combined[valid..]);
|
||||
}
|
||||
match std::str::from_utf8(&combined[..valid]) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !text.is_empty() {
|
||||
self.token_count += self.tokenizer.count(text);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform binary detection on the buffer.
|
||||
fn detect_binary(&mut self, buffer: &[u8]) -> bool {
|
||||
let result = is_binary(buffer);
|
||||
self.is_binary_content = Some(result);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl MetaPlugin for TokensMetaPlugin {
|
||||
fn is_finalized(&self) -> bool {
|
||||
self.is_finalized
|
||||
}
|
||||
|
||||
fn set_finalized(&mut self, finalized: bool) {
|
||||
self.is_finalized = finalized;
|
||||
}
|
||||
|
||||
fn update(&mut self, data: &[u8]) -> MetaPluginResponse {
|
||||
if self.is_finalized {
|
||||
return MetaPluginResponse {
|
||||
metadata: Vec::new(),
|
||||
is_finalized: true,
|
||||
};
|
||||
}
|
||||
|
||||
let mut metadata = Vec::new();
|
||||
|
||||
if self.is_binary_content.is_none() {
|
||||
// Add data to the buffer
|
||||
let should_detect = if let Some(ref mut buffer) = self.buffer {
|
||||
let remaining = self.max_buffer_size.saturating_sub(buffer.len());
|
||||
let to_take = std::cmp::min(data.len(), remaining);
|
||||
buffer.extend_from_slice(&data[..to_take]);
|
||||
buffer.len() >= std::cmp::min(1024, self.max_buffer_size)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if should_detect {
|
||||
let buf_clone = self.buffer.as_ref().unwrap().clone();
|
||||
let is_binary = self.detect_binary(&buf_clone);
|
||||
|
||||
if is_binary {
|
||||
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||
"token_count",
|
||||
serde_yaml::Value::Null,
|
||||
self.base.outputs(),
|
||||
) {
|
||||
metadata.push(md);
|
||||
}
|
||||
self.buffer = None;
|
||||
self.is_finalized = true;
|
||||
return MetaPluginResponse {
|
||||
metadata,
|
||||
is_finalized: true,
|
||||
};
|
||||
}
|
||||
|
||||
// It's text — tokenize the full accumulated buffer
|
||||
self.count_tokens(&buf_clone);
|
||||
|
||||
if buf_clone.len() >= self.max_buffer_size {
|
||||
self.buffer = None;
|
||||
}
|
||||
} else if self.buffer.is_some() {
|
||||
// Still building up buffer — tokenize what was just added
|
||||
let remaining = self
|
||||
.max_buffer_size
|
||||
.saturating_sub(self.buffer.as_ref().map_or(0, |b| b.len()));
|
||||
let to_take = std::cmp::min(data.len(), remaining);
|
||||
self.count_tokens(&data[..to_take]);
|
||||
}
|
||||
} else if self.is_binary_content == Some(false) {
|
||||
self.count_tokens(data);
|
||||
} else if self.is_binary_content == Some(true) {
|
||||
self.is_finalized = true;
|
||||
return MetaPluginResponse {
|
||||
metadata: Vec::new(),
|
||||
is_finalized: true,
|
||||
};
|
||||
}
|
||||
|
||||
MetaPluginResponse {
|
||||
metadata,
|
||||
is_finalized: self.is_finalized,
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&mut self) -> MetaPluginResponse {
|
||||
if self.is_finalized {
|
||||
return MetaPluginResponse {
|
||||
metadata: Vec::new(),
|
||||
is_finalized: true,
|
||||
};
|
||||
}
|
||||
|
||||
let mut metadata = Vec::new();
|
||||
|
||||
// If binary detection hasn't completed, do it now
|
||||
if self.is_binary_content.is_none() {
|
||||
if let Some(buffer) = &self.buffer {
|
||||
if !buffer.is_empty() {
|
||||
let buf_clone = buffer.clone();
|
||||
let is_binary = self.detect_binary(&buf_clone);
|
||||
|
||||
if is_binary {
|
||||
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||
"token_count",
|
||||
serde_yaml::Value::Null,
|
||||
self.base.outputs(),
|
||||
) {
|
||||
metadata.push(md);
|
||||
}
|
||||
self.buffer = None;
|
||||
self.is_finalized = true;
|
||||
return MetaPluginResponse {
|
||||
metadata,
|
||||
is_finalized: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process any remaining UTF-8 bytes
|
||||
if !self.utf8_buffer.is_empty() {
|
||||
self.count_tokens(&[]);
|
||||
}
|
||||
|
||||
// Emit token count
|
||||
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||
"token_count",
|
||||
serde_yaml::Value::String(self.token_count.to_string()),
|
||||
self.base.outputs(),
|
||||
) {
|
||||
metadata.push(md);
|
||||
}
|
||||
|
||||
self.buffer = None;
|
||||
self.is_finalized = true;
|
||||
MetaPluginResponse {
|
||||
metadata,
|
||||
is_finalized: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn meta_type(&self) -> MetaPluginType {
|
||||
MetaPluginType::Tokens
|
||||
}
|
||||
|
||||
fn outputs(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
|
||||
self.base.outputs()
|
||||
}
|
||||
|
||||
fn outputs_mut(
|
||||
&mut self,
|
||||
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||
Ok(self.base.outputs_mut())
|
||||
}
|
||||
|
||||
fn default_outputs(&self) -> Vec<String> {
|
||||
vec!["token_count".to_string()]
|
||||
}
|
||||
|
||||
fn options(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
|
||||
self.base.options()
|
||||
}
|
||||
|
||||
fn options_mut(
|
||||
&mut self,
|
||||
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||
Ok(self.base.options_mut())
|
||||
}
|
||||
}
|
||||
|
||||
use crate::meta_plugin::register_meta_plugin;
|
||||
|
||||
#[ctor::ctor]
|
||||
fn register_tokens_plugin() {
|
||||
register_meta_plugin(MetaPluginType::Tokens, |options, outputs| {
|
||||
Box::new(TokensMetaPlugin::new(options, outputs))
|
||||
});
|
||||
}
|
||||
147
src/tokenizer/mod.rs
Normal file
147
src/tokenizer/mod.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
use anyhow::{Result, bail};
|
||||
|
||||
/// Supported LLM token encodings.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum TokenEncoding {
|
||||
/// cl100k_base — used by GPT-3.5, GPT-4, text-embedding-ada-002.
|
||||
#[default]
|
||||
Cl100kBase,
|
||||
/// o200k_base — used by GPT-4o, GPT-5, o1, o3, o4 models.
|
||||
O200kBase,
|
||||
}
|
||||
|
||||
impl std::str::FromStr for TokenEncoding {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"cl100k_base" => Ok(TokenEncoding::Cl100kBase),
|
||||
"o200k_base" => Ok(TokenEncoding::O200kBase),
|
||||
_ => bail!("Unknown token encoding: {s}. Supported: cl100k_base, o200k_base"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TokenEncoding {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TokenEncoding::Cl100kBase => write!(f, "cl100k_base"),
|
||||
TokenEncoding::O200kBase => write!(f, "o200k_base"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around tiktoken BPE tokenizer.
|
||||
///
|
||||
/// Provides streaming-friendly tokenization: count tokens in text,
|
||||
/// split text into token strings, and decode token IDs back to text.
|
||||
#[derive(Clone)]
|
||||
pub struct Tokenizer {
|
||||
bpe: tiktoken_rs::CoreBPE,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tokenizer {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Tokenizer").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
/// Creates a new tokenizer for the specified encoding.
|
||||
pub fn new(encoding: TokenEncoding) -> Result<Self> {
|
||||
let bpe = match encoding {
|
||||
TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base()?,
|
||||
TokenEncoding::O200kBase => tiktoken_rs::o200k_base()?,
|
||||
};
|
||||
Ok(Self { bpe })
|
||||
}
|
||||
|
||||
/// Counts the number of tokens in the given text.
|
||||
///
|
||||
/// Uses `encode_ordinary` which treats the text as a single unit.
|
||||
/// For streaming: tokenizing chunks independently and summing gives
|
||||
/// the same result as tokenizing the full text (exact when no regex
|
||||
/// match spans a chunk boundary).
|
||||
pub fn count(&self, text: &str) -> usize {
|
||||
self.bpe.encode_ordinary(text).len()
|
||||
}
|
||||
|
||||
/// Splits text into individual decoded token strings.
|
||||
///
|
||||
/// Each returned string corresponds to one token. Useful for finding
|
||||
/// exact byte boundaries when filtering by token count.
|
||||
pub fn split_by_token(&self, text: &str) -> Result<Vec<String>> {
|
||||
self.bpe.split_by_token(text, false)
|
||||
}
|
||||
|
||||
/// Decodes a slice of token IDs back into a string.
|
||||
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
self.bpe.decode(tokens.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_count() {
|
||||
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||
let count = tok.count("Hello, world!");
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_split() {
|
||||
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||
let tokens = tok.split_by_token("Hello world").unwrap();
|
||||
assert_eq!(tokens.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_roundtrip() {
|
||||
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
let token_ids: Vec<u32> = tok
|
||||
.bpe
|
||||
.encode_ordinary(text)
|
||||
.into_iter()
|
||||
.map(|x| x as u32)
|
||||
.collect();
|
||||
let decoded = tok.decode(&token_ids).unwrap();
|
||||
assert_eq!(text, decoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_sum_close_to_full() {
|
||||
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||
let text = "The quick brown fox jumps over the lazy dog. \
|
||||
Pack my box with five dozen liquor jugs. \
|
||||
How vexingly quick daft zebras jump!";
|
||||
let full_count = tok.count(text);
|
||||
|
||||
// Split into chunks at word boundaries
|
||||
let mid = text.find("Pack").unwrap();
|
||||
let (a, b) = text.split_at(mid);
|
||||
let chunk_sum = tok.count(a) + tok.count(b);
|
||||
// Chunk-based counting may differ by 1-2 tokens when a BPE merge
|
||||
// boundary falls near the chunk split point
|
||||
assert!(
|
||||
(full_count as isize - chunk_sum as isize).abs() <= 2,
|
||||
"full={full_count}, chunk_sum={chunk_sum}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoding_from_str() {
|
||||
assert_eq!(
|
||||
"cl100k_base".parse::<TokenEncoding>().unwrap(),
|
||||
TokenEncoding::Cl100kBase
|
||||
);
|
||||
assert_eq!(
|
||||
"o200k_base".parse::<TokenEncoding>().unwrap(),
|
||||
TokenEncoding::O200kBase
|
||||
);
|
||||
assert!("unknown".parse::<TokenEncoding>().is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user