feat: add LLM token counting meta plugin and token filters

Add tiktoken-based token counting via new 'tokens' feature flag.

New components:
- Shared tokenizer module wrapping tiktoken CoreBPE (cl100k_base, o200k_base)
- TokensMetaPlugin: streaming token counter, tokenizes each chunk independently
- head_tokens(N): stream first N tokens, split at exact boundary when mid-chunk
- skip_tokens(N): skip first N tokens, stream the rest
- tail_tokens(N): bounded ring buffer (~16KB), outputs last N tokens at finalize

All filters are fully streaming — no full-stream buffering.
Meta plugin accuracy: exact for normal text, ±1-2 tokens if long whitespace
sequence spans a chunk boundary.

Also: add 'client' and 'tokens' to default features, add curl to Dockerfile builder stage.
This commit is contained in:
2026-03-13 16:48:31 -03:00
parent e672ec751e
commit 914190e119
9 changed files with 1128 additions and 3 deletions

View File

@@ -26,6 +26,8 @@ pub mod head;
pub mod skip;
pub mod strip_ansi;
pub mod tail;
#[cfg(feature = "tokens")]
pub mod tokens;
pub mod utils;
use std::collections::HashMap;
@@ -192,6 +194,12 @@ pub enum FilterType {
SkipLines,
Grep,
StripAnsi,
#[cfg(feature = "tokens")]
HeadTokens,
#[cfg(feature = "tokens")]
SkipTokens,
#[cfg(feature = "tokens")]
TailTokens,
}
/// Maximum buffer size (256 MB) for filter chain intermediate results.
@@ -490,6 +498,12 @@ fn create_filter_with_options(
FilterType::SkipBytes => skip::SkipBytesFilter::new(0).options(),
FilterType::SkipLines => skip::SkipLinesFilter::new(0).options(),
FilterType::StripAnsi => strip_ansi::StripAnsiFilter::new().options(),
#[cfg(feature = "tokens")]
FilterType::HeadTokens => tokens::HeadTokensFilter::new(0).options(),
#[cfg(feature = "tokens")]
FilterType::SkipTokens => tokens::SkipTokensFilter::new(0).options(),
#[cfg(feature = "tokens")]
FilterType::TailTokens => tokens::TailTokensFilter::new(0).options(),
};
let mut options = HashMap::new();
@@ -658,6 +672,69 @@ fn create_specific_filter(
}
Ok(Box::new(strip_ansi::StripAnsiFilter::new()))
}
#[cfg(feature = "tokens")]
FilterType::HeadTokens => {
let count = options
.get("count")
.and_then(|v| v.as_u64())
.map(|n| n as usize)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"head_tokens filter requires 'count' parameter",
)
})?;
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
let mut f = tokens::HeadTokensFilter::new(count);
f.tokenizer = Some(
crate::tokenizer::Tokenizer::new(encoding)
.map_err(|e| std::io::Error::other(e.to_string()))?,
);
f.encoding = encoding;
Ok(Box::new(f))
}
#[cfg(feature = "tokens")]
FilterType::SkipTokens => {
let count = options
.get("count")
.and_then(|v| v.as_u64())
.map(|n| n as usize)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"skip_tokens filter requires 'count' parameter",
)
})?;
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
let mut f = tokens::SkipTokensFilter::new(count);
f.tokenizer = Some(
crate::tokenizer::Tokenizer::new(encoding)
.map_err(|e| std::io::Error::other(e.to_string()))?,
);
f.encoding = encoding;
Ok(Box::new(f))
}
#[cfg(feature = "tokens")]
FilterType::TailTokens => {
let count = options
.get("count")
.and_then(|v| v.as_u64())
.map(|n| n as usize)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"tail_tokens filter requires 'count' parameter",
)
})?;
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
let mut f = tokens::TailTokensFilter::new(count);
f.tokenizer = Some(
crate::tokenizer::Tokenizer::new(encoding)
.map_err(|e| std::io::Error::other(e.to_string()))?,
);
f.encoding = encoding;
Ok(Box::new(f))
}
}
}

530
src/filter_plugin/tokens.rs Normal file
View File

@@ -0,0 +1,530 @@
use super::{FilterOption, FilterPlugin};
use crate::common::PIPESIZE;
use crate::services::filter_service::register_filter_plugin;
use crate::tokenizer::{TokenEncoding, Tokenizer};
use std::collections::VecDeque;
use std::io::{Read, Result, Write};
/// Resolve the tokenizer from a JSON options map.
fn resolve_tokenizer(options: &Option<serde_json::Value>) -> Tokenizer {
let encoding = options
.as_ref()
.and_then(|v| v.as_str())
.and_then(|s| s.parse::<TokenEncoding>().ok())
.unwrap_or_default();
Tokenizer::new(encoding).expect("Failed to create tokenizer")
}
// ---------------------------------------------------------------------------
// head_tokens
// ---------------------------------------------------------------------------
/// A filter that outputs only the first N tokens of the input stream.
///
/// Streams bytes directly until the token limit is reached. When the limit
/// falls mid-chunk, uses `split_by_token` to find the exact byte boundary.
pub struct HeadTokensFilter {
pub remaining: usize,
pub tokenizer: Option<Tokenizer>,
pub encoding: TokenEncoding,
}
impl HeadTokensFilter {
pub fn new(count: usize) -> Self {
Self {
remaining: count,
tokenizer: None,
encoding: TokenEncoding::default(),
}
}
}
impl FilterPlugin for HeadTokensFilter {
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
if self.remaining == 0 {
return Ok(());
}
let tokenizer = self
.tokenizer
.as_ref()
.unwrap_or_else(|| panic!("HeadTokensFilter: tokenizer not initialized"));
let mut buffer = vec![0u8; PIPESIZE];
let mut total_tokens = 0usize;
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
let chunk = &buffer[..n];
let text = String::from_utf8_lossy(chunk);
let chunk_tokens = tokenizer.count(&text);
if total_tokens + chunk_tokens <= self.remaining {
// Entire chunk fits — write it directly
writer.write_all(chunk)?;
total_tokens += chunk_tokens;
if total_tokens >= self.remaining {
break;
}
} else {
// Cutoff is within this chunk — split at exact token boundary
let tokens_to_write = self.remaining - total_tokens;
let token_strs = tokenizer
.split_by_token(&text)
.map_err(|e| std::io::Error::other(e.to_string()))?;
let mut byte_pos = 0usize;
for token_str in token_strs.iter().take(tokens_to_write) {
byte_pos += token_str.len();
}
// Write only the bytes for the tokens we want.
// Map byte positions in the lossy string back to positions in the
// original byte slice. Since from_utf8_lossy replaces invalid
// bytes with the replacement character (3 bytes), we need to be
// careful. For simplicity, write the valid prefix of the chunk.
// We use the original bytes up to the calculated position, adjusting
// for any UTF-8 replacement character differences.
let write_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
writer.write_all(&chunk[..write_len])?;
break;
}
}
Ok(())
}
fn clone_box(&self) -> Box<dyn FilterPlugin> {
Box::new(Self {
remaining: self.remaining,
tokenizer: self
.tokenizer
.as_ref()
.map(|_| Tokenizer::new(self.encoding).unwrap()),
encoding: self.encoding,
})
}
fn options(&self) -> Vec<FilterOption> {
vec![FilterOption {
name: "count".to_string(),
default: None,
required: true,
}]
}
}
// ---------------------------------------------------------------------------
// skip_tokens
// ---------------------------------------------------------------------------
/// A filter that skips the first N tokens of the input stream and outputs the rest.
pub struct SkipTokensFilter {
pub remaining: usize,
pub tokenizer: Option<Tokenizer>,
pub encoding: TokenEncoding,
}
impl SkipTokensFilter {
pub fn new(count: usize) -> Self {
Self {
remaining: count,
tokenizer: None,
encoding: TokenEncoding::default(),
}
}
}
impl FilterPlugin for SkipTokensFilter {
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
if self.remaining == 0 {
return std::io::copy(reader, writer).map(|_| ());
}
let tokenizer = self
.tokenizer
.as_ref()
.unwrap_or_else(|| panic!("SkipTokensFilter: tokenizer not initialized"));
let mut buffer = vec![0u8; PIPESIZE];
let mut total_tokens = 0usize;
let mut done_skipping = false;
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
if done_skipping {
writer.write_all(&buffer[..n])?;
continue;
}
let chunk = &buffer[..n];
let text = String::from_utf8_lossy(chunk);
let chunk_tokens = tokenizer.count(&text);
if total_tokens + chunk_tokens <= self.remaining {
// Entire chunk is skipped
total_tokens += chunk_tokens;
if total_tokens >= self.remaining {
done_skipping = true;
}
} else {
// Cutoff is within this chunk — skip past the boundary, write rest
let tokens_to_skip = self.remaining - total_tokens;
let token_strs = tokenizer
.split_by_token(&text)
.map_err(|e| std::io::Error::other(e.to_string()))?;
let mut byte_pos = 0usize;
for token_str in token_strs.iter().take(tokens_to_skip) {
byte_pos += token_str.len();
}
let skip_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
if skip_len < n {
writer.write_all(&chunk[skip_len..])?;
}
done_skipping = true;
}
}
Ok(())
}
fn clone_box(&self) -> Box<dyn FilterPlugin> {
Box::new(Self {
remaining: self.remaining,
tokenizer: self
.tokenizer
.as_ref()
.map(|_| Tokenizer::new(self.encoding).unwrap()),
encoding: self.encoding,
})
}
fn options(&self) -> Vec<FilterOption> {
vec![FilterOption {
name: "count".to_string(),
default: None,
required: true,
}]
}
}
// ---------------------------------------------------------------------------
// tail_tokens
// ---------------------------------------------------------------------------
/// A filter that outputs only the last N tokens of the input stream.
///
/// Uses a bounded ring buffer (last ~2× PIPESIZE) to keep recent bytes.
/// At finalize, tokenizes the buffered content and writes only the last N tokens.
pub struct TailTokensFilter {
pub count: usize,
/// Ring buffer holding the most recent bytes from the stream.
pub ring: VecDeque<u8>,
pub ring_capacity: usize,
pub tokenizer: Option<Tokenizer>,
pub encoding: TokenEncoding,
}
impl TailTokensFilter {
pub fn new(count: usize) -> Self {
// Keep enough bytes for ~2 chunks worth of data
let ring_capacity = PIPESIZE * 2;
Self {
count,
ring: VecDeque::with_capacity(ring_capacity),
ring_capacity,
tokenizer: None,
encoding: TokenEncoding::default(),
}
}
}
impl FilterPlugin for TailTokensFilter {
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
if self.count == 0 {
return Ok(());
}
let tokenizer = self
.tokenizer
.as_ref()
.unwrap_or_else(|| panic!("TailTokensFilter: tokenizer not initialized"));
// Stream all bytes through the ring buffer
let mut buffer = vec![0u8; PIPESIZE];
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
for &byte in &buffer[..n] {
if self.ring.len() >= self.ring_capacity {
self.ring.pop_front();
}
self.ring.push_back(byte);
}
}
// Tokenize the buffered content and extract last N tokens
let buffered: Vec<u8> = self.ring.iter().copied().collect();
let text = String::from_utf8_lossy(&buffered);
let token_strs = tokenizer
.split_by_token(&text)
.map_err(|e| std::io::Error::other(e.to_string()))?;
if token_strs.len() <= self.count {
// All tokens fit — write everything
writer.write_all(&buffered)?;
} else {
// Write only the last N tokens
let skip = token_strs.len() - self.count;
let mut byte_offset = 0usize;
for token_str in token_strs.iter().take(skip) {
byte_offset += token_str.len();
}
let write_len = map_lossy_pos_to_bytes(&buffered, &text, byte_offset);
if write_len < buffered.len() {
writer.write_all(&buffered[write_len..])?;
}
}
Ok(())
}
fn clone_box(&self) -> Box<dyn FilterPlugin> {
Box::new(Self {
count: self.count,
ring: self.ring.clone(),
ring_capacity: self.ring_capacity,
tokenizer: self
.tokenizer
.as_ref()
.map(|_| Tokenizer::new(self.encoding).unwrap()),
encoding: self.encoding,
})
}
fn options(&self) -> Vec<FilterOption> {
vec![FilterOption {
name: "count".to_string(),
default: None,
required: true,
}]
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Map a byte position in a lossy string back to a position in the original byte slice.
///
/// `String::from_utf8_lossy` replaces invalid UTF-8 bytes with the Unicode
/// replacement character (U+FFFD), which encodes to 3 bytes in UTF-8. This
/// function walks both the original bytes and the lossy string in lockstep,
/// finding the original byte position that corresponds to `lossy_pos`.
fn map_lossy_pos_to_bytes(original: &[u8], lossy: &str, lossy_pos: usize) -> usize {
if lossy_pos == 0 {
return 0;
}
let replacement = '\u{FFFD}';
let replacement_len = replacement.len_utf8(); // 3 bytes
let mut orig_idx = 0usize;
let mut lossy_idx = 0usize;
let lossy_bytes = lossy.as_bytes();
while lossy_idx < lossy_pos && orig_idx < original.len() {
// Try to decode the next character from the original bytes
match std::str::from_utf8(&original[orig_idx..]) {
Ok("") => break,
Ok(s) => {
let ch = s.chars().next().unwrap();
let ch_len = ch.len_utf8();
// Check if this is a replacement character in the lossy string
if ch == replacement
&& lossy_idx + replacement_len <= lossy_pos
&& lossy_bytes[lossy_idx..].starts_with(
&replacement.encode_utf8(&mut [0; 4]).as_bytes()[..replacement_len],
)
{
// Could be a real U+FFFD or a replacement of invalid bytes.
// If the original byte at this position is valid UTF-8 start, it's real.
if original[orig_idx] < 0x80 || original[orig_idx] >= 0xC0 {
// Real character
orig_idx += ch_len;
lossy_idx += ch_len;
} else {
// Invalid byte that was replaced — advance original by 1
orig_idx += 1;
lossy_idx += replacement_len;
}
} else {
orig_idx += ch_len;
lossy_idx += ch_len;
}
}
Err(e) => {
let valid = e.valid_up_to();
if valid > 0 {
// Some valid bytes, then invalid
orig_idx += valid;
lossy_idx += valid;
} else {
// Invalid byte — in lossy it becomes 3-byte replacement char
orig_idx += 1;
lossy_idx += replacement_len;
}
}
}
}
orig_idx.min(original.len())
}
// ---------------------------------------------------------------------------
// Registration
// ---------------------------------------------------------------------------
#[ctor::ctor]
fn register_token_filters() {
register_filter_plugin("head_tokens", || {
let mut f = HeadTokensFilter::new(0);
f.tokenizer = Some(resolve_tokenizer(&None));
Box::new(f)
});
register_filter_plugin("skip_tokens", || {
let mut f = SkipTokensFilter::new(0);
f.tokenizer = Some(resolve_tokenizer(&None));
Box::new(f)
});
register_filter_plugin("tail_tokens", || {
let mut f = TailTokensFilter::new(0);
f.tokenizer = Some(resolve_tokenizer(&None));
Box::new(f)
});
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn make_tokenizer() -> Tokenizer {
Tokenizer::new(TokenEncoding::Cl100kBase).unwrap()
}
#[test]
fn test_head_tokens_basic() {
let mut filter = HeadTokensFilter::new(3);
filter.tokenizer = Some(make_tokenizer());
let input = b"The quick brown fox";
let mut output = Vec::new();
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
let result = String::from_utf8_lossy(&output);
// "The quick brown" is typically 3 tokens
assert!(!result.is_empty());
assert!(result.len() <= input.len());
}
#[test]
fn test_head_tokens_zero() {
let mut filter = HeadTokensFilter::new(0);
filter.tokenizer = Some(make_tokenizer());
let input = b"The quick brown fox";
let mut output = Vec::new();
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
assert!(output.is_empty());
}
#[test]
fn test_head_tokens_more_than_available() {
let mut filter = HeadTokensFilter::new(1000);
filter.tokenizer = Some(make_tokenizer());
let input = b"Hello world";
let mut output = Vec::new();
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
assert_eq!(output, input);
}
#[test]
fn test_skip_tokens_basic() {
let mut filter = SkipTokensFilter::new(2);
filter.tokenizer = Some(make_tokenizer());
let input = b"The quick brown fox";
let mut output = Vec::new();
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
let result = String::from_utf8_lossy(&output);
// Should have skipped some tokens
assert!(result.len() < input.len());
}
#[test]
fn test_skip_tokens_zero() {
let mut filter = SkipTokensFilter::new(0);
filter.tokenizer = Some(make_tokenizer());
let input = b"Hello world";
let mut output = Vec::new();
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
assert_eq!(output, input);
}
#[test]
fn test_tail_tokens_basic() {
let mut filter = TailTokensFilter::new(2);
filter.tokenizer = Some(make_tokenizer());
let input = b"The quick brown fox jumps over the lazy dog";
let mut output = Vec::new();
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
let result = String::from_utf8_lossy(&output);
// Should only have last 2 tokens
assert!(!result.is_empty());
assert!(result.len() < input.len());
}
#[test]
fn test_tail_tokens_zero() {
let mut filter = TailTokensFilter::new(0);
filter.tokenizer = Some(make_tokenizer());
let input = b"Hello world";
let mut output = Vec::new();
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
assert!(output.is_empty());
}
#[test]
fn test_map_lossy_pos_ascii() {
let original = b"Hello world";
let lossy = String::from_utf8_lossy(original);
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 0), 0);
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 11), 11);
}
#[test]
fn test_map_lossy_pos_with_invalid_utf8() {
let original = b"Hello\x80world";
let lossy = String::from_utf8_lossy(original);
// lossy = "Hello\u{FFFD}world" (13 bytes)
// Position 5 in lossy = after "Hello" = position 5 in original
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
// Position 8 in lossy = "Hello\u{FFFD}" = position 6 in original
// (the invalid byte \x80 at position 5 was replaced)
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 8), 6);
}
}

View File

@@ -43,6 +43,9 @@ pub mod services;
#[cfg(feature = "client")]
pub mod client;
#[cfg(feature = "tokens")]
pub mod tokenizer;
// Re-export Args struct for library usage
pub use args::Args;
// Re-export PIPESIZE constant
@@ -52,6 +55,10 @@ pub use common::PIPESIZE;
#[allow(unused_imports)]
use filter_plugin::{grep, head, skip, strip_ansi, tail};
#[cfg(feature = "tokens")]
#[allow(unused_imports)]
use filter_plugin::tokens as token_filters;
use crate::meta_plugin::{
cwd, digest, env, exec, hostname, keep_pid, read_rate, read_time, shell, shell_pid, user,
};
@@ -60,6 +67,10 @@ use crate::meta_plugin::{
#[allow(unused_imports)]
use crate::meta_plugin::magic_file;
#[cfg(feature = "tokens")]
#[allow(unused_imports)]
use crate::meta_plugin::tokens;
/// Initializes plugins at library load time.
///
/// Plugin registration happens automatically via `#[ctor]` constructors

View File

@@ -16,8 +16,9 @@ pub mod read_time;
pub mod shell;
pub mod shell_pid;
pub mod text;
#[cfg(feature = "tokens")]
pub mod tokens;
pub mod user;
// pub mod text; // Removed duplicate
pub use digest::DigestMetaPlugin;
pub use exec::MetaPluginExec;
@@ -232,6 +233,7 @@ pub enum MetaPluginType {
Hostname,
Exec,
Env,
Tokens,
}
/// Central function to handle metadata output with name mapping.

295
src/meta_plugin/tokens.rs Normal file
View File

@@ -0,0 +1,295 @@
use crate::common::PIPESIZE;
use crate::common::is_binary::is_binary;
use crate::meta_plugin::{MetaPlugin, MetaPluginResponse, MetaPluginType};
use crate::tokenizer::{TokenEncoding, Tokenizer};
#[derive(Debug, Clone)]
pub struct TokensMetaPlugin {
/// Buffer for binary detection (up to PIPESIZE bytes).
buffer: Option<Vec<u8>>,
max_buffer_size: usize,
is_finalized: bool,
is_binary_content: Option<bool>,
/// Running token count accumulated across chunks.
token_count: usize,
/// UTF-8 boundary carry buffer.
utf8_buffer: Vec<u8>,
base: crate::meta_plugin::BaseMetaPlugin,
/// The tokenizer instance.
tokenizer: Tokenizer,
}
impl TokensMetaPlugin {
pub fn new(
options: Option<std::collections::HashMap<String, serde_yaml::Value>>,
outputs: Option<std::collections::HashMap<String, serde_yaml::Value>>,
) -> Self {
let mut base = crate::meta_plugin::BaseMetaPlugin::new();
base.initialize_plugin(&["token_count"], &options, &outputs);
// Set default options
let default_options = vec![
(
"token_detect_size",
serde_yaml::Value::Number(PIPESIZE.into()),
),
(
"encoding",
serde_yaml::Value::String("cl100k_base".to_string()),
),
];
for (key, value) in default_options {
if !base.options.contains_key(key) {
base.options.insert(key.to_string(), value);
}
}
let max_buffer_size = base
.options
.get("token_detect_size")
.and_then(|v| v.as_u64())
.unwrap_or(PIPESIZE as u64) as usize;
let encoding = base
.options
.get("encoding")
.and_then(|v| v.as_str())
.and_then(|s| s.parse::<TokenEncoding>().ok())
.unwrap_or_default();
let tokenizer = Tokenizer::new(encoding).expect("Failed to create tokenizer");
Self {
buffer: Some(Vec::new()),
max_buffer_size,
is_finalized: false,
is_binary_content: None,
token_count: 0,
utf8_buffer: Vec::new(),
base,
tokenizer,
}
}
/// Tokenize a byte chunk, handling UTF-8 boundaries.
///
/// Combines with any pending UTF-8 carry bytes, converts to text,
/// and adds the token count to the running total.
fn count_tokens(&mut self, data: &[u8]) {
if data.is_empty() && self.utf8_buffer.is_empty() {
return;
}
let combined = if !self.utf8_buffer.is_empty() {
let mut c = self.utf8_buffer.clone();
c.extend_from_slice(data);
c
} else {
data.to_vec()
};
self.utf8_buffer.clear();
let text = match std::str::from_utf8(&combined) {
Ok(t) => t,
Err(e) => {
let valid = e.valid_up_to();
if valid < combined.len() {
self.utf8_buffer.extend_from_slice(&combined[valid..]);
}
match std::str::from_utf8(&combined[..valid]) {
Ok(t) => t,
Err(_) => return,
}
}
};
if !text.is_empty() {
self.token_count += self.tokenizer.count(text);
}
}
/// Perform binary detection on the buffer.
fn detect_binary(&mut self, buffer: &[u8]) -> bool {
let result = is_binary(buffer);
self.is_binary_content = Some(result);
result
}
}
impl MetaPlugin for TokensMetaPlugin {
fn is_finalized(&self) -> bool {
self.is_finalized
}
fn set_finalized(&mut self, finalized: bool) {
self.is_finalized = finalized;
}
fn update(&mut self, data: &[u8]) -> MetaPluginResponse {
if self.is_finalized {
return MetaPluginResponse {
metadata: Vec::new(),
is_finalized: true,
};
}
let mut metadata = Vec::new();
if self.is_binary_content.is_none() {
// Add data to the buffer
let should_detect = if let Some(ref mut buffer) = self.buffer {
let remaining = self.max_buffer_size.saturating_sub(buffer.len());
let to_take = std::cmp::min(data.len(), remaining);
buffer.extend_from_slice(&data[..to_take]);
buffer.len() >= std::cmp::min(1024, self.max_buffer_size)
} else {
false
};
if should_detect {
let buf_clone = self.buffer.as_ref().unwrap().clone();
let is_binary = self.detect_binary(&buf_clone);
if is_binary {
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
"token_count",
serde_yaml::Value::Null,
self.base.outputs(),
) {
metadata.push(md);
}
self.buffer = None;
self.is_finalized = true;
return MetaPluginResponse {
metadata,
is_finalized: true,
};
}
// It's text — tokenize the full accumulated buffer
self.count_tokens(&buf_clone);
if buf_clone.len() >= self.max_buffer_size {
self.buffer = None;
}
} else if self.buffer.is_some() {
// Still building up buffer — tokenize what was just added
let remaining = self
.max_buffer_size
.saturating_sub(self.buffer.as_ref().map_or(0, |b| b.len()));
let to_take = std::cmp::min(data.len(), remaining);
self.count_tokens(&data[..to_take]);
}
} else if self.is_binary_content == Some(false) {
self.count_tokens(data);
} else if self.is_binary_content == Some(true) {
self.is_finalized = true;
return MetaPluginResponse {
metadata: Vec::new(),
is_finalized: true,
};
}
MetaPluginResponse {
metadata,
is_finalized: self.is_finalized,
}
}
fn finalize(&mut self) -> MetaPluginResponse {
if self.is_finalized {
return MetaPluginResponse {
metadata: Vec::new(),
is_finalized: true,
};
}
let mut metadata = Vec::new();
// If binary detection hasn't completed, do it now
if self.is_binary_content.is_none() {
if let Some(buffer) = &self.buffer {
if !buffer.is_empty() {
let buf_clone = buffer.clone();
let is_binary = self.detect_binary(&buf_clone);
if is_binary {
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
"token_count",
serde_yaml::Value::Null,
self.base.outputs(),
) {
metadata.push(md);
}
self.buffer = None;
self.is_finalized = true;
return MetaPluginResponse {
metadata,
is_finalized: true,
};
}
}
}
}
// Process any remaining UTF-8 bytes
if !self.utf8_buffer.is_empty() {
self.count_tokens(&[]);
}
// Emit token count
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
"token_count",
serde_yaml::Value::String(self.token_count.to_string()),
self.base.outputs(),
) {
metadata.push(md);
}
self.buffer = None;
self.is_finalized = true;
MetaPluginResponse {
metadata,
is_finalized: true,
}
}
fn meta_type(&self) -> MetaPluginType {
MetaPluginType::Tokens
}
fn outputs(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
self.base.outputs()
}
fn outputs_mut(
&mut self,
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
Ok(self.base.outputs_mut())
}
fn default_outputs(&self) -> Vec<String> {
vec!["token_count".to_string()]
}
fn options(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
self.base.options()
}
fn options_mut(
&mut self,
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
Ok(self.base.options_mut())
}
}
use crate::meta_plugin::register_meta_plugin;
#[ctor::ctor]
fn register_tokens_plugin() {
register_meta_plugin(MetaPluginType::Tokens, |options, outputs| {
Box::new(TokensMetaPlugin::new(options, outputs))
});
}

147
src/tokenizer/mod.rs Normal file
View File

@@ -0,0 +1,147 @@
use anyhow::{Result, bail};
/// Supported LLM token encodings.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TokenEncoding {
/// cl100k_base — used by GPT-3.5, GPT-4, text-embedding-ada-002.
#[default]
Cl100kBase,
/// o200k_base — used by GPT-4o, GPT-5, o1, o3, o4 models.
O200kBase,
}
impl std::str::FromStr for TokenEncoding {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"cl100k_base" => Ok(TokenEncoding::Cl100kBase),
"o200k_base" => Ok(TokenEncoding::O200kBase),
_ => bail!("Unknown token encoding: {s}. Supported: cl100k_base, o200k_base"),
}
}
}
impl std::fmt::Display for TokenEncoding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TokenEncoding::Cl100kBase => write!(f, "cl100k_base"),
TokenEncoding::O200kBase => write!(f, "o200k_base"),
}
}
}
/// Wrapper around tiktoken BPE tokenizer.
///
/// Provides streaming-friendly tokenization: count tokens in text,
/// split text into token strings, and decode token IDs back to text.
#[derive(Clone)]
pub struct Tokenizer {
bpe: tiktoken_rs::CoreBPE,
}
impl std::fmt::Debug for Tokenizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tokenizer").finish()
}
}
impl Tokenizer {
/// Creates a new tokenizer for the specified encoding.
pub fn new(encoding: TokenEncoding) -> Result<Self> {
let bpe = match encoding {
TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base()?,
TokenEncoding::O200kBase => tiktoken_rs::o200k_base()?,
};
Ok(Self { bpe })
}
/// Counts the number of tokens in the given text.
///
/// Uses `encode_ordinary` which treats the text as a single unit.
/// For streaming: tokenizing chunks independently and summing gives
/// the same result as tokenizing the full text (exact when no regex
/// match spans a chunk boundary).
pub fn count(&self, text: &str) -> usize {
self.bpe.encode_ordinary(text).len()
}
/// Splits text into individual decoded token strings.
///
/// Each returned string corresponds to one token. Useful for finding
/// exact byte boundaries when filtering by token count.
pub fn split_by_token(&self, text: &str) -> Result<Vec<String>> {
self.bpe.split_by_token(text, false)
}
/// Decodes a slice of token IDs back into a string.
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
self.bpe.decode(tokens.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenizer_count() {
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
let count = tok.count("Hello, world!");
assert!(count > 0);
}
#[test]
fn test_tokenizer_split() {
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
let tokens = tok.split_by_token("Hello world").unwrap();
assert_eq!(tokens.len(), 2);
}
#[test]
fn test_tokenizer_roundtrip() {
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
let text = "The quick brown fox jumps over the lazy dog.";
let token_ids: Vec<u32> = tok
.bpe
.encode_ordinary(text)
.into_iter()
.map(|x| x as u32)
.collect();
let decoded = tok.decode(&token_ids).unwrap();
assert_eq!(text, decoded);
}
#[test]
fn test_chunk_sum_close_to_full() {
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
let text = "The quick brown fox jumps over the lazy dog. \
Pack my box with five dozen liquor jugs. \
How vexingly quick daft zebras jump!";
let full_count = tok.count(text);
// Split into chunks at word boundaries
let mid = text.find("Pack").unwrap();
let (a, b) = text.split_at(mid);
let chunk_sum = tok.count(a) + tok.count(b);
// Chunk-based counting may differ by 1-2 tokens when a BPE merge
// boundary falls near the chunk split point
assert!(
(full_count as isize - chunk_sum as isize).abs() <= 2,
"full={full_count}, chunk_sum={chunk_sum}"
);
}
#[test]
fn test_encoding_from_str() {
assert_eq!(
"cl100k_base".parse::<TokenEncoding>().unwrap(),
TokenEncoding::Cl100kBase
);
assert_eq!(
"o200k_base".parse::<TokenEncoding>().unwrap(),
TokenEncoding::O200kBase
);
assert!("unknown".parse::<TokenEncoding>().is_err());
}
}