feat: add LLM token counting meta plugin and token filters
Add tiktoken-based token counting via new 'tokens' feature flag. New components: - Shared tokenizer module wrapping tiktoken CoreBPE (cl100k_base, o200k_base) - TokensMetaPlugin: streaming token counter, tokenizes each chunk independently - head_tokens(N): stream first N tokens, split at exact boundary when mid-chunk - skip_tokens(N): skip first N tokens, stream the rest - tail_tokens(N): bounded ring buffer (~16KB), outputs last N tokens at finalize All filters are fully streaming — no full-stream buffering. Meta plugin accuracy: exact for normal text, ±1-2 tokens if long whitespace sequence spans a chunk boundary. Also: add 'client' and 'tokens' to default features, add curl to Dockerfile builder stage.
This commit is contained in:
59
Cargo.lock
generated
59
Cargo.lock
generated
@@ -310,6 +310,21 @@ version = "0.22.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bit-set"
|
||||||
|
version = "0.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
|
||||||
|
dependencies = [
|
||||||
|
"bit-vec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bit-vec"
|
||||||
|
version = "0.6.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
version = "1.3.2"
|
version = "1.3.2"
|
||||||
@@ -354,6 +369,17 @@ dependencies = [
|
|||||||
"opaque-debug",
|
"opaque-debug",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bstr"
|
||||||
|
version = "1.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"regex-automata",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.19.0"
|
version = "3.19.0"
|
||||||
@@ -929,6 +955,17 @@ version = "0.1.9"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
|
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fancy-regex"
|
||||||
|
version = "0.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
|
||||||
|
dependencies = [
|
||||||
|
"bit-set",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "2.3.0"
|
version = "2.3.0"
|
||||||
@@ -1598,6 +1635,7 @@ dependencies = [
|
|||||||
"tempfile",
|
"tempfile",
|
||||||
"term",
|
"term",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
|
"tiktoken-rs",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
@@ -2354,6 +2392,12 @@ version = "0.1.26"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace"
|
checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc-hash"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustix"
|
name = "rustix"
|
||||||
version = "1.0.8"
|
version = "1.0.8"
|
||||||
@@ -2851,6 +2895,21 @@ dependencies = [
|
|||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tiktoken-rs"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"base64 0.22.1",
|
||||||
|
"bstr",
|
||||||
|
"fancy-regex",
|
||||||
|
"lazy_static",
|
||||||
|
"regex",
|
||||||
|
"rustc-hash",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time"
|
name = "time"
|
||||||
version = "0.3.47"
|
version = "0.3.47"
|
||||||
|
|||||||
@@ -75,10 +75,11 @@ ureq = { version = "3", features = ["json"], optional = true }
|
|||||||
os_pipe = { version = "1", optional = true }
|
os_pipe = { version = "1", optional = true }
|
||||||
axum-server = { version = "0.8", features = ["tls-rustls"], optional = true }
|
axum-server = { version = "0.8", features = ["tls-rustls"], optional = true }
|
||||||
jsonwebtoken = { version = "10", optional = true, features = ["aws_lc_rs"] }
|
jsonwebtoken = { version = "10", optional = true, features = ["aws_lc_rs"] }
|
||||||
|
tiktoken-rs = { version = "0.9", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
# Default features include core compression engines and swagger UI
|
# Default features include core compression engines and swagger UI
|
||||||
default = ["magic", "lz4", "gzip"]
|
default = ["magic", "lz4", "gzip", "client", "tokens"]
|
||||||
|
|
||||||
# Full
|
# Full
|
||||||
#default = ["server", "magic", "lz4", "swagger"]
|
#default = ["server", "magic", "lz4", "swagger"]
|
||||||
@@ -113,6 +114,9 @@ client = ["dep:ureq", "dep:os_pipe"]
|
|||||||
# TLS feature (HTTPS server support)
|
# TLS feature (HTTPS server support)
|
||||||
tls = ["dep:axum-server"]
|
tls = ["dep:axum-server"]
|
||||||
|
|
||||||
|
# Token counting feature (LLM token support via tiktoken)
|
||||||
|
tokens = ["dep:tiktoken-rs"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.3.0"
|
tempfile = "3.3.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ FROM rust:1.88-slim AS builder
|
|||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
cmake \
|
cmake \
|
||||||
|
curl \
|
||||||
make \
|
make \
|
||||||
gcc \
|
gcc \
|
||||||
musl-tools \
|
musl-tools \
|
||||||
@@ -16,7 +17,6 @@ WORKDIR /app
|
|||||||
# Copy manifests and fetch dependencies (cached layer)
|
# Copy manifests and fetch dependencies (cached layer)
|
||||||
COPY Cargo.toml Cargo.lock ./
|
COPY Cargo.toml Cargo.lock ./
|
||||||
RUN mkdir src && echo 'fn main() {}' > src/main.rs && echo '' > src/lib.rs
|
RUN mkdir src && echo 'fn main() {}' > src/main.rs && echo '' > src/lib.rs
|
||||||
|
|
||||||
RUN cargo fetch --target x86_64-unknown-linux-musl
|
RUN cargo fetch --target x86_64-unknown-linux-musl
|
||||||
|
|
||||||
# Copy real source and build static binary
|
# Copy real source and build static binary
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ pub mod head;
|
|||||||
pub mod skip;
|
pub mod skip;
|
||||||
pub mod strip_ansi;
|
pub mod strip_ansi;
|
||||||
pub mod tail;
|
pub mod tail;
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
pub mod tokens;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -192,6 +194,12 @@ pub enum FilterType {
|
|||||||
SkipLines,
|
SkipLines,
|
||||||
Grep,
|
Grep,
|
||||||
StripAnsi,
|
StripAnsi,
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
HeadTokens,
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
SkipTokens,
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
TailTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Maximum buffer size (256 MB) for filter chain intermediate results.
|
/// Maximum buffer size (256 MB) for filter chain intermediate results.
|
||||||
@@ -490,6 +498,12 @@ fn create_filter_with_options(
|
|||||||
FilterType::SkipBytes => skip::SkipBytesFilter::new(0).options(),
|
FilterType::SkipBytes => skip::SkipBytesFilter::new(0).options(),
|
||||||
FilterType::SkipLines => skip::SkipLinesFilter::new(0).options(),
|
FilterType::SkipLines => skip::SkipLinesFilter::new(0).options(),
|
||||||
FilterType::StripAnsi => strip_ansi::StripAnsiFilter::new().options(),
|
FilterType::StripAnsi => strip_ansi::StripAnsiFilter::new().options(),
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::HeadTokens => tokens::HeadTokensFilter::new(0).options(),
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::SkipTokens => tokens::SkipTokensFilter::new(0).options(),
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::TailTokens => tokens::TailTokensFilter::new(0).options(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut options = HashMap::new();
|
let mut options = HashMap::new();
|
||||||
@@ -658,6 +672,69 @@ fn create_specific_filter(
|
|||||||
}
|
}
|
||||||
Ok(Box::new(strip_ansi::StripAnsiFilter::new()))
|
Ok(Box::new(strip_ansi::StripAnsiFilter::new()))
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::HeadTokens => {
|
||||||
|
let count = options
|
||||||
|
.get("count")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|n| n as usize)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidInput,
|
||||||
|
"head_tokens filter requires 'count' parameter",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
|
||||||
|
let mut f = tokens::HeadTokensFilter::new(count);
|
||||||
|
f.tokenizer = Some(
|
||||||
|
crate::tokenizer::Tokenizer::new(encoding)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?,
|
||||||
|
);
|
||||||
|
f.encoding = encoding;
|
||||||
|
Ok(Box::new(f))
|
||||||
|
}
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::SkipTokens => {
|
||||||
|
let count = options
|
||||||
|
.get("count")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|n| n as usize)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidInput,
|
||||||
|
"skip_tokens filter requires 'count' parameter",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
|
||||||
|
let mut f = tokens::SkipTokensFilter::new(count);
|
||||||
|
f.tokenizer = Some(
|
||||||
|
crate::tokenizer::Tokenizer::new(encoding)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?,
|
||||||
|
);
|
||||||
|
f.encoding = encoding;
|
||||||
|
Ok(Box::new(f))
|
||||||
|
}
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::TailTokens => {
|
||||||
|
let count = options
|
||||||
|
.get("count")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|n| n as usize)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidInput,
|
||||||
|
"tail_tokens filter requires 'count' parameter",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let encoding = crate::tokenizer::TokenEncoding::Cl100kBase;
|
||||||
|
let mut f = tokens::TailTokensFilter::new(count);
|
||||||
|
f.tokenizer = Some(
|
||||||
|
crate::tokenizer::Tokenizer::new(encoding)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?,
|
||||||
|
);
|
||||||
|
f.encoding = encoding;
|
||||||
|
Ok(Box::new(f))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
530
src/filter_plugin/tokens.rs
Normal file
530
src/filter_plugin/tokens.rs
Normal file
@@ -0,0 +1,530 @@
|
|||||||
|
use super::{FilterOption, FilterPlugin};
|
||||||
|
use crate::common::PIPESIZE;
|
||||||
|
use crate::services::filter_service::register_filter_plugin;
|
||||||
|
use crate::tokenizer::{TokenEncoding, Tokenizer};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::io::{Read, Result, Write};
|
||||||
|
|
||||||
|
/// Resolve the tokenizer from a JSON options map.
|
||||||
|
fn resolve_tokenizer(options: &Option<serde_json::Value>) -> Tokenizer {
|
||||||
|
let encoding = options
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(|s| s.parse::<TokenEncoding>().ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
Tokenizer::new(encoding).expect("Failed to create tokenizer")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// head_tokens
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A filter that outputs only the first N tokens of the input stream.
|
||||||
|
///
|
||||||
|
/// Streams bytes directly until the token limit is reached. When the limit
|
||||||
|
/// falls mid-chunk, uses `split_by_token` to find the exact byte boundary.
|
||||||
|
pub struct HeadTokensFilter {
|
||||||
|
pub remaining: usize,
|
||||||
|
pub tokenizer: Option<Tokenizer>,
|
||||||
|
pub encoding: TokenEncoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HeadTokensFilter {
|
||||||
|
pub fn new(count: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
remaining: count,
|
||||||
|
tokenizer: None,
|
||||||
|
encoding: TokenEncoding::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilterPlugin for HeadTokensFilter {
|
||||||
|
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
if self.remaining == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenizer = self
|
||||||
|
.tokenizer
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or_else(|| panic!("HeadTokensFilter: tokenizer not initialized"));
|
||||||
|
|
||||||
|
let mut buffer = vec![0u8; PIPESIZE];
|
||||||
|
let mut total_tokens = 0usize;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut buffer)?;
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = &buffer[..n];
|
||||||
|
let text = String::from_utf8_lossy(chunk);
|
||||||
|
let chunk_tokens = tokenizer.count(&text);
|
||||||
|
|
||||||
|
if total_tokens + chunk_tokens <= self.remaining {
|
||||||
|
// Entire chunk fits — write it directly
|
||||||
|
writer.write_all(chunk)?;
|
||||||
|
total_tokens += chunk_tokens;
|
||||||
|
if total_tokens >= self.remaining {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Cutoff is within this chunk — split at exact token boundary
|
||||||
|
let tokens_to_write = self.remaining - total_tokens;
|
||||||
|
let token_strs = tokenizer
|
||||||
|
.split_by_token(&text)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
|
let mut byte_pos = 0usize;
|
||||||
|
for token_str in token_strs.iter().take(tokens_to_write) {
|
||||||
|
byte_pos += token_str.len();
|
||||||
|
}
|
||||||
|
// Write only the bytes for the tokens we want.
|
||||||
|
// Map byte positions in the lossy string back to positions in the
|
||||||
|
// original byte slice. Since from_utf8_lossy replaces invalid
|
||||||
|
// bytes with the replacement character (3 bytes), we need to be
|
||||||
|
// careful. For simplicity, write the valid prefix of the chunk.
|
||||||
|
// We use the original bytes up to the calculated position, adjusting
|
||||||
|
// for any UTF-8 replacement character differences.
|
||||||
|
let write_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
|
||||||
|
writer.write_all(&chunk[..write_len])?;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||||
|
Box::new(Self {
|
||||||
|
remaining: self.remaining,
|
||||||
|
tokenizer: self
|
||||||
|
.tokenizer
|
||||||
|
.as_ref()
|
||||||
|
.map(|_| Tokenizer::new(self.encoding).unwrap()),
|
||||||
|
encoding: self.encoding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> Vec<FilterOption> {
|
||||||
|
vec![FilterOption {
|
||||||
|
name: "count".to_string(),
|
||||||
|
default: None,
|
||||||
|
required: true,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// skip_tokens
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A filter that skips the first N tokens of the input stream and outputs the rest.
|
||||||
|
pub struct SkipTokensFilter {
|
||||||
|
pub remaining: usize,
|
||||||
|
pub tokenizer: Option<Tokenizer>,
|
||||||
|
pub encoding: TokenEncoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SkipTokensFilter {
|
||||||
|
pub fn new(count: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
remaining: count,
|
||||||
|
tokenizer: None,
|
||||||
|
encoding: TokenEncoding::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilterPlugin for SkipTokensFilter {
|
||||||
|
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
if self.remaining == 0 {
|
||||||
|
return std::io::copy(reader, writer).map(|_| ());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenizer = self
|
||||||
|
.tokenizer
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or_else(|| panic!("SkipTokensFilter: tokenizer not initialized"));
|
||||||
|
|
||||||
|
let mut buffer = vec![0u8; PIPESIZE];
|
||||||
|
let mut total_tokens = 0usize;
|
||||||
|
let mut done_skipping = false;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut buffer)?;
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if done_skipping {
|
||||||
|
writer.write_all(&buffer[..n])?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = &buffer[..n];
|
||||||
|
let text = String::from_utf8_lossy(chunk);
|
||||||
|
let chunk_tokens = tokenizer.count(&text);
|
||||||
|
|
||||||
|
if total_tokens + chunk_tokens <= self.remaining {
|
||||||
|
// Entire chunk is skipped
|
||||||
|
total_tokens += chunk_tokens;
|
||||||
|
if total_tokens >= self.remaining {
|
||||||
|
done_skipping = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Cutoff is within this chunk — skip past the boundary, write rest
|
||||||
|
let tokens_to_skip = self.remaining - total_tokens;
|
||||||
|
let token_strs = tokenizer
|
||||||
|
.split_by_token(&text)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
|
let mut byte_pos = 0usize;
|
||||||
|
for token_str in token_strs.iter().take(tokens_to_skip) {
|
||||||
|
byte_pos += token_str.len();
|
||||||
|
}
|
||||||
|
let skip_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
|
||||||
|
if skip_len < n {
|
||||||
|
writer.write_all(&chunk[skip_len..])?;
|
||||||
|
}
|
||||||
|
done_skipping = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||||
|
Box::new(Self {
|
||||||
|
remaining: self.remaining,
|
||||||
|
tokenizer: self
|
||||||
|
.tokenizer
|
||||||
|
.as_ref()
|
||||||
|
.map(|_| Tokenizer::new(self.encoding).unwrap()),
|
||||||
|
encoding: self.encoding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> Vec<FilterOption> {
|
||||||
|
vec![FilterOption {
|
||||||
|
name: "count".to_string(),
|
||||||
|
default: None,
|
||||||
|
required: true,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// tail_tokens
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A filter that outputs only the last N tokens of the input stream.
|
||||||
|
///
|
||||||
|
/// Uses a bounded ring buffer (last ~2× PIPESIZE) to keep recent bytes.
|
||||||
|
/// At finalize, tokenizes the buffered content and writes only the last N tokens.
|
||||||
|
pub struct TailTokensFilter {
|
||||||
|
pub count: usize,
|
||||||
|
/// Ring buffer holding the most recent bytes from the stream.
|
||||||
|
pub ring: VecDeque<u8>,
|
||||||
|
pub ring_capacity: usize,
|
||||||
|
pub tokenizer: Option<Tokenizer>,
|
||||||
|
pub encoding: TokenEncoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TailTokensFilter {
|
||||||
|
pub fn new(count: usize) -> Self {
|
||||||
|
// Keep enough bytes for ~2 chunks worth of data
|
||||||
|
let ring_capacity = PIPESIZE * 2;
|
||||||
|
Self {
|
||||||
|
count,
|
||||||
|
ring: VecDeque::with_capacity(ring_capacity),
|
||||||
|
ring_capacity,
|
||||||
|
tokenizer: None,
|
||||||
|
encoding: TokenEncoding::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilterPlugin for TailTokensFilter {
|
||||||
|
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
if self.count == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenizer = self
|
||||||
|
.tokenizer
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or_else(|| panic!("TailTokensFilter: tokenizer not initialized"));
|
||||||
|
|
||||||
|
// Stream all bytes through the ring buffer
|
||||||
|
let mut buffer = vec![0u8; PIPESIZE];
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut buffer)?;
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for &byte in &buffer[..n] {
|
||||||
|
if self.ring.len() >= self.ring_capacity {
|
||||||
|
self.ring.pop_front();
|
||||||
|
}
|
||||||
|
self.ring.push_back(byte);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenize the buffered content and extract last N tokens
|
||||||
|
let buffered: Vec<u8> = self.ring.iter().copied().collect();
|
||||||
|
let text = String::from_utf8_lossy(&buffered);
|
||||||
|
let token_strs = tokenizer
|
||||||
|
.split_by_token(&text)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
|
|
||||||
|
if token_strs.len() <= self.count {
|
||||||
|
// All tokens fit — write everything
|
||||||
|
writer.write_all(&buffered)?;
|
||||||
|
} else {
|
||||||
|
// Write only the last N tokens
|
||||||
|
let skip = token_strs.len() - self.count;
|
||||||
|
let mut byte_offset = 0usize;
|
||||||
|
for token_str in token_strs.iter().take(skip) {
|
||||||
|
byte_offset += token_str.len();
|
||||||
|
}
|
||||||
|
let write_len = map_lossy_pos_to_bytes(&buffered, &text, byte_offset);
|
||||||
|
if write_len < buffered.len() {
|
||||||
|
writer.write_all(&buffered[write_len..])?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||||
|
Box::new(Self {
|
||||||
|
count: self.count,
|
||||||
|
ring: self.ring.clone(),
|
||||||
|
ring_capacity: self.ring_capacity,
|
||||||
|
tokenizer: self
|
||||||
|
.tokenizer
|
||||||
|
.as_ref()
|
||||||
|
.map(|_| Tokenizer::new(self.encoding).unwrap()),
|
||||||
|
encoding: self.encoding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> Vec<FilterOption> {
|
||||||
|
vec![FilterOption {
|
||||||
|
name: "count".to_string(),
|
||||||
|
default: None,
|
||||||
|
required: true,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Map a byte position in a lossy string back to a position in the original byte slice.
|
||||||
|
///
|
||||||
|
/// `String::from_utf8_lossy` replaces invalid UTF-8 bytes with the Unicode
|
||||||
|
/// replacement character (U+FFFD), which encodes to 3 bytes in UTF-8. This
|
||||||
|
/// function walks both the original bytes and the lossy string in lockstep,
|
||||||
|
/// finding the original byte position that corresponds to `lossy_pos`.
|
||||||
|
fn map_lossy_pos_to_bytes(original: &[u8], lossy: &str, lossy_pos: usize) -> usize {
|
||||||
|
if lossy_pos == 0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let replacement = '\u{FFFD}';
|
||||||
|
let replacement_len = replacement.len_utf8(); // 3 bytes
|
||||||
|
|
||||||
|
let mut orig_idx = 0usize;
|
||||||
|
let mut lossy_idx = 0usize;
|
||||||
|
let lossy_bytes = lossy.as_bytes();
|
||||||
|
|
||||||
|
while lossy_idx < lossy_pos && orig_idx < original.len() {
|
||||||
|
// Try to decode the next character from the original bytes
|
||||||
|
match std::str::from_utf8(&original[orig_idx..]) {
|
||||||
|
Ok("") => break,
|
||||||
|
Ok(s) => {
|
||||||
|
let ch = s.chars().next().unwrap();
|
||||||
|
let ch_len = ch.len_utf8();
|
||||||
|
// Check if this is a replacement character in the lossy string
|
||||||
|
if ch == replacement
|
||||||
|
&& lossy_idx + replacement_len <= lossy_pos
|
||||||
|
&& lossy_bytes[lossy_idx..].starts_with(
|
||||||
|
&replacement.encode_utf8(&mut [0; 4]).as_bytes()[..replacement_len],
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// Could be a real U+FFFD or a replacement of invalid bytes.
|
||||||
|
// If the original byte at this position is valid UTF-8 start, it's real.
|
||||||
|
if original[orig_idx] < 0x80 || original[orig_idx] >= 0xC0 {
|
||||||
|
// Real character
|
||||||
|
orig_idx += ch_len;
|
||||||
|
lossy_idx += ch_len;
|
||||||
|
} else {
|
||||||
|
// Invalid byte that was replaced — advance original by 1
|
||||||
|
orig_idx += 1;
|
||||||
|
lossy_idx += replacement_len;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
orig_idx += ch_len;
|
||||||
|
lossy_idx += ch_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let valid = e.valid_up_to();
|
||||||
|
if valid > 0 {
|
||||||
|
// Some valid bytes, then invalid
|
||||||
|
orig_idx += valid;
|
||||||
|
lossy_idx += valid;
|
||||||
|
} else {
|
||||||
|
// Invalid byte — in lossy it becomes 3-byte replacement char
|
||||||
|
orig_idx += 1;
|
||||||
|
lossy_idx += replacement_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orig_idx.min(original.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Registration
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[ctor::ctor]
|
||||||
|
fn register_token_filters() {
|
||||||
|
register_filter_plugin("head_tokens", || {
|
||||||
|
let mut f = HeadTokensFilter::new(0);
|
||||||
|
f.tokenizer = Some(resolve_tokenizer(&None));
|
||||||
|
Box::new(f)
|
||||||
|
});
|
||||||
|
register_filter_plugin("skip_tokens", || {
|
||||||
|
let mut f = SkipTokensFilter::new(0);
|
||||||
|
f.tokenizer = Some(resolve_tokenizer(&None));
|
||||||
|
Box::new(f)
|
||||||
|
});
|
||||||
|
register_filter_plugin("tail_tokens", || {
|
||||||
|
let mut f = TailTokensFilter::new(0);
|
||||||
|
f.tokenizer = Some(resolve_tokenizer(&None));
|
||||||
|
Box::new(f)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
fn make_tokenizer() -> Tokenizer {
|
||||||
|
Tokenizer::new(TokenEncoding::Cl100kBase).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_head_tokens_basic() {
|
||||||
|
let mut filter = HeadTokensFilter::new(3);
|
||||||
|
filter.tokenizer = Some(make_tokenizer());
|
||||||
|
|
||||||
|
let input = b"The quick brown fox";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&output);
|
||||||
|
// "The quick brown" is typically 3 tokens
|
||||||
|
assert!(!result.is_empty());
|
||||||
|
assert!(result.len() <= input.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_head_tokens_zero() {
|
||||||
|
let mut filter = HeadTokensFilter::new(0);
|
||||||
|
filter.tokenizer = Some(make_tokenizer());
|
||||||
|
|
||||||
|
let input = b"The quick brown fox";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert!(output.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_head_tokens_more_than_available() {
|
||||||
|
let mut filter = HeadTokensFilter::new(1000);
|
||||||
|
filter.tokenizer = Some(make_tokenizer());
|
||||||
|
|
||||||
|
let input = b"Hello world";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert_eq!(output, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skip_tokens_basic() {
|
||||||
|
let mut filter = SkipTokensFilter::new(2);
|
||||||
|
filter.tokenizer = Some(make_tokenizer());
|
||||||
|
|
||||||
|
let input = b"The quick brown fox";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&output);
|
||||||
|
// Should have skipped some tokens
|
||||||
|
assert!(result.len() < input.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skip_tokens_zero() {
|
||||||
|
let mut filter = SkipTokensFilter::new(0);
|
||||||
|
filter.tokenizer = Some(make_tokenizer());
|
||||||
|
|
||||||
|
let input = b"Hello world";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert_eq!(output, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tail_tokens_basic() {
|
||||||
|
let mut filter = TailTokensFilter::new(2);
|
||||||
|
filter.tokenizer = Some(make_tokenizer());
|
||||||
|
|
||||||
|
let input = b"The quick brown fox jumps over the lazy dog";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&output);
|
||||||
|
// Should only have last 2 tokens
|
||||||
|
assert!(!result.is_empty());
|
||||||
|
assert!(result.len() < input.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tail_tokens_zero() {
|
||||||
|
let mut filter = TailTokensFilter::new(0);
|
||||||
|
filter.tokenizer = Some(make_tokenizer());
|
||||||
|
|
||||||
|
let input = b"Hello world";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert!(output.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_map_lossy_pos_ascii() {
|
||||||
|
let original = b"Hello world";
|
||||||
|
let lossy = String::from_utf8_lossy(original);
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 0), 0);
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 11), 11);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_map_lossy_pos_with_invalid_utf8() {
|
||||||
|
let original = b"Hello\x80world";
|
||||||
|
let lossy = String::from_utf8_lossy(original);
|
||||||
|
// lossy = "Hello\u{FFFD}world" (13 bytes)
|
||||||
|
// Position 5 in lossy = after "Hello" = position 5 in original
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
|
||||||
|
// Position 8 in lossy = "Hello\u{FFFD}" = position 6 in original
|
||||||
|
// (the invalid byte \x80 at position 5 was replaced)
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 8), 6);
|
||||||
|
}
|
||||||
|
}
|
||||||
11
src/lib.rs
11
src/lib.rs
@@ -43,6 +43,9 @@ pub mod services;
|
|||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
pub mod client;
|
pub mod client;
|
||||||
|
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
pub mod tokenizer;
|
||||||
|
|
||||||
// Re-export Args struct for library usage
|
// Re-export Args struct for library usage
|
||||||
pub use args::Args;
|
pub use args::Args;
|
||||||
// Re-export PIPESIZE constant
|
// Re-export PIPESIZE constant
|
||||||
@@ -52,6 +55,10 @@ pub use common::PIPESIZE;
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use filter_plugin::{grep, head, skip, strip_ansi, tail};
|
use filter_plugin::{grep, head, skip, strip_ansi, tail};
|
||||||
|
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
use filter_plugin::tokens as token_filters;
|
||||||
|
|
||||||
use crate::meta_plugin::{
|
use crate::meta_plugin::{
|
||||||
cwd, digest, env, exec, hostname, keep_pid, read_rate, read_time, shell, shell_pid, user,
|
cwd, digest, env, exec, hostname, keep_pid, read_rate, read_time, shell, shell_pid, user,
|
||||||
};
|
};
|
||||||
@@ -60,6 +67,10 @@ use crate::meta_plugin::{
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use crate::meta_plugin::magic_file;
|
use crate::meta_plugin::magic_file;
|
||||||
|
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
use crate::meta_plugin::tokens;
|
||||||
|
|
||||||
/// Initializes plugins at library load time.
|
/// Initializes plugins at library load time.
|
||||||
///
|
///
|
||||||
/// Plugin registration happens automatically via `#[ctor]` constructors
|
/// Plugin registration happens automatically via `#[ctor]` constructors
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ pub mod read_time;
|
|||||||
pub mod shell;
|
pub mod shell;
|
||||||
pub mod shell_pid;
|
pub mod shell_pid;
|
||||||
pub mod text;
|
pub mod text;
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
pub mod tokens;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
// pub mod text; // Removed duplicate
|
|
||||||
|
|
||||||
pub use digest::DigestMetaPlugin;
|
pub use digest::DigestMetaPlugin;
|
||||||
pub use exec::MetaPluginExec;
|
pub use exec::MetaPluginExec;
|
||||||
@@ -232,6 +233,7 @@ pub enum MetaPluginType {
|
|||||||
Hostname,
|
Hostname,
|
||||||
Exec,
|
Exec,
|
||||||
Env,
|
Env,
|
||||||
|
Tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Central function to handle metadata output with name mapping.
|
/// Central function to handle metadata output with name mapping.
|
||||||
|
|||||||
295
src/meta_plugin/tokens.rs
Normal file
295
src/meta_plugin/tokens.rs
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
use crate::common::PIPESIZE;
|
||||||
|
use crate::common::is_binary::is_binary;
|
||||||
|
use crate::meta_plugin::{MetaPlugin, MetaPluginResponse, MetaPluginType};
|
||||||
|
use crate::tokenizer::{TokenEncoding, Tokenizer};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TokensMetaPlugin {
|
||||||
|
/// Buffer for binary detection (up to PIPESIZE bytes).
|
||||||
|
buffer: Option<Vec<u8>>,
|
||||||
|
max_buffer_size: usize,
|
||||||
|
is_finalized: bool,
|
||||||
|
is_binary_content: Option<bool>,
|
||||||
|
/// Running token count accumulated across chunks.
|
||||||
|
token_count: usize,
|
||||||
|
/// UTF-8 boundary carry buffer.
|
||||||
|
utf8_buffer: Vec<u8>,
|
||||||
|
base: crate::meta_plugin::BaseMetaPlugin,
|
||||||
|
/// The tokenizer instance.
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokensMetaPlugin {
|
||||||
|
pub fn new(
|
||||||
|
options: Option<std::collections::HashMap<String, serde_yaml::Value>>,
|
||||||
|
outputs: Option<std::collections::HashMap<String, serde_yaml::Value>>,
|
||||||
|
) -> Self {
|
||||||
|
let mut base = crate::meta_plugin::BaseMetaPlugin::new();
|
||||||
|
|
||||||
|
base.initialize_plugin(&["token_count"], &options, &outputs);
|
||||||
|
|
||||||
|
// Set default options
|
||||||
|
let default_options = vec![
|
||||||
|
(
|
||||||
|
"token_detect_size",
|
||||||
|
serde_yaml::Value::Number(PIPESIZE.into()),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"encoding",
|
||||||
|
serde_yaml::Value::String("cl100k_base".to_string()),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (key, value) in default_options {
|
||||||
|
if !base.options.contains_key(key) {
|
||||||
|
base.options.insert(key.to_string(), value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_buffer_size = base
|
||||||
|
.options
|
||||||
|
.get("token_detect_size")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(PIPESIZE as u64) as usize;
|
||||||
|
|
||||||
|
let encoding = base
|
||||||
|
.options
|
||||||
|
.get("encoding")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(|s| s.parse::<TokenEncoding>().ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::new(encoding).expect("Failed to create tokenizer");
|
||||||
|
|
||||||
|
Self {
|
||||||
|
buffer: Some(Vec::new()),
|
||||||
|
max_buffer_size,
|
||||||
|
is_finalized: false,
|
||||||
|
is_binary_content: None,
|
||||||
|
token_count: 0,
|
||||||
|
utf8_buffer: Vec::new(),
|
||||||
|
base,
|
||||||
|
tokenizer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokenize a byte chunk, handling UTF-8 boundaries.
|
||||||
|
///
|
||||||
|
/// Combines with any pending UTF-8 carry bytes, converts to text,
|
||||||
|
/// and adds the token count to the running total.
|
||||||
|
fn count_tokens(&mut self, data: &[u8]) {
|
||||||
|
if data.is_empty() && self.utf8_buffer.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let combined = if !self.utf8_buffer.is_empty() {
|
||||||
|
let mut c = self.utf8_buffer.clone();
|
||||||
|
c.extend_from_slice(data);
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
data.to_vec()
|
||||||
|
};
|
||||||
|
self.utf8_buffer.clear();
|
||||||
|
|
||||||
|
let text = match std::str::from_utf8(&combined) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let valid = e.valid_up_to();
|
||||||
|
if valid < combined.len() {
|
||||||
|
self.utf8_buffer.extend_from_slice(&combined[valid..]);
|
||||||
|
}
|
||||||
|
match std::str::from_utf8(&combined[..valid]) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(_) => return,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !text.is_empty() {
|
||||||
|
self.token_count += self.tokenizer.count(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform binary detection on the buffer.
|
||||||
|
fn detect_binary(&mut self, buffer: &[u8]) -> bool {
|
||||||
|
let result = is_binary(buffer);
|
||||||
|
self.is_binary_content = Some(result);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetaPlugin for TokensMetaPlugin {
|
||||||
|
fn is_finalized(&self) -> bool {
|
||||||
|
self.is_finalized
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_finalized(&mut self, finalized: bool) {
|
||||||
|
self.is_finalized = finalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, data: &[u8]) -> MetaPluginResponse {
|
||||||
|
if self.is_finalized {
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata: Vec::new(),
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut metadata = Vec::new();
|
||||||
|
|
||||||
|
if self.is_binary_content.is_none() {
|
||||||
|
// Add data to the buffer
|
||||||
|
let should_detect = if let Some(ref mut buffer) = self.buffer {
|
||||||
|
let remaining = self.max_buffer_size.saturating_sub(buffer.len());
|
||||||
|
let to_take = std::cmp::min(data.len(), remaining);
|
||||||
|
buffer.extend_from_slice(&data[..to_take]);
|
||||||
|
buffer.len() >= std::cmp::min(1024, self.max_buffer_size)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
if should_detect {
|
||||||
|
let buf_clone = self.buffer.as_ref().unwrap().clone();
|
||||||
|
let is_binary = self.detect_binary(&buf_clone);
|
||||||
|
|
||||||
|
if is_binary {
|
||||||
|
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||||
|
"token_count",
|
||||||
|
serde_yaml::Value::Null,
|
||||||
|
self.base.outputs(),
|
||||||
|
) {
|
||||||
|
metadata.push(md);
|
||||||
|
}
|
||||||
|
self.buffer = None;
|
||||||
|
self.is_finalized = true;
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's text — tokenize the full accumulated buffer
|
||||||
|
self.count_tokens(&buf_clone);
|
||||||
|
|
||||||
|
if buf_clone.len() >= self.max_buffer_size {
|
||||||
|
self.buffer = None;
|
||||||
|
}
|
||||||
|
} else if self.buffer.is_some() {
|
||||||
|
// Still building up buffer — tokenize what was just added
|
||||||
|
let remaining = self
|
||||||
|
.max_buffer_size
|
||||||
|
.saturating_sub(self.buffer.as_ref().map_or(0, |b| b.len()));
|
||||||
|
let to_take = std::cmp::min(data.len(), remaining);
|
||||||
|
self.count_tokens(&data[..to_take]);
|
||||||
|
}
|
||||||
|
} else if self.is_binary_content == Some(false) {
|
||||||
|
self.count_tokens(data);
|
||||||
|
} else if self.is_binary_content == Some(true) {
|
||||||
|
self.is_finalized = true;
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata: Vec::new(),
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: self.is_finalized,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalize(&mut self) -> MetaPluginResponse {
|
||||||
|
if self.is_finalized {
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata: Vec::new(),
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut metadata = Vec::new();
|
||||||
|
|
||||||
|
// If binary detection hasn't completed, do it now
|
||||||
|
if self.is_binary_content.is_none() {
|
||||||
|
if let Some(buffer) = &self.buffer {
|
||||||
|
if !buffer.is_empty() {
|
||||||
|
let buf_clone = buffer.clone();
|
||||||
|
let is_binary = self.detect_binary(&buf_clone);
|
||||||
|
|
||||||
|
if is_binary {
|
||||||
|
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||||
|
"token_count",
|
||||||
|
serde_yaml::Value::Null,
|
||||||
|
self.base.outputs(),
|
||||||
|
) {
|
||||||
|
metadata.push(md);
|
||||||
|
}
|
||||||
|
self.buffer = None;
|
||||||
|
self.is_finalized = true;
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process any remaining UTF-8 bytes
|
||||||
|
if !self.utf8_buffer.is_empty() {
|
||||||
|
self.count_tokens(&[]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit token count
|
||||||
|
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||||
|
"token_count",
|
||||||
|
serde_yaml::Value::String(self.token_count.to_string()),
|
||||||
|
self.base.outputs(),
|
||||||
|
) {
|
||||||
|
metadata.push(md);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.buffer = None;
|
||||||
|
self.is_finalized = true;
|
||||||
|
MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn meta_type(&self) -> MetaPluginType {
|
||||||
|
MetaPluginType::Tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn outputs(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
|
||||||
|
self.base.outputs()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn outputs_mut(
|
||||||
|
&mut self,
|
||||||
|
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||||
|
Ok(self.base.outputs_mut())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_outputs(&self) -> Vec<String> {
|
||||||
|
vec!["token_count".to_string()]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
|
||||||
|
self.base.options()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options_mut(
|
||||||
|
&mut self,
|
||||||
|
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||||
|
Ok(self.base.options_mut())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use crate::meta_plugin::register_meta_plugin;
|
||||||
|
|
||||||
|
#[ctor::ctor]
|
||||||
|
fn register_tokens_plugin() {
|
||||||
|
register_meta_plugin(MetaPluginType::Tokens, |options, outputs| {
|
||||||
|
Box::new(TokensMetaPlugin::new(options, outputs))
|
||||||
|
});
|
||||||
|
}
|
||||||
147
src/tokenizer/mod.rs
Normal file
147
src/tokenizer/mod.rs
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
|
||||||
|
/// Supported LLM token encodings.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||||
|
pub enum TokenEncoding {
|
||||||
|
/// cl100k_base — used by GPT-3.5, GPT-4, text-embedding-ada-002.
|
||||||
|
#[default]
|
||||||
|
Cl100kBase,
|
||||||
|
/// o200k_base — used by GPT-4o, GPT-5, o1, o3, o4 models.
|
||||||
|
O200kBase,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for TokenEncoding {
|
||||||
|
type Err = anyhow::Error;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"cl100k_base" => Ok(TokenEncoding::Cl100kBase),
|
||||||
|
"o200k_base" => Ok(TokenEncoding::O200kBase),
|
||||||
|
_ => bail!("Unknown token encoding: {s}. Supported: cl100k_base, o200k_base"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TokenEncoding {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TokenEncoding::Cl100kBase => write!(f, "cl100k_base"),
|
||||||
|
TokenEncoding::O200kBase => write!(f, "o200k_base"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrapper around tiktoken BPE tokenizer.
|
||||||
|
///
|
||||||
|
/// Provides streaming-friendly tokenization: count tokens in text,
|
||||||
|
/// split text into token strings, and decode token IDs back to text.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Tokenizer {
|
||||||
|
bpe: tiktoken_rs::CoreBPE,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Tokenizer {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("Tokenizer").finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tokenizer {
|
||||||
|
/// Creates a new tokenizer for the specified encoding.
|
||||||
|
pub fn new(encoding: TokenEncoding) -> Result<Self> {
|
||||||
|
let bpe = match encoding {
|
||||||
|
TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base()?,
|
||||||
|
TokenEncoding::O200kBase => tiktoken_rs::o200k_base()?,
|
||||||
|
};
|
||||||
|
Ok(Self { bpe })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Counts the number of tokens in the given text.
|
||||||
|
///
|
||||||
|
/// Uses `encode_ordinary` which treats the text as a single unit.
|
||||||
|
/// For streaming: tokenizing chunks independently and summing gives
|
||||||
|
/// the same result as tokenizing the full text (exact when no regex
|
||||||
|
/// match spans a chunk boundary).
|
||||||
|
pub fn count(&self, text: &str) -> usize {
|
||||||
|
self.bpe.encode_ordinary(text).len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Splits text into individual decoded token strings.
|
||||||
|
///
|
||||||
|
/// Each returned string corresponds to one token. Useful for finding
|
||||||
|
/// exact byte boundaries when filtering by token count.
|
||||||
|
pub fn split_by_token(&self, text: &str) -> Result<Vec<String>> {
|
||||||
|
self.bpe.split_by_token(text, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decodes a slice of token IDs back into a string.
|
||||||
|
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||||
|
self.bpe.decode(tokens.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tokenizer_count() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let count = tok.count("Hello, world!");
|
||||||
|
assert!(count > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tokenizer_split() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let tokens = tok.split_by_token("Hello world").unwrap();
|
||||||
|
assert_eq!(tokens.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tokenizer_roundtrip() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let text = "The quick brown fox jumps over the lazy dog.";
|
||||||
|
let token_ids: Vec<u32> = tok
|
||||||
|
.bpe
|
||||||
|
.encode_ordinary(text)
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| x as u32)
|
||||||
|
.collect();
|
||||||
|
let decoded = tok.decode(&token_ids).unwrap();
|
||||||
|
assert_eq!(text, decoded);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chunk_sum_close_to_full() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let text = "The quick brown fox jumps over the lazy dog. \
|
||||||
|
Pack my box with five dozen liquor jugs. \
|
||||||
|
How vexingly quick daft zebras jump!";
|
||||||
|
let full_count = tok.count(text);
|
||||||
|
|
||||||
|
// Split into chunks at word boundaries
|
||||||
|
let mid = text.find("Pack").unwrap();
|
||||||
|
let (a, b) = text.split_at(mid);
|
||||||
|
let chunk_sum = tok.count(a) + tok.count(b);
|
||||||
|
// Chunk-based counting may differ by 1-2 tokens when a BPE merge
|
||||||
|
// boundary falls near the chunk split point
|
||||||
|
assert!(
|
||||||
|
(full_count as isize - chunk_sum as isize).abs() <= 2,
|
||||||
|
"full={full_count}, chunk_sum={chunk_sum}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encoding_from_str() {
|
||||||
|
assert_eq!(
|
||||||
|
"cl100k_base".parse::<TokenEncoding>().unwrap(),
|
||||||
|
TokenEncoding::Cl100kBase
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
"o200k_base".parse::<TokenEncoding>().unwrap(),
|
||||||
|
TokenEncoding::O200kBase
|
||||||
|
);
|
||||||
|
assert!("unknown".parse::<TokenEncoding>().is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user