Compare commits
4 Commits
af1e0ca570
...
a07bb6b350
| Author | SHA1 | Date | |
|---|---|---|---|
| a07bb6b350 | |||
| e7d8a83369 | |||
| 914190e119 | |||
| e672ec751e |
1729
Cargo.lock
generated
1729
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
113
Cargo.toml
113
Cargo.toml
@@ -10,81 +10,83 @@ categories = ["command-line-utilities"]
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.72"
|
anyhow = "1.0"
|
||||||
axum = { version = "0.8.4", optional = true }
|
axum = { version = "0.8", optional = true }
|
||||||
derive_more = { version = "2.0", features = ["full"] }
|
derive_more = { version = "2.0", features = ["full"] }
|
||||||
smart-default = "0.7"
|
smart-default = "0.7"
|
||||||
thiserror = "1.0"
|
thiserror = "2.0"
|
||||||
base64 = "0.22.1"
|
base64 = "0.22"
|
||||||
chrono = { version = "0.4.26", features = ["serde"] }
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
clap = { version = "4.3.10", features = ["derive", "env"] }
|
clap = { version = "4.6", features = ["derive", "env"] }
|
||||||
config = "0.14.0"
|
config = "0.15"
|
||||||
ctor = "0.2"
|
ctor = "0.2"
|
||||||
directories = "6.0.0"
|
directories = "6.0"
|
||||||
dns-lookup = "2.0.2"
|
dns-lookup = "3.0"
|
||||||
enum-map = "2.6.1"
|
enum-map = "2.7"
|
||||||
flate2 = { version = "1.0.27", features = ["zlib-ng-compat"], optional = true }
|
flate2 = { version = "1.0", features = ["zlib-ng-compat"], optional = true }
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
gethostname = "1.0.2"
|
gethostname = "1.0"
|
||||||
humansize = "2.1.3"
|
humansize = "2.1"
|
||||||
async-stream = "0.3"
|
async-stream = "0.3"
|
||||||
hyper = { version = "1.0", features = ["full"] }
|
hyper = { version = "1.0", features = ["full"] }
|
||||||
http-body-util = "0.1"
|
http-body-util = "0.1"
|
||||||
inventory = "0.3"
|
inventory = "0.3"
|
||||||
is-terminal = "0.4.9"
|
is-terminal = "0.4"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.5"
|
||||||
libc = "0.2.147"
|
libc = "0.2"
|
||||||
local-ip-address = "0.6.5"
|
local-ip-address = "0.6"
|
||||||
log = "0.4.19"
|
log = "0.4"
|
||||||
lz4_flex = { version = "0.11.1", optional = true }
|
lz4_flex = { version = "0.12", optional = true }
|
||||||
magic = { version = "0.13.0", optional = true }
|
magic = { version = "0.13", optional = true }
|
||||||
nix = "0.30.1"
|
nix = "0.30"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.21"
|
||||||
comfy-table = "7.2.0"
|
comfy-table = "7.2"
|
||||||
pwhash = "1.0.0"
|
pwhash = "1.0"
|
||||||
regex = "1.9.5"
|
regex = "1.10"
|
||||||
ringbuf = "0.3"
|
ringbuf = "0.4"
|
||||||
rmcp = { version = "0.2.0", features = ["server"], optional = true }
|
rmcp = { version = "0.2", features = ["server"], optional = true }
|
||||||
rusqlite = { version = "0.37.0", features = ["bundled", "array", "chrono"] }
|
rusqlite = { version = "0.37", features = ["bundled", "array", "chrono"] }
|
||||||
rusqlite_migration = "2.3.0"
|
rusqlite_migration = "2.3"
|
||||||
serde = { version = "1.0.219", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0.142"
|
serde_json = "1.0"
|
||||||
serde_yaml = "0.9.34"
|
serde_yaml = "0.9"
|
||||||
sha2 = "0.10.0"
|
sha2 = "0.10"
|
||||||
md5 = "0.7.0"
|
md5 = "0.7"
|
||||||
subtle = "2.6"
|
subtle = "2.6"
|
||||||
stderrlog = "0.6.0"
|
env_logger = "0.11"
|
||||||
strum = { version = "0.27.2", features = ["derive"] }
|
strum = { version = "0.27", features = ["derive"] }
|
||||||
term = "1.1.0"
|
term = "1.2"
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
tokio-util = "0.7.16"
|
tokio-util = "0.7"
|
||||||
tower = { version = "0.5.2", optional = true }
|
tower = { version = "0.5", optional = true }
|
||||||
tower-http = { version = "0.6.6", features = ["cors", "fs", "trace"], optional = true }
|
tower-http = { version = "0.6", features = ["cors", "fs", "trace"], optional = true }
|
||||||
utoipa = { version = "5.4.0", features = ["axum_extras"], optional = true }
|
utoipa = { version = "5.4", features = ["axum_extras"], optional = true }
|
||||||
utoipa-swagger-ui = { version = "9.0.2", features = ["axum"], optional = true }
|
utoipa-swagger-ui = { version = "9.0", features = ["axum"], optional = true }
|
||||||
uzers = "0.12.1"
|
uzers = "0.12"
|
||||||
which = "8.0.0"
|
which = "8.0"
|
||||||
xdg = "2.5.2"
|
xdg = "2.5"
|
||||||
strip-ansi-escapes = "0.2.1"
|
strip-ansi-escapes = "0.2"
|
||||||
pest = "2.8.1"
|
pest = "2.8"
|
||||||
pest_derive = "2.8.1"
|
pest_derive = "2.8"
|
||||||
dirs = "6.0.0"
|
dirs = "6.0"
|
||||||
similar = { version = "2.7.0", default-features = false, features = ["text"] }
|
similar = { version = "2.7", default-features = false, features = ["text"] }
|
||||||
ureq = { version = "3", features = ["json"], optional = true }
|
ureq = { version = "3", features = ["json"], optional = true }
|
||||||
os_pipe = { version = "1", optional = true }
|
os_pipe = { version = "1", optional = true }
|
||||||
axum-server = { version = "0.8", features = ["tls-rustls"], optional = true }
|
axum-server = { version = "0.8", features = ["tls-rustls"], optional = true }
|
||||||
|
jsonwebtoken = { version = "10", optional = true, features = ["aws_lc_rs"] }
|
||||||
|
tiktoken-rs = { version = "0.9", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
# Default features include core compression engines and swagger UI
|
# Default features include core compression engines and swagger UI
|
||||||
default = ["magic", "lz4", "gzip"]
|
default = ["magic", "lz4", "gzip", "client", "tokens"]
|
||||||
|
|
||||||
# Full
|
# Full
|
||||||
#default = ["server", "magic", "lz4", "swagger"]
|
#default = ["server", "magic", "lz4", "swagger"]
|
||||||
|
|
||||||
|
|
||||||
# Server feature (includes axum and related dependencies)
|
# Server feature (includes axum and related dependencies)
|
||||||
server = ["dep:axum", "dep:tower", "dep:tower-http", "dep:utoipa"]
|
server = ["dep:axum", "dep:tower", "dep:tower-http", "dep:utoipa", "dep:jsonwebtoken"]
|
||||||
|
|
||||||
# Compression features
|
# Compression features
|
||||||
gzip = ["flate2"]
|
gzip = ["flate2"]
|
||||||
@@ -112,6 +114,9 @@ client = ["dep:ureq", "dep:os_pipe"]
|
|||||||
# TLS feature (HTTPS server support)
|
# TLS feature (HTTPS server support)
|
||||||
tls = ["dep:axum-server"]
|
tls = ["dep:axum-server"]
|
||||||
|
|
||||||
|
# Token counting feature (LLM token support via tiktoken)
|
||||||
|
tokens = ["dep:tiktoken-rs"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.3.0"
|
tempfile = "3.3"
|
||||||
rand = "0.8.5"
|
rand = "0.9"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ FROM rust:1.88-slim AS builder
|
|||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
cmake \
|
cmake \
|
||||||
|
curl \
|
||||||
make \
|
make \
|
||||||
gcc \
|
gcc \
|
||||||
musl-tools \
|
musl-tools \
|
||||||
@@ -16,7 +17,6 @@ WORKDIR /app
|
|||||||
# Copy manifests and fetch dependencies (cached layer)
|
# Copy manifests and fetch dependencies (cached layer)
|
||||||
COPY Cargo.toml Cargo.lock ./
|
COPY Cargo.toml Cargo.lock ./
|
||||||
RUN mkdir src && echo 'fn main() {}' > src/main.rs && echo '' > src/lib.rs
|
RUN mkdir src && echo 'fn main() {}' > src/main.rs && echo '' > src/lib.rs
|
||||||
|
|
||||||
RUN cargo fetch --target x86_64-unknown-linux-musl
|
RUN cargo fetch --target x86_64-unknown-linux-musl
|
||||||
|
|
||||||
# Copy real source and build static binary
|
# Copy real source and build static binary
|
||||||
@@ -48,8 +48,11 @@ ENV KEEP_LIST_FORMAT="id,time,size,tags,meta:hostname"
|
|||||||
# Server options
|
# Server options
|
||||||
ENV KEEP_SERVER_ADDRESS=0.0.0.0
|
ENV KEEP_SERVER_ADDRESS=0.0.0.0
|
||||||
ENV KEEP_SERVER_PORT=21080
|
ENV KEEP_SERVER_PORT=21080
|
||||||
|
# ENV KEEP_SERVER_USERNAME="keep"
|
||||||
# ENV KEEP_SERVER_PASSWORD=""
|
# ENV KEEP_SERVER_PASSWORD=""
|
||||||
# ENV KEEP_SERVER_PASSWORD_HASH=""
|
# ENV KEEP_SERVER_PASSWORD_HASH=""
|
||||||
|
# ENV KEEP_SERVER_JWT_SECRET=""
|
||||||
|
# ENV KEEP_SERVER_JWT_SECRET_FILE=/config/jwt_secret
|
||||||
|
|
||||||
# TLS options
|
# TLS options
|
||||||
# ENV KEEP_SERVER_CERT=/certs/cert.pem
|
# ENV KEEP_SERVER_CERT=/certs/cert.pem
|
||||||
@@ -57,6 +60,8 @@ ENV KEEP_SERVER_PORT=21080
|
|||||||
|
|
||||||
# Client options
|
# Client options
|
||||||
# ENV KEEP_CLIENT_URL=""
|
# ENV KEEP_CLIENT_URL=""
|
||||||
|
# ENV KEEP_CLIENT_USERNAME="keep"
|
||||||
# ENV KEEP_CLIENT_PASSWORD=""
|
# ENV KEEP_CLIENT_PASSWORD=""
|
||||||
|
# ENV KEEP_CLIENT_JWT=""
|
||||||
|
|
||||||
ENTRYPOINT ["/keep"]
|
ENTRYPOINT ["/keep"]
|
||||||
|
|||||||
155
README.md
155
README.md
@@ -351,12 +351,17 @@ KEEP_META_build=1234 echo "data" | keep --save tag --meta env=staging
|
|||||||
| `KEEP_LIST_FORMAT` | List column format | built-in defaults |
|
| `KEEP_LIST_FORMAT` | List column format | built-in defaults |
|
||||||
| `KEEP_SERVER_ADDRESS` | Server bind address | `127.0.0.1` |
|
| `KEEP_SERVER_ADDRESS` | Server bind address | `127.0.0.1` |
|
||||||
| `KEEP_SERVER_PORT` | Server port | `21080` |
|
| `KEEP_SERVER_PORT` | Server port | `21080` |
|
||||||
|
| `KEEP_SERVER_USERNAME` | Server Basic auth username | `keep` |
|
||||||
| `KEEP_SERVER_PASSWORD` | Server password | none |
|
| `KEEP_SERVER_PASSWORD` | Server password | none |
|
||||||
| `KEEP_SERVER_PASSWORD_HASH` | Server password hash | none |
|
| `KEEP_SERVER_PASSWORD_HASH` | Server password hash | none |
|
||||||
|
| `KEEP_SERVER_JWT_SECRET` | JWT secret for token auth | none |
|
||||||
|
| `KEEP_SERVER_JWT_SECRET_FILE` | Path to JWT secret file | none |
|
||||||
| `KEEP_SERVER_CERT` | TLS certificate file path (PEM) | none |
|
| `KEEP_SERVER_CERT` | TLS certificate file path (PEM) | none |
|
||||||
| `KEEP_SERVER_KEY` | TLS private key file path (PEM) | none |
|
| `KEEP_SERVER_KEY` | TLS private key file path (PEM) | none |
|
||||||
| `KEEP_CLIENT_URL` | Remote keep server URL | none |
|
| `KEEP_CLIENT_URL` | Remote keep server URL | none |
|
||||||
|
| `KEEP_CLIENT_USERNAME` | Remote server username | `keep` |
|
||||||
| `KEEP_CLIENT_PASSWORD` | Remote server password | none |
|
| `KEEP_CLIENT_PASSWORD` | Remote server password | none |
|
||||||
|
| `KEEP_CLIENT_JWT` | JWT token for remote server | none |
|
||||||
|
|
||||||
Any config setting can be overridden with `KEEP__<SETTING>` environment variables (double underscore separator).
|
Any config setting can be overridden with `KEEP__<SETTING>` environment variables (double underscore separator).
|
||||||
|
|
||||||
@@ -409,7 +414,11 @@ meta_plugins:
|
|||||||
server:
|
server:
|
||||||
address: "127.0.0.1"
|
address: "127.0.0.1"
|
||||||
port: 21080
|
port: 21080
|
||||||
|
username: "keep"
|
||||||
password: "secret"
|
password: "secret"
|
||||||
|
# JWT authentication (takes priority over password)
|
||||||
|
# jwt_secret: "my-secret-key"
|
||||||
|
# jwt_secret_file: /path/to/jwt_secret
|
||||||
# TLS (requires tls feature)
|
# TLS (requires tls feature)
|
||||||
# cert_file: /path/to/cert.pem
|
# cert_file: /path/to/cert.pem
|
||||||
# key_file: /path/to/key.pem
|
# key_file: /path/to/key.pem
|
||||||
@@ -417,7 +426,10 @@ server:
|
|||||||
# Client settings
|
# Client settings
|
||||||
client:
|
client:
|
||||||
url: "http://localhost:21080"
|
url: "http://localhost:21080"
|
||||||
|
username: "keep"
|
||||||
password: "secret"
|
password: "secret"
|
||||||
|
# Or use JWT token
|
||||||
|
# jwt: "eyJhbGciOiJIUzI1NiIs..."
|
||||||
|
|
||||||
human_readable: true
|
human_readable: true
|
||||||
quiet: false
|
quiet: false
|
||||||
@@ -444,10 +456,117 @@ keep --server
|
|||||||
# Custom address and port
|
# Custom address and port
|
||||||
keep --server --server-address 0.0.0.0 --server-port 8080
|
keep --server --server-address 0.0.0.0 --server-port 8080
|
||||||
|
|
||||||
# With authentication
|
# With password authentication
|
||||||
keep --server --server-password mypassword
|
keep --server --server-password mypassword
|
||||||
|
|
||||||
|
# With custom username
|
||||||
|
keep --server --server-username admin --server-password mypassword
|
||||||
|
|
||||||
|
# With JWT authentication
|
||||||
|
keep --server --server-jwt-secret my-secret-key
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### JWT Authentication
|
||||||
|
|
||||||
|
JWT (JSON Web Token) authentication provides permission-based access control. When a JWT secret is configured, the server validates tokens and checks permission claims for each request.
|
||||||
|
|
||||||
|
**Configuration:**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Via CLI flag
|
||||||
|
keep --server --server-jwt-secret my-secret-key
|
||||||
|
|
||||||
|
# Via environment variable
|
||||||
|
export KEEP_SERVER_JWT_SECRET=my-secret-key
|
||||||
|
keep --server
|
||||||
|
|
||||||
|
# Via config file (config.yml)
|
||||||
|
server:
|
||||||
|
jwt_secret: "my-secret-key"
|
||||||
|
|
||||||
|
# Via secret file (for Docker/secrets management)
|
||||||
|
keep --server --server-jwt-secret-file /path/to/secret
|
||||||
|
```
|
||||||
|
|
||||||
|
**Token format:**
|
||||||
|
|
||||||
|
JWTs must use HS256 algorithm with the following claims:
|
||||||
|
|
||||||
|
| Claim | Type | Required | Description |
|
||||||
|
|-------|------|----------|-------------|
|
||||||
|
| `sub` | string | Yes | Subject (client identifier) |
|
||||||
|
| `exp` | number | Yes | Expiration time (Unix timestamp) |
|
||||||
|
| `read` | boolean | No | Permission for GET requests (default: false) |
|
||||||
|
| `write` | boolean | No | Permission for POST/PUT requests (default: false) |
|
||||||
|
| `delete` | boolean | No | Permission for DELETE requests (default: false) |
|
||||||
|
|
||||||
|
**Permission mapping:**
|
||||||
|
|
||||||
|
| HTTP Method | Required Permission |
|
||||||
|
|-------------|-------------------|
|
||||||
|
| `GET` | `read` |
|
||||||
|
| `POST`, `PUT`, `PATCH` | `write` |
|
||||||
|
| `DELETE` | `delete` |
|
||||||
|
|
||||||
|
**Example token payload:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sub": "ci-pipeline",
|
||||||
|
"exp": 1735689600,
|
||||||
|
"read": true,
|
||||||
|
"write": true,
|
||||||
|
"delete": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Generating tokens:**
|
||||||
|
|
||||||
|
The server does not generate tokens — use any JWT library or tool:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Using jwt-cli (https://github.com/mike-engel/jwt-cli)
|
||||||
|
jwt encode --secret my-secret-key \
|
||||||
|
--exp=$(date -d '+24 hours' +%s) \
|
||||||
|
'{"sub":"my-client","read":true,"write":true,"delete":false}'
|
||||||
|
|
||||||
|
# Using Python
|
||||||
|
python3 -c "
|
||||||
|
import jwt, time
|
||||||
|
token = jwt.encode({
|
||||||
|
'sub': 'my-client',
|
||||||
|
'exp': int(time.time()) + 86400,
|
||||||
|
'read': True, 'write': True, 'delete': False
|
||||||
|
}, 'my-secret-key', algorithm='HS256')
|
||||||
|
print(token)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using tokens:**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# With curl
|
||||||
|
curl -H "Authorization: Bearer <jwt-token>" http://localhost:21080/api/item/
|
||||||
|
|
||||||
|
# The keep client uses --client-jwt for JWT tokens
|
||||||
|
keep --client-url http://server:21080 --client-jwt <jwt-token> --save my-tag
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response codes:**
|
||||||
|
|
||||||
|
| Code | Meaning |
|
||||||
|
|------|---------|
|
||||||
|
| `200` | Authorized |
|
||||||
|
| `401` | Missing, invalid, or expired token |
|
||||||
|
| `403` | Valid token but insufficient permissions |
|
||||||
|
|
||||||
|
**Notes:**
|
||||||
|
|
||||||
|
- When `jwt_secret` is set, password authentication is disabled — all requests must present a valid JWT Bearer token
|
||||||
|
- JWT and password authentication are mutually exclusive — when both `jwt_secret` and `password` are configured, only JWT is used
|
||||||
|
- Permission fields default to `false` if omitted — tokens must explicitly grant permissions
|
||||||
|
- JWT authentication requires the `server` feature (jsonwebtoken is included automatically)
|
||||||
|
|
||||||
#### HTTPS / TLS
|
#### HTTPS / TLS
|
||||||
|
|
||||||
Build with the `tls` feature to enable HTTPS:
|
Build with the `tls` feature to enable HTTPS:
|
||||||
@@ -533,9 +652,16 @@ cargo build --release --features client
|
|||||||
keep --client-url http://server:21080 --save my-tag
|
keep --client-url http://server:21080 --save my-tag
|
||||||
export KEEP_CLIENT_URL=http://server:21080
|
export KEEP_CLIENT_URL=http://server:21080
|
||||||
|
|
||||||
# With authentication
|
# With password authentication
|
||||||
keep --client-url http://server:21080 --client-password mypassword --save my-tag
|
keep --client-url http://server:21080 --client-password mypassword --save my-tag
|
||||||
export KEEP_CLIENT_PASSWORD=mypassword
|
export KEEP_CLIENT_PASSWORD=mypassword
|
||||||
|
|
||||||
|
# With custom username
|
||||||
|
keep --client-url http://server:21080 --client-username admin --client-password mypassword --save my-tag
|
||||||
|
|
||||||
|
# With JWT authentication
|
||||||
|
keep --client-url http://server:21080 --client-jwt <jwt-token> --save my-tag
|
||||||
|
export KEEP_CLIENT_JWT=<jwt-token>
|
||||||
```
|
```
|
||||||
|
|
||||||
#### How Client Mode Works
|
#### How Client Mode Works
|
||||||
@@ -608,15 +734,30 @@ keep --client-url http://logserver:21080 --list --meta project=myapp
|
|||||||
|
|
||||||
#### Authentication
|
#### Authentication
|
||||||
|
|
||||||
```sh
|
The server supports three authentication modes:
|
||||||
# Bearer token
|
|
||||||
curl -H "Authorization: Bearer mypassword" http://localhost:21080/api/status
|
|
||||||
|
|
||||||
# Basic auth
|
**1. Password (HTTP Basic auth):**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Default username is "keep"
|
||||||
curl -u keep:mypassword http://localhost:21080/api/status
|
curl -u keep:mypassword http://localhost:21080/api/status
|
||||||
|
|
||||||
|
# Custom username
|
||||||
|
curl -u admin:mypassword http://localhost:21080/api/status
|
||||||
```
|
```
|
||||||
|
|
||||||
When no password is configured, authentication is disabled.
|
**2. JWT (permission-based):**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Valid JWT with read permission allows GET requests
|
||||||
|
curl -H "Authorization: Bearer <jwt-token>" http://localhost:21080/api/item/
|
||||||
|
```
|
||||||
|
|
||||||
|
See [JWT Authentication](#jwt-authentication) for token format and configuration.
|
||||||
|
|
||||||
|
**3. No authentication:**
|
||||||
|
|
||||||
|
When neither password nor JWT secret is configured, authentication is disabled.
|
||||||
|
|
||||||
#### Swagger UI
|
#### Swagger UI
|
||||||
|
|
||||||
|
|||||||
@@ -9,14 +9,19 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- KEEP_SERVER_ADDRESS=0.0.0.0
|
- KEEP_SERVER_ADDRESS=0.0.0.0
|
||||||
- KEEP_SERVER_PORT=21080
|
- KEEP_SERVER_PORT=21080
|
||||||
|
# - KEEP_SERVER_USERNAME=keep
|
||||||
# - KEEP_SERVER_PASSWORD=changeme
|
# - KEEP_SERVER_PASSWORD=changeme
|
||||||
# - KEEP_SERVER_PASSWORD_HASH=
|
# - KEEP_SERVER_PASSWORD_HASH=
|
||||||
|
# - KEEP_SERVER_JWT_SECRET=
|
||||||
|
# - KEEP_SERVER_JWT_SECRET_FILE=/config/jwt_secret
|
||||||
# - KEEP_COMPRESSION=lz4
|
# - KEEP_COMPRESSION=lz4
|
||||||
# - KEEP_META_PLUGINS=
|
# - KEEP_META_PLUGINS=
|
||||||
# - KEEP_FILTERS=
|
# - KEEP_FILTERS=
|
||||||
- KEEP_CONFIG=/config/config.yml
|
- KEEP_CONFIG=/config/config.yml
|
||||||
# - KEEP_SERVER_CERT=/certs/cert.pem
|
# - KEEP_SERVER_CERT=/certs/cert.pem
|
||||||
# - KEEP_SERVER_KEY=/certs/key.pem
|
# - KEEP_SERVER_KEY=/certs/key.pem
|
||||||
|
# - KEEP_CLIENT_USERNAME=keep
|
||||||
|
# - KEEP_CLIENT_JWT=""
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
# For TLS, mount certificate files:
|
# For TLS, mount certificate files:
|
||||||
# volumes:
|
# volumes:
|
||||||
|
|||||||
24
src/args.rs
24
src/args.rs
@@ -151,6 +151,20 @@ pub struct OptionsArgs {
|
|||||||
#[arg(help("Password hash for server authentication (requires --server)"))]
|
#[arg(help("Password hash for server authentication (requires --server)"))]
|
||||||
pub server_password_hash: Option<String>,
|
pub server_password_hash: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, env("KEEP_SERVER_USERNAME"))]
|
||||||
|
#[arg(help(
|
||||||
|
"Username for server Basic authentication (requires --server, defaults to 'keep')"
|
||||||
|
))]
|
||||||
|
pub server_username: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, env("KEEP_SERVER_JWT_SECRET"))]
|
||||||
|
#[arg(help("JWT secret for token-based authentication (requires --server)"))]
|
||||||
|
pub server_jwt_secret: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, env("KEEP_SERVER_JWT_SECRET_FILE"))]
|
||||||
|
#[arg(help("Path to file containing JWT secret (requires --server)"))]
|
||||||
|
pub server_jwt_secret_file: Option<PathBuf>,
|
||||||
|
|
||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
#[arg(long, env("KEEP_CLIENT_URL"), help_heading("Client Options"))]
|
#[arg(long, env("KEEP_CLIENT_URL"), help_heading("Client Options"))]
|
||||||
#[arg(help("Remote keep server URL for client mode"))]
|
#[arg(help("Remote keep server URL for client mode"))]
|
||||||
@@ -161,6 +175,16 @@ pub struct OptionsArgs {
|
|||||||
#[arg(help("Password for remote keep server authentication"))]
|
#[arg(help("Password for remote keep server authentication"))]
|
||||||
pub client_password: Option<String>,
|
pub client_password: Option<String>,
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
#[arg(long, env("KEEP_CLIENT_USERNAME"), help_heading("Client Options"))]
|
||||||
|
#[arg(help("Username for remote keep server authentication (defaults to 'keep')"))]
|
||||||
|
pub client_username: Option<String>,
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
#[arg(long, env("KEEP_CLIENT_JWT"), help_heading("Client Options"))]
|
||||||
|
#[arg(help("JWT token for remote keep server authentication"))]
|
||||||
|
pub client_jwt: Option<String>,
|
||||||
|
|
||||||
#[arg(
|
#[arg(
|
||||||
long,
|
long,
|
||||||
help("Force output even when binary data would be sent to a TTY")
|
help("Force output even when binary data would be sent to a TTY")
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::services::error::CoreError;
|
use crate::services::error::CoreError;
|
||||||
|
use base64::Engine;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
@@ -17,17 +18,26 @@ pub struct ItemInfo {
|
|||||||
pub struct KeepClient {
|
pub struct KeepClient {
|
||||||
base_url: String,
|
base_url: String,
|
||||||
agent: ureq::Agent,
|
agent: ureq::Agent,
|
||||||
|
username: Option<String>,
|
||||||
password: Option<String>,
|
password: Option<String>,
|
||||||
|
jwt: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KeepClient {
|
impl KeepClient {
|
||||||
pub fn new(base_url: &str, password: Option<String>) -> Result<Self, CoreError> {
|
pub fn new(
|
||||||
|
base_url: &str,
|
||||||
|
username: Option<String>,
|
||||||
|
password: Option<String>,
|
||||||
|
jwt: Option<String>,
|
||||||
|
) -> Result<Self, CoreError> {
|
||||||
let base_url = base_url.trim_end_matches('/').to_string();
|
let base_url = base_url.trim_end_matches('/').to_string();
|
||||||
let agent = ureq::Agent::new_with_defaults();
|
let agent = ureq::Agent::new_with_defaults();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_url,
|
base_url,
|
||||||
agent,
|
agent,
|
||||||
|
username,
|
||||||
password,
|
password,
|
||||||
|
jwt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,14 +45,40 @@ impl KeepClient {
|
|||||||
&self.base_url
|
&self.base_url
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn username(&self) -> Option<&String> {
|
||||||
|
self.username.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn password(&self) -> Option<&String> {
|
pub fn password(&self) -> Option<&String> {
|
||||||
self.password.as_ref()
|
self.password.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn jwt(&self) -> Option<&String> {
|
||||||
|
self.jwt.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
fn url(&self, path: &str) -> String {
|
fn url(&self, path: &str) -> String {
|
||||||
format!("{}{}", self.base_url, path)
|
format!("{}{}", self.base_url, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the Authorization header value for the current credentials.
|
||||||
|
///
|
||||||
|
/// JWT token is sent as `Bearer <token>`.
|
||||||
|
/// Password is sent as `Basic base64(username:password)`
|
||||||
|
/// where username defaults to "keep".
|
||||||
|
fn auth_header(&self) -> Option<String> {
|
||||||
|
if let Some(ref jwt) = self.jwt {
|
||||||
|
Some(format!("Bearer {jwt}"))
|
||||||
|
} else if let Some(ref password) = self.password {
|
||||||
|
let username = self.username.as_deref().unwrap_or("keep");
|
||||||
|
let credentials = format!("{username}:{password}");
|
||||||
|
let encoded = base64::engine::general_purpose::STANDARD.encode(&credentials);
|
||||||
|
Some(format!("Basic {encoded}"))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn handle_error<T>(&self, result: Result<T, ureq::Error>) -> Result<T, CoreError> {
|
fn handle_error<T>(&self, result: Result<T, ureq::Error>) -> Result<T, CoreError> {
|
||||||
match result {
|
match result {
|
||||||
Ok(v) => Ok(v),
|
Ok(v) => Ok(v),
|
||||||
@@ -57,8 +93,8 @@ impl KeepClient {
|
|||||||
pub fn get_json<T: DeserializeOwned>(&self, path: &str) -> Result<T, CoreError> {
|
pub fn get_json<T: DeserializeOwned>(&self, path: &str) -> Result<T, CoreError> {
|
||||||
let url = self.url(path);
|
let url = self.url(path);
|
||||||
let mut req = self.agent.get(&url);
|
let mut req = self.agent.get(&url);
|
||||||
if let Some(ref password) = self.password {
|
if let Some(ref auth) = self.auth_header() {
|
||||||
req = req.header("Authorization", &format!("Bearer {password}"));
|
req = req.header("Authorization", auth);
|
||||||
}
|
}
|
||||||
let response = self.handle_error(req.call())?;
|
let response = self.handle_error(req.call())?;
|
||||||
let body: T = self.handle_error(response.into_body().read_json())?;
|
let body: T = self.handle_error(response.into_body().read_json())?;
|
||||||
@@ -81,8 +117,8 @@ impl KeepClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
let mut req = self.agent.get(&url);
|
let mut req = self.agent.get(&url);
|
||||||
if let Some(ref password) = self.password {
|
if let Some(ref auth) = self.auth_header() {
|
||||||
req = req.header("Authorization", &format!("Bearer {password}"));
|
req = req.header("Authorization", auth);
|
||||||
}
|
}
|
||||||
let response = self.handle_error(req.call())?;
|
let response = self.handle_error(req.call())?;
|
||||||
let body: T = self.handle_error(response.into_body().read_json())?;
|
let body: T = self.handle_error(response.into_body().read_json())?;
|
||||||
@@ -92,8 +128,8 @@ impl KeepClient {
|
|||||||
pub fn get_bytes(&self, path: &str) -> Result<Vec<u8>, CoreError> {
|
pub fn get_bytes(&self, path: &str) -> Result<Vec<u8>, CoreError> {
|
||||||
let url = self.url(path);
|
let url = self.url(path);
|
||||||
let mut req = self.agent.get(&url);
|
let mut req = self.agent.get(&url);
|
||||||
if let Some(ref password) = self.password {
|
if let Some(ref auth) = self.auth_header() {
|
||||||
req = req.header("Authorization", &format!("Bearer {password}"));
|
req = req.header("Authorization", auth);
|
||||||
}
|
}
|
||||||
let response = self.handle_error(req.call())?;
|
let response = self.handle_error(req.call())?;
|
||||||
let mut body = response.into_body();
|
let mut body = response.into_body();
|
||||||
@@ -135,8 +171,8 @@ impl KeepClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut req = self.agent.post(&url);
|
let mut req = self.agent.post(&url);
|
||||||
if let Some(ref password) = self.password {
|
if let Some(ref auth) = self.auth_header() {
|
||||||
req = req.header("Authorization", &format!("Bearer {password}"));
|
req = req.header("Authorization", auth);
|
||||||
}
|
}
|
||||||
req = req.header("Content-Type", "application/octet-stream");
|
req = req.header("Content-Type", "application/octet-stream");
|
||||||
|
|
||||||
@@ -162,8 +198,8 @@ impl KeepClient {
|
|||||||
pub fn delete(&self, path: &str) -> Result<(), CoreError> {
|
pub fn delete(&self, path: &str) -> Result<(), CoreError> {
|
||||||
let url = self.url(path);
|
let url = self.url(path);
|
||||||
let mut req = self.agent.delete(&url);
|
let mut req = self.agent.delete(&url);
|
||||||
if let Some(ref password) = self.password {
|
if let Some(ref auth) = self.auth_header() {
|
||||||
req = req.header("Authorization", &format!("Bearer {password}"));
|
req = req.header("Authorization", auth);
|
||||||
}
|
}
|
||||||
self.handle_error(req.call())?;
|
self.handle_error(req.call())?;
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -254,8 +290,8 @@ impl KeepClient {
|
|||||||
) -> Result<(), CoreError> {
|
) -> Result<(), CoreError> {
|
||||||
let url = self.url(&format!("/api/item/{id}/meta"));
|
let url = self.url(&format!("/api/item/{id}/meta"));
|
||||||
let mut req = self.agent.post(&url);
|
let mut req = self.agent.post(&url);
|
||||||
if let Some(ref password) = self.password {
|
if let Some(ref auth) = self.auth_header() {
|
||||||
req = req.header("Authorization", &format!("Bearer {password}"));
|
req = req.header("Authorization", auth);
|
||||||
}
|
}
|
||||||
req = req.header("Content-Type", "application/json");
|
req = req.header("Content-Type", "application/json");
|
||||||
|
|
||||||
@@ -274,8 +310,8 @@ impl KeepClient {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let mut req = self.agent.get(&url);
|
let mut req = self.agent.get(&url);
|
||||||
if let Some(ref password) = self.password {
|
if let Some(ref auth) = self.auth_header() {
|
||||||
req = req.header("Authorization", &format!("Bearer {password}"));
|
req = req.header("Authorization", auth);
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = self.handle_error(req.call())?;
|
let response = self.handle_error(req.call())?;
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ fn has_binary_signature(data: &[u8]) -> bool {
|
|||||||
|
|
||||||
/// Check if data looks like UTF-16 without BOM
|
/// Check if data looks like UTF-16 without BOM
|
||||||
fn looks_like_utf16(data: &[u8]) -> bool {
|
fn looks_like_utf16(data: &[u8]) -> bool {
|
||||||
if data.len() < 4 || data.len() % 2 != 0 {
|
if data.len() < 4 || !data.len().is_multiple_of(2) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,5 +3,8 @@ pub mod is_binary;
|
|||||||
/// Detects if data is binary or text based on signatures and printable ratios.
|
/// Detects if data is binary or text based on signatures and printable ratios.
|
||||||
pub mod status;
|
pub mod status;
|
||||||
|
|
||||||
|
/// Plugin schema types and discovery functions.
|
||||||
|
pub mod schema;
|
||||||
|
|
||||||
/// Standard buffer size for I/O operations (8KB)
|
/// Standard buffer size for I/O operations (8KB)
|
||||||
pub const PIPESIZE: usize = 8192;
|
pub const PIPESIZE: usize = 8192;
|
||||||
|
|||||||
166
src/common/schema.rs
Normal file
166
src/common/schema.rs
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
|
/// Value type for a plugin option.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum OptionType {
|
||||||
|
String,
|
||||||
|
Integer,
|
||||||
|
Boolean,
|
||||||
|
Any,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OptionType {
|
||||||
|
/// Infer the option type from a YAML value.
|
||||||
|
pub fn from_yaml_value(value: &serde_yaml::Value) -> Self {
|
||||||
|
match value {
|
||||||
|
serde_yaml::Value::Bool(_) => OptionType::Boolean,
|
||||||
|
serde_yaml::Value::Number(_) => OptionType::Integer,
|
||||||
|
serde_yaml::Value::String(_) => OptionType::String,
|
||||||
|
_ => OptionType::Any,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Schema for a single plugin option.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OptionSchema {
|
||||||
|
pub name: String,
|
||||||
|
pub option_type: OptionType,
|
||||||
|
pub default: Option<serde_yaml::Value>,
|
||||||
|
pub required: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Schema for a single plugin output.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OutputSchema {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Schema describing a plugin's configuration requirements.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PluginSchema {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub options: Vec<OptionSchema>,
|
||||||
|
pub outputs: Vec<OutputSchema>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gathers schemas from all registered meta plugins.
|
||||||
|
///
|
||||||
|
/// Iterates all `MetaPluginType` variants, attempts to create a default instance,
|
||||||
|
/// and collects their schemas. Plugins that fail to register (e.g., feature-gated)
|
||||||
|
/// are silently skipped.
|
||||||
|
pub fn gather_meta_plugin_schemas() -> Vec<PluginSchema> {
|
||||||
|
use crate::meta_plugin::{MetaPluginType, get_meta_plugin};
|
||||||
|
|
||||||
|
let mut schemas = Vec::new();
|
||||||
|
let mut sorted_types: Vec<MetaPluginType> = MetaPluginType::iter().collect();
|
||||||
|
sorted_types.sort_by_key(|t| t.to_string());
|
||||||
|
|
||||||
|
for plugin_type in sorted_types {
|
||||||
|
let plugin = match get_meta_plugin(plugin_type.clone(), None, None) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let name = plugin.meta_type().to_string();
|
||||||
|
|
||||||
|
let options: Vec<OptionSchema> = plugin
|
||||||
|
.options()
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| {
|
||||||
|
let option_type = OptionType::from_yaml_value(value);
|
||||||
|
let (default, required) = if value.is_null() {
|
||||||
|
(None, true)
|
||||||
|
} else {
|
||||||
|
(Some(value.clone()), false)
|
||||||
|
};
|
||||||
|
OptionSchema {
|
||||||
|
name: key.clone(),
|
||||||
|
option_type,
|
||||||
|
default,
|
||||||
|
required,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut outputs: Vec<OutputSchema> = Vec::new();
|
||||||
|
for (key, value) in plugin.outputs() {
|
||||||
|
if !value.is_null() {
|
||||||
|
outputs.push(OutputSchema {
|
||||||
|
name: key.clone(),
|
||||||
|
description: key.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Also include default outputs if outputs map is empty
|
||||||
|
if outputs.is_empty() {
|
||||||
|
for output_name in plugin.default_outputs() {
|
||||||
|
outputs.push(OutputSchema {
|
||||||
|
name: output_name.clone(),
|
||||||
|
description: output_name,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
schemas.push(PluginSchema {
|
||||||
|
name,
|
||||||
|
description: plugin.description().to_string(),
|
||||||
|
options,
|
||||||
|
outputs,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
schemas
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gathers schemas from all registered filter plugins.
|
||||||
|
///
|
||||||
|
/// Uses the global filter plugin registry to discover all registered filters,
|
||||||
|
/// creates a default instance of each, and collects their option schemas.
|
||||||
|
pub fn gather_filter_plugin_schemas() -> Vec<PluginSchema> {
|
||||||
|
use crate::services::filter_service::get_available_filter_plugins;
|
||||||
|
|
||||||
|
let plugins = get_available_filter_plugins();
|
||||||
|
let mut schemas: Vec<PluginSchema> = plugins
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, creator)| {
|
||||||
|
let plugin = creator();
|
||||||
|
let options: Vec<OptionSchema> = plugin
|
||||||
|
.options()
|
||||||
|
.iter()
|
||||||
|
.map(|opt| {
|
||||||
|
let option_type = match &opt.default {
|
||||||
|
Some(serde_json::Value::Bool(_)) => OptionType::Boolean,
|
||||||
|
Some(serde_json::Value::Number(_)) => OptionType::Integer,
|
||||||
|
Some(serde_json::Value::String(_)) => OptionType::String,
|
||||||
|
_ => OptionType::Any,
|
||||||
|
};
|
||||||
|
OptionSchema {
|
||||||
|
name: opt.name.clone(),
|
||||||
|
option_type,
|
||||||
|
default: opt.default.as_ref().map(|v| {
|
||||||
|
// Convert serde_json::Value to serde_yaml::Value
|
||||||
|
serde_yaml::to_value(v).unwrap_or(serde_yaml::Value::Null)
|
||||||
|
}),
|
||||||
|
required: opt.required,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
PluginSchema {
|
||||||
|
name: name.clone(),
|
||||||
|
description: plugin.description().to_string(),
|
||||||
|
options,
|
||||||
|
outputs: Vec::new(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
schemas.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
schemas
|
||||||
|
}
|
||||||
@@ -11,12 +11,12 @@ use std::io::{Read, Write};
|
|||||||
#[cfg(feature = "gzip")]
|
#[cfg(feature = "gzip")]
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[cfg(feature = "gzip")]
|
||||||
|
use flate2::Compression;
|
||||||
#[cfg(feature = "gzip")]
|
#[cfg(feature = "gzip")]
|
||||||
use flate2::read::GzDecoder;
|
use flate2::read::GzDecoder;
|
||||||
#[cfg(feature = "gzip")]
|
#[cfg(feature = "gzip")]
|
||||||
use flate2::write::GzEncoder;
|
use flate2::write::GzEncoder;
|
||||||
#[cfg(feature = "gzip")]
|
|
||||||
use flate2::Compression;
|
|
||||||
|
|
||||||
#[cfg(feature = "gzip")]
|
#[cfg(feature = "gzip")]
|
||||||
use crate::compression_engine::CompressionEngine;
|
use crate::compression_engine::CompressionEngine;
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::{Result, anyhow};
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
@@ -175,6 +175,7 @@ impl Clone for Box<dyn CompressionEngine> {
|
|||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref COMPRESSION_ENGINES: EnumMap<CompressionType, Box<dyn CompressionEngine>> = {
|
static ref COMPRESSION_ENGINES: EnumMap<CompressionType, Box<dyn CompressionEngine>> = {
|
||||||
|
#[allow(unused_mut)] // mut needed when gzip/lz4 features are enabled
|
||||||
let mut em = enum_map! {
|
let mut em = enum_map! {
|
||||||
CompressionType::LZ4 => Box::new(crate::compression_engine::program::CompressionEngineProgram::new(
|
CompressionType::LZ4 => Box::new(crate::compression_engine::program::CompressionEngineProgram::new(
|
||||||
"lz4",
|
"lz4",
|
||||||
@@ -231,9 +232,6 @@ pub fn get_compression_engine(ct: CompressionType) -> Result<Box<dyn Compression
|
|||||||
if engine.is_supported() {
|
if engine.is_supported() {
|
||||||
Ok(engine.clone())
|
Ok(engine.clone())
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow!(
|
Err(anyhow!("Compression engine for {ct} is not supported",))
|
||||||
"Compression engine for {} is not supported",
|
|
||||||
ct.to_string()
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{Context, Result, anyhow};
|
||||||
use log::*;
|
use log::*;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
|
|||||||
117
src/config.rs
117
src/config.rs
@@ -143,9 +143,12 @@ impl<'de> serde::Deserialize<'de> for ColumnConfig {
|
|||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
pub address: Option<String>,
|
pub address: Option<String>,
|
||||||
pub port: Option<u16>,
|
pub port: Option<u16>,
|
||||||
|
pub username: Option<String>,
|
||||||
pub password_file: Option<PathBuf>,
|
pub password_file: Option<PathBuf>,
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
pub password_hash: Option<String>,
|
pub password_hash: Option<String>,
|
||||||
|
pub jwt_secret: Option<String>,
|
||||||
|
pub jwt_secret_file: Option<PathBuf>,
|
||||||
pub cert_file: Option<PathBuf>,
|
pub cert_file: Option<PathBuf>,
|
||||||
pub key_file: Option<PathBuf>,
|
pub key_file: Option<PathBuf>,
|
||||||
pub cors_origin: Option<String>,
|
pub cors_origin: Option<String>,
|
||||||
@@ -159,7 +162,9 @@ pub struct CompressionPluginConfig {
|
|||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ClientConfig {
|
pub struct ClientConfig {
|
||||||
pub url: Option<String>,
|
pub url: Option<String>,
|
||||||
|
pub username: Option<String>,
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
|
pub jwt: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -198,7 +203,11 @@ pub struct Settings {
|
|||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
pub client_url: Option<String>,
|
pub client_url: Option<String>,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
|
pub client_username: Option<String>,
|
||||||
|
#[serde(skip)]
|
||||||
pub client_password: Option<String>,
|
pub client_password: Option<String>,
|
||||||
|
#[serde(skip)]
|
||||||
|
pub client_jwt: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Settings {
|
impl Settings {
|
||||||
@@ -281,6 +290,11 @@ impl Settings {
|
|||||||
.set_override("server.password_hash", server_password_hash.as_str())?;
|
.set_override("server.password_hash", server_password_hash.as_str())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(server_username) = &args.options.server_username {
|
||||||
|
config_builder =
|
||||||
|
config_builder.set_override("server.username", server_username.as_str())?;
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(server_address) = &args.mode.server_address {
|
if let Some(server_address) = &args.mode.server_address {
|
||||||
config_builder =
|
config_builder =
|
||||||
config_builder.set_override("server.address", server_address.as_str())?;
|
config_builder.set_override("server.address", server_address.as_str())?;
|
||||||
@@ -429,11 +443,21 @@ impl Settings {
|
|||||||
.client_url
|
.client_url
|
||||||
.clone()
|
.clone()
|
||||||
.or_else(|| settings.client.as_ref().and_then(|c| c.url.clone()));
|
.or_else(|| settings.client.as_ref().and_then(|c| c.url.clone()));
|
||||||
|
settings.client_username = args
|
||||||
|
.options
|
||||||
|
.client_username
|
||||||
|
.clone()
|
||||||
|
.or_else(|| settings.client.as_ref().and_then(|c| c.username.clone()));
|
||||||
settings.client_password = args
|
settings.client_password = args
|
||||||
.options
|
.options
|
||||||
.client_password
|
.client_password
|
||||||
.clone()
|
.clone()
|
||||||
.or_else(|| settings.client.as_ref().and_then(|c| c.password.clone()));
|
.or_else(|| settings.client.as_ref().and_then(|c| c.password.clone()));
|
||||||
|
settings.client_jwt = args
|
||||||
|
.options
|
||||||
|
.client_jwt
|
||||||
|
.clone()
|
||||||
|
.or_else(|| settings.client.as_ref().and_then(|c| c.jwt.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("CONFIG: Final settings: {settings:?}");
|
debug!("CONFIG: Final settings: {settings:?}");
|
||||||
@@ -487,6 +511,38 @@ impl Settings {
|
|||||||
self.server.as_ref().and_then(|s| s.password_hash.clone())
|
self.server.as_ref().and_then(|s| s.password_hash.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn server_username(&self) -> Option<String> {
|
||||||
|
self.server.as_ref().and_then(|s| s.username.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get JWT secret from jwt_secret_file or directly from config if configured
|
||||||
|
pub fn get_server_jwt_secret(&self) -> Result<Option<String>> {
|
||||||
|
if let Some(server) = &self.server {
|
||||||
|
// First check for jwt_secret_file
|
||||||
|
if let Some(jwt_secret_file) = &server.jwt_secret_file {
|
||||||
|
debug!("CONFIG: Reading JWT secret from file: {jwt_secret_file:?}");
|
||||||
|
let secret = fs::read_to_string(jwt_secret_file)
|
||||||
|
.with_context(|| {
|
||||||
|
format!("Failed to read JWT secret file: {jwt_secret_file:?}")
|
||||||
|
})?
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
return Ok(Some(secret));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to direct jwt_secret field
|
||||||
|
if let Some(secret) = &server.jwt_secret {
|
||||||
|
debug!("CONFIG: Using JWT secret from config");
|
||||||
|
return Ok(Some(secret.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn server_jwt_secret(&self) -> Option<String> {
|
||||||
|
self.get_server_jwt_secret().ok().flatten()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn server_address(&self) -> Option<String> {
|
pub fn server_address(&self) -> Option<String> {
|
||||||
self.server.as_ref().and_then(|s| s.address.clone())
|
self.server.as_ref().and_then(|s| s.address.clone())
|
||||||
}
|
}
|
||||||
@@ -517,4 +573,65 @@ impl Settings {
|
|||||||
.map(|plugins| plugins.iter().map(|p| p.name.clone()).collect())
|
.map(|plugins| plugins.iter().map(|p| p.name.clone()).collect())
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validates the configuration against plugin schemas.
|
||||||
|
///
|
||||||
|
/// Checks that:
|
||||||
|
/// - All configured meta plugin names are valid and registered
|
||||||
|
/// - Required options are present for each meta plugin
|
||||||
|
/// - Compression plugin name (if set) is a valid compression type
|
||||||
|
///
|
||||||
|
/// Returns a list of warning strings. An empty list means the config is valid.
|
||||||
|
pub fn validate_config(&self) -> Vec<String> {
|
||||||
|
use crate::common::schema::gather_meta_plugin_schemas;
|
||||||
|
use crate::compression_engine::CompressionType;
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
|
let mut warnings = Vec::new();
|
||||||
|
|
||||||
|
// Validate compression plugin
|
||||||
|
if let Some(ref comp) = self.compression_plugin {
|
||||||
|
let valid_types: Vec<String> =
|
||||||
|
CompressionType::iter().map(|ct| ct.to_string()).collect();
|
||||||
|
if !valid_types.contains(&comp.name) {
|
||||||
|
warnings.push(format!(
|
||||||
|
"Unknown compression_plugin.name: '{}'. Valid types: {}",
|
||||||
|
comp.name,
|
||||||
|
valid_types.join(", ")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate meta plugins
|
||||||
|
if let Some(ref plugins) = self.meta_plugins {
|
||||||
|
let schemas = gather_meta_plugin_schemas();
|
||||||
|
let schema_map: std::collections::HashMap<&str, &crate::common::schema::PluginSchema> =
|
||||||
|
schemas.iter().map(|s| (s.name.as_str(), s)).collect();
|
||||||
|
|
||||||
|
for plugin in plugins {
|
||||||
|
match schema_map.get(plugin.name.as_str()) {
|
||||||
|
Some(schema) => {
|
||||||
|
// Check required options
|
||||||
|
for opt in &schema.options {
|
||||||
|
if opt.required && !plugin.options.contains_key(&opt.name) {
|
||||||
|
warnings.push(format!(
|
||||||
|
"Meta plugin '{}': missing required option '{}'",
|
||||||
|
plugin.name, opt.name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warnings.push(format!(
|
||||||
|
"Unknown meta plugin: '{}'. Available: {}",
|
||||||
|
plugin.name,
|
||||||
|
schema_map.keys().copied().collect::<Vec<_>>().join(", ")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
warnings
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use super::{FilterPlugin, FilterOption};
|
use super::{FilterOption, FilterPlugin};
|
||||||
use std::io::{Result, Read, Write};
|
|
||||||
use std::process::{Command, Stdio, Child};
|
|
||||||
use which::which;
|
|
||||||
use log::*;
|
use log::*;
|
||||||
|
use std::io::{Read, Result, Write};
|
||||||
|
use std::process::{Child, Command, Stdio};
|
||||||
|
use which::which;
|
||||||
|
|
||||||
/// A filter that executes an external program and pipes input through it.
|
/// A filter that executes an external program and pipes input through it.
|
||||||
///
|
///
|
||||||
@@ -43,16 +43,13 @@ impl ExecFilter {
|
|||||||
/// let filter = ExecFilter::new("grep", vec!["-i", "error"], false);
|
/// let filter = ExecFilter::new("grep", vec!["-i", "error"], false);
|
||||||
/// assert!(filter.supported);
|
/// assert!(filter.supported);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new(
|
pub fn new(program: &str, args: Vec<&str>, split_whitespace: bool) -> ExecFilter {
|
||||||
program: &str,
|
|
||||||
args: Vec<&str>,
|
|
||||||
split_whitespace: bool,
|
|
||||||
) -> ExecFilter {
|
|
||||||
let program_path = which(program);
|
let program_path = which(program);
|
||||||
let supported = program_path.is_ok();
|
let supported = program_path.is_ok();
|
||||||
|
|
||||||
ExecFilter {
|
ExecFilter {
|
||||||
program: program_path.map_or_else(|| program.to_string(), |p| p.to_string_lossy().to_string()),
|
program: program_path
|
||||||
|
.map_or_else(|| program.to_string(), |p| p.to_string_lossy().to_string()),
|
||||||
args: args.iter().map(|s| s.to_string()).collect(),
|
args: args.iter().map(|s| s.to_string()).collect(),
|
||||||
supported,
|
supported,
|
||||||
split_whitespace,
|
split_whitespace,
|
||||||
@@ -101,7 +98,10 @@ impl FilterPlugin for ExecFilter {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("FILTER_EXEC: Executing command: {} {:?}", self.program, self.args);
|
debug!(
|
||||||
|
"FILTER_EXEC: Executing command: {} {:?}",
|
||||||
|
self.program, self.args
|
||||||
|
);
|
||||||
|
|
||||||
// Read all input first
|
// Read all input first
|
||||||
let mut input_data = Vec::new();
|
let mut input_data = Vec::new();
|
||||||
@@ -142,8 +142,7 @@ impl FilterPlugin for ExecFilter {
|
|||||||
std::io::copy(&mut stdout, writer)?;
|
std::io::copy(&mut stdout, writer)?;
|
||||||
|
|
||||||
// Wait for the child process to finish
|
// Wait for the child process to finish
|
||||||
let output = child.wait_with_output()
|
let output = child.wait_with_output().map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
std::io::Error::new(
|
std::io::Error::new(
|
||||||
std::io::ErrorKind::Other,
|
std::io::ErrorKind::Other,
|
||||||
format!("Failed to wait on child process: {}", e),
|
format!("Failed to wait on child process: {}", e),
|
||||||
@@ -205,6 +204,10 @@ impl FilterPlugin for ExecFilter {
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Pipe input through an external command"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the plugin at module initialization time
|
// Register the plugin at module initialization time
|
||||||
|
|||||||
@@ -132,4 +132,8 @@ impl FilterPlugin for GrepFilter {
|
|||||||
required: true,
|
required: true,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Filter lines matching a regex pattern"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,6 +140,10 @@ impl FilterPlugin for HeadBytesFilter {
|
|||||||
required: true,
|
required: true,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the first N bytes"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A filter that reads the first N lines from the input stream.
|
/// A filter that reads the first N lines from the input stream.
|
||||||
@@ -270,6 +274,10 @@ impl FilterPlugin for HeadLinesFilter {
|
|||||||
required: true,
|
required: true,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the first N lines"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the plugin at module initialization time
|
// Register the plugin at module initialization time
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ pub mod head;
|
|||||||
pub mod skip;
|
pub mod skip;
|
||||||
pub mod strip_ansi;
|
pub mod strip_ansi;
|
||||||
pub mod tail;
|
pub mod tail;
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
pub mod tokens;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -170,6 +172,15 @@ pub trait FilterPlugin: Send {
|
|||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
fn options(&self) -> Vec<FilterOption>;
|
fn options(&self) -> Vec<FilterOption>;
|
||||||
|
|
||||||
|
/// Returns a human-readable description of this filter.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A description string (empty by default).
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enum representing the different types of filters.
|
/// Enum representing the different types of filters.
|
||||||
@@ -192,6 +203,12 @@ pub enum FilterType {
|
|||||||
SkipLines,
|
SkipLines,
|
||||||
Grep,
|
Grep,
|
||||||
StripAnsi,
|
StripAnsi,
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
HeadTokens,
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
SkipTokens,
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
TailTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Maximum buffer size (256 MB) for filter chain intermediate results.
|
/// Maximum buffer size (256 MB) for filter chain intermediate results.
|
||||||
@@ -490,6 +507,12 @@ fn create_filter_with_options(
|
|||||||
FilterType::SkipBytes => skip::SkipBytesFilter::new(0).options(),
|
FilterType::SkipBytes => skip::SkipBytesFilter::new(0).options(),
|
||||||
FilterType::SkipLines => skip::SkipLinesFilter::new(0).options(),
|
FilterType::SkipLines => skip::SkipLinesFilter::new(0).options(),
|
||||||
FilterType::StripAnsi => strip_ansi::StripAnsiFilter::new().options(),
|
FilterType::StripAnsi => strip_ansi::StripAnsiFilter::new().options(),
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::HeadTokens => tokens::HeadTokensFilter::new(0).options(),
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::SkipTokens => tokens::SkipTokensFilter::new(0).options(),
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::TailTokens => tokens::TailTokensFilter::new(0).options(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut options = HashMap::new();
|
let mut options = HashMap::new();
|
||||||
@@ -658,6 +681,72 @@ fn create_specific_filter(
|
|||||||
}
|
}
|
||||||
Ok(Box::new(strip_ansi::StripAnsiFilter::new()))
|
Ok(Box::new(strip_ansi::StripAnsiFilter::new()))
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::HeadTokens => {
|
||||||
|
let count = options
|
||||||
|
.get("count")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|n| n as usize)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidInput,
|
||||||
|
"head_tokens filter requires 'count' parameter",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let encoding = options
|
||||||
|
.get("encoding")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(|s| s.parse::<crate::tokenizer::TokenEncoding>().ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let mut f = tokens::HeadTokensFilter::new(count);
|
||||||
|
f.tokenizer = crate::tokenizer::get_tokenizer(encoding).clone();
|
||||||
|
f.encoding = encoding;
|
||||||
|
Ok(Box::new(f))
|
||||||
|
}
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::SkipTokens => {
|
||||||
|
let count = options
|
||||||
|
.get("count")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|n| n as usize)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidInput,
|
||||||
|
"skip_tokens filter requires 'count' parameter",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let encoding = options
|
||||||
|
.get("encoding")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(|s| s.parse::<crate::tokenizer::TokenEncoding>().ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let mut f = tokens::SkipTokensFilter::new(count);
|
||||||
|
f.tokenizer = crate::tokenizer::get_tokenizer(encoding).clone();
|
||||||
|
f.encoding = encoding;
|
||||||
|
Ok(Box::new(f))
|
||||||
|
}
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
FilterType::TailTokens => {
|
||||||
|
let count = options
|
||||||
|
.get("count")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|n| n as usize)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidInput,
|
||||||
|
"tail_tokens filter requires 'count' parameter",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let encoding = options
|
||||||
|
.get("encoding")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(|s| s.parse::<crate::tokenizer::TokenEncoding>().ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let mut f = tokens::TailTokensFilter::new(count);
|
||||||
|
f.tokenizer = crate::tokenizer::get_tokenizer(encoding).clone();
|
||||||
|
f.encoding = encoding;
|
||||||
|
Ok(Box::new(f))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,10 @@ impl FilterPlugin for SkipBytesFilter {
|
|||||||
required: true,
|
required: true,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Skip the first N bytes"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A filter that skips the first N lines from the input stream.
|
/// A filter that skips the first N lines from the input stream.
|
||||||
@@ -137,6 +141,10 @@ impl FilterPlugin for SkipLinesFilter {
|
|||||||
required: true,
|
required: true,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Skip the first N lines"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the plugin at module initialization time
|
// Register the plugin at module initialization time
|
||||||
|
|||||||
@@ -56,4 +56,8 @@ impl FilterPlugin for StripAnsiFilter {
|
|||||||
fn options(&self) -> Vec<FilterOption> {
|
fn options(&self) -> Vec<FilterOption> {
|
||||||
Vec::new() // strip_ansi doesn't take any options
|
Vec::new() // strip_ansi doesn't take any options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Strip ANSI escape sequences"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,6 +82,10 @@ impl FilterPlugin for TailBytesFilter {
|
|||||||
required: true,
|
required: true,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the last N bytes"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A filter that reads the last N lines from the input stream.
|
/// A filter that reads the last N lines from the input stream.
|
||||||
@@ -156,6 +160,10 @@ impl FilterPlugin for TailLinesFilter {
|
|||||||
required: true,
|
required: true,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the last N lines"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the plugin at module initialization time
|
// Register the plugin at module initialization time
|
||||||
|
|||||||
502
src/filter_plugin/tokens.rs
Normal file
502
src/filter_plugin/tokens.rs
Normal file
@@ -0,0 +1,502 @@
|
|||||||
|
use super::{FilterOption, FilterPlugin};
|
||||||
|
use crate::common::PIPESIZE;
|
||||||
|
use crate::services::filter_service::register_filter_plugin;
|
||||||
|
use crate::tokenizer::{TokenEncoding, Tokenizer, get_tokenizer};
|
||||||
|
use std::io::{Read, Result, Write};
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// head_tokens
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A filter that outputs only the first N tokens of the input stream.
|
||||||
|
///
|
||||||
|
/// Streams bytes directly until the token limit is reached. When the limit
|
||||||
|
/// falls mid-chunk, uses `split_by_token_iter` to find the exact byte boundary
|
||||||
|
/// without allocating token strings beyond what is needed.
|
||||||
|
pub struct HeadTokensFilter {
|
||||||
|
pub remaining: usize,
|
||||||
|
pub tokenizer: Tokenizer,
|
||||||
|
pub encoding: TokenEncoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HeadTokensFilter {
|
||||||
|
pub fn new(count: usize) -> Self {
|
||||||
|
let encoding = TokenEncoding::default();
|
||||||
|
Self {
|
||||||
|
remaining: count,
|
||||||
|
tokenizer: get_tokenizer(encoding).clone(),
|
||||||
|
encoding,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilterPlugin for HeadTokensFilter {
|
||||||
|
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
if self.remaining == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenizer = &self.tokenizer;
|
||||||
|
let mut buffer = vec![0u8; PIPESIZE];
|
||||||
|
let mut total_tokens = 0usize;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut buffer)?;
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = &buffer[..n];
|
||||||
|
let text = String::from_utf8_lossy(chunk);
|
||||||
|
let chunk_tokens = tokenizer.count(&text);
|
||||||
|
|
||||||
|
if total_tokens + chunk_tokens <= self.remaining {
|
||||||
|
// Entire chunk fits — write it directly
|
||||||
|
writer.write_all(chunk)?;
|
||||||
|
total_tokens += chunk_tokens;
|
||||||
|
if total_tokens >= self.remaining {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Cutoff is within this chunk — use iterator to find exact
|
||||||
|
// boundary without allocating all token strings
|
||||||
|
let tokens_to_write = self.remaining - total_tokens;
|
||||||
|
let mut byte_pos = 0usize;
|
||||||
|
for token_str in tokenizer.split_by_token_iter(&text).take(tokens_to_write) {
|
||||||
|
byte_pos += token_str
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?
|
||||||
|
.len();
|
||||||
|
}
|
||||||
|
let write_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
|
||||||
|
writer.write_all(&chunk[..write_len])?;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||||
|
Box::new(Self {
|
||||||
|
remaining: self.remaining,
|
||||||
|
tokenizer: get_tokenizer(self.encoding).clone(),
|
||||||
|
encoding: self.encoding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> Vec<FilterOption> {
|
||||||
|
vec![
|
||||||
|
FilterOption {
|
||||||
|
name: "count".to_string(),
|
||||||
|
default: None,
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
FilterOption {
|
||||||
|
name: "encoding".to_string(),
|
||||||
|
default: Some(serde_json::Value::String("cl100k_base".to_string())),
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the first N LLM tokens"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// skip_tokens
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A filter that skips the first N tokens of the input stream and outputs the rest.
|
||||||
|
pub struct SkipTokensFilter {
|
||||||
|
pub remaining: usize,
|
||||||
|
pub tokenizer: Tokenizer,
|
||||||
|
pub encoding: TokenEncoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SkipTokensFilter {
|
||||||
|
pub fn new(count: usize) -> Self {
|
||||||
|
let encoding = TokenEncoding::default();
|
||||||
|
Self {
|
||||||
|
remaining: count,
|
||||||
|
tokenizer: get_tokenizer(encoding).clone(),
|
||||||
|
encoding,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilterPlugin for SkipTokensFilter {
|
||||||
|
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
if self.remaining == 0 {
|
||||||
|
return std::io::copy(reader, writer).map(|_| ());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenizer = &self.tokenizer;
|
||||||
|
let mut buffer = vec![0u8; PIPESIZE];
|
||||||
|
let mut total_tokens = 0usize;
|
||||||
|
let mut done_skipping = false;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut buffer)?;
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if done_skipping {
|
||||||
|
writer.write_all(&buffer[..n])?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = &buffer[..n];
|
||||||
|
let text = String::from_utf8_lossy(chunk);
|
||||||
|
let chunk_tokens = tokenizer.count(&text);
|
||||||
|
|
||||||
|
if total_tokens + chunk_tokens <= self.remaining {
|
||||||
|
// Entire chunk is skipped
|
||||||
|
total_tokens += chunk_tokens;
|
||||||
|
if total_tokens >= self.remaining {
|
||||||
|
done_skipping = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Cutoff is within this chunk — use iterator to skip past
|
||||||
|
// the boundary without allocating all token strings
|
||||||
|
let tokens_to_skip = self.remaining - total_tokens;
|
||||||
|
let mut byte_pos = 0usize;
|
||||||
|
for token_str in tokenizer.split_by_token_iter(&text).take(tokens_to_skip) {
|
||||||
|
byte_pos += token_str
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?
|
||||||
|
.len();
|
||||||
|
}
|
||||||
|
let skip_len = map_lossy_pos_to_bytes(chunk, &text, byte_pos);
|
||||||
|
if skip_len < n {
|
||||||
|
writer.write_all(&chunk[skip_len..])?;
|
||||||
|
}
|
||||||
|
done_skipping = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||||
|
Box::new(Self {
|
||||||
|
remaining: self.remaining,
|
||||||
|
tokenizer: get_tokenizer(self.encoding).clone(),
|
||||||
|
encoding: self.encoding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> Vec<FilterOption> {
|
||||||
|
vec![
|
||||||
|
FilterOption {
|
||||||
|
name: "count".to_string(),
|
||||||
|
default: None,
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
FilterOption {
|
||||||
|
name: "encoding".to_string(),
|
||||||
|
default: Some(serde_json::Value::String("cl100k_base".to_string())),
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Skip the first N LLM tokens"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// tail_tokens
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A filter that outputs only the last N tokens of the input stream.
|
||||||
|
///
|
||||||
|
/// Buffers all bytes from the stream, then at finalize tokenizes the
|
||||||
|
/// content and writes only the last N tokens.
|
||||||
|
pub struct TailTokensFilter {
|
||||||
|
pub count: usize,
|
||||||
|
/// Buffer holding all bytes from the stream.
|
||||||
|
buffer: Vec<u8>,
|
||||||
|
pub tokenizer: Tokenizer,
|
||||||
|
pub encoding: TokenEncoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TailTokensFilter {
|
||||||
|
pub fn new(count: usize) -> Self {
|
||||||
|
let encoding = TokenEncoding::default();
|
||||||
|
Self {
|
||||||
|
count,
|
||||||
|
buffer: Vec::with_capacity(PIPESIZE),
|
||||||
|
tokenizer: get_tokenizer(encoding).clone(),
|
||||||
|
encoding,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilterPlugin for TailTokensFilter {
|
||||||
|
fn filter(&mut self, reader: &mut dyn Read, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
if self.count == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenizer = &self.tokenizer;
|
||||||
|
|
||||||
|
// Buffer all bytes from the stream
|
||||||
|
std::io::copy(reader, &mut self.buffer)?;
|
||||||
|
|
||||||
|
if self.buffer.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let text = String::from_utf8_lossy(&self.buffer);
|
||||||
|
let token_strs = tokenizer
|
||||||
|
.split_by_token(&text)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
|
|
||||||
|
if token_strs.len() <= self.count {
|
||||||
|
// All tokens fit — write everything
|
||||||
|
writer.write_all(&self.buffer)?;
|
||||||
|
} else {
|
||||||
|
// Write only the last N tokens
|
||||||
|
let skip = token_strs.len() - self.count;
|
||||||
|
let mut byte_offset = 0usize;
|
||||||
|
for token_str in token_strs.iter().take(skip) {
|
||||||
|
byte_offset += token_str.len();
|
||||||
|
}
|
||||||
|
let write_len = map_lossy_pos_to_bytes(&self.buffer, &text, byte_offset);
|
||||||
|
if write_len < self.buffer.len() {
|
||||||
|
writer.write_all(&self.buffer[write_len..])?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_box(&self) -> Box<dyn FilterPlugin> {
|
||||||
|
Box::new(Self {
|
||||||
|
count: self.count,
|
||||||
|
buffer: self.buffer.clone(),
|
||||||
|
tokenizer: get_tokenizer(self.encoding).clone(),
|
||||||
|
encoding: self.encoding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> Vec<FilterOption> {
|
||||||
|
vec![
|
||||||
|
FilterOption {
|
||||||
|
name: "count".to_string(),
|
||||||
|
default: None,
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
FilterOption {
|
||||||
|
name: "encoding".to_string(),
|
||||||
|
default: Some(serde_json::Value::String("cl100k_base".to_string())),
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the last N LLM tokens"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Map a byte position in a lossy string back to a position in the original byte slice.
|
||||||
|
///
|
||||||
|
/// `String::from_utf8_lossy` replaces invalid UTF-8 bytes with the Unicode
|
||||||
|
/// replacement character (U+FFFD), which encodes to 3 bytes in UTF-8. This
|
||||||
|
/// function walks both the original bytes and the lossy string in lockstep,
|
||||||
|
/// finding the original byte position that corresponds to `lossy_pos`.
|
||||||
|
fn map_lossy_pos_to_bytes(original: &[u8], lossy: &str, lossy_pos: usize) -> usize {
|
||||||
|
if lossy_pos == 0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let replacement = '\u{FFFD}';
|
||||||
|
let replacement_len = replacement.len_utf8(); // 3 bytes
|
||||||
|
|
||||||
|
let mut orig_idx = 0usize;
|
||||||
|
let mut lossy_idx = 0usize;
|
||||||
|
let lossy_bytes = lossy.as_bytes();
|
||||||
|
|
||||||
|
while lossy_idx < lossy_pos && orig_idx < original.len() {
|
||||||
|
// Try to decode the next character from the original bytes
|
||||||
|
match std::str::from_utf8(&original[orig_idx..]) {
|
||||||
|
Ok("") => break,
|
||||||
|
Ok(s) => {
|
||||||
|
let ch = s.chars().next().unwrap();
|
||||||
|
let ch_len = ch.len_utf8();
|
||||||
|
// Check if this is a replacement character in the lossy string
|
||||||
|
if ch == replacement
|
||||||
|
&& lossy_idx + replacement_len <= lossy_pos
|
||||||
|
&& lossy_bytes[lossy_idx..].starts_with(
|
||||||
|
&replacement.encode_utf8(&mut [0; 4]).as_bytes()[..replacement_len],
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// Could be a real U+FFFD or a replacement of invalid bytes.
|
||||||
|
// If the original byte at this position is valid UTF-8 start, it's real.
|
||||||
|
if original[orig_idx] < 0x80 || original[orig_idx] >= 0xC0 {
|
||||||
|
// Real character
|
||||||
|
orig_idx += ch_len;
|
||||||
|
lossy_idx += ch_len;
|
||||||
|
} else {
|
||||||
|
// Invalid byte that was replaced — advance original by 1
|
||||||
|
orig_idx += 1;
|
||||||
|
lossy_idx += replacement_len;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
orig_idx += ch_len;
|
||||||
|
lossy_idx += ch_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let valid = e.valid_up_to();
|
||||||
|
if valid > 0 {
|
||||||
|
// Some valid bytes, then invalid
|
||||||
|
orig_idx += valid;
|
||||||
|
lossy_idx += valid;
|
||||||
|
} else {
|
||||||
|
// Invalid byte — in lossy it becomes 3-byte replacement char
|
||||||
|
orig_idx += 1;
|
||||||
|
lossy_idx += replacement_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orig_idx.min(original.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Registration
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[ctor::ctor]
|
||||||
|
fn register_token_filters() {
|
||||||
|
register_filter_plugin("head_tokens", || Box::new(HeadTokensFilter::new(0)));
|
||||||
|
register_filter_plugin("skip_tokens", || Box::new(SkipTokensFilter::new(0)));
|
||||||
|
register_filter_plugin("tail_tokens", || Box::new(TailTokensFilter::new(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
fn make_tokenizer() -> Tokenizer {
|
||||||
|
get_tokenizer(TokenEncoding::Cl100kBase).clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_head_tokens_basic() {
|
||||||
|
let mut filter = HeadTokensFilter::new(3);
|
||||||
|
filter.tokenizer = make_tokenizer();
|
||||||
|
|
||||||
|
let input = b"The quick brown fox";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&output);
|
||||||
|
// "The quick brown" is typically 3 tokens
|
||||||
|
assert!(!result.is_empty());
|
||||||
|
assert!(result.len() <= input.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_head_tokens_zero() {
|
||||||
|
let mut filter = HeadTokensFilter::new(0);
|
||||||
|
filter.tokenizer = make_tokenizer();
|
||||||
|
|
||||||
|
let input = b"The quick brown fox";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert!(output.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_head_tokens_more_than_available() {
|
||||||
|
let mut filter = HeadTokensFilter::new(1000);
|
||||||
|
filter.tokenizer = make_tokenizer();
|
||||||
|
|
||||||
|
let input = b"Hello world";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert_eq!(output, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skip_tokens_basic() {
|
||||||
|
let mut filter = SkipTokensFilter::new(2);
|
||||||
|
filter.tokenizer = make_tokenizer();
|
||||||
|
|
||||||
|
let input = b"The quick brown fox";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&output);
|
||||||
|
// Should have skipped some tokens
|
||||||
|
assert!(result.len() < input.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skip_tokens_zero() {
|
||||||
|
let mut filter = SkipTokensFilter::new(0);
|
||||||
|
filter.tokenizer = make_tokenizer();
|
||||||
|
|
||||||
|
let input = b"Hello world";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert_eq!(output, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tail_tokens_basic() {
|
||||||
|
let mut filter = TailTokensFilter::new(2);
|
||||||
|
filter.tokenizer = make_tokenizer();
|
||||||
|
|
||||||
|
let input = b"The quick brown fox jumps over the lazy dog";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&output);
|
||||||
|
// Should only have last 2 tokens
|
||||||
|
assert!(!result.is_empty());
|
||||||
|
assert!(result.len() < input.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tail_tokens_zero() {
|
||||||
|
let mut filter = TailTokensFilter::new(0);
|
||||||
|
filter.tokenizer = make_tokenizer();
|
||||||
|
|
||||||
|
let input = b"Hello world";
|
||||||
|
let mut output = Vec::new();
|
||||||
|
filter.filter(&mut Cursor::new(input), &mut output).unwrap();
|
||||||
|
assert!(output.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_map_lossy_pos_ascii() {
|
||||||
|
let original = b"Hello world";
|
||||||
|
let lossy = String::from_utf8_lossy(original);
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 0), 0);
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 11), 11);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_map_lossy_pos_with_invalid_utf8() {
|
||||||
|
let original = b"Hello\x80world";
|
||||||
|
let lossy = String::from_utf8_lossy(original);
|
||||||
|
// lossy = "Hello\u{FFFD}world" (13 bytes)
|
||||||
|
// Position 5 in lossy = after "Hello" = position 5 in original
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 5), 5);
|
||||||
|
// Position 8 in lossy = "Hello\u{FFFD}" = position 6 in original
|
||||||
|
// (the invalid byte \x80 at position 5 was replaced)
|
||||||
|
assert_eq!(map_lossy_pos_to_bytes(original, &lossy, 8), 6);
|
||||||
|
}
|
||||||
|
}
|
||||||
11
src/lib.rs
11
src/lib.rs
@@ -43,6 +43,9 @@ pub mod services;
|
|||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
pub mod client;
|
pub mod client;
|
||||||
|
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
pub mod tokenizer;
|
||||||
|
|
||||||
// Re-export Args struct for library usage
|
// Re-export Args struct for library usage
|
||||||
pub use args::Args;
|
pub use args::Args;
|
||||||
// Re-export PIPESIZE constant
|
// Re-export PIPESIZE constant
|
||||||
@@ -52,6 +55,10 @@ pub use common::PIPESIZE;
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use filter_plugin::{grep, head, skip, strip_ansi, tail};
|
use filter_plugin::{grep, head, skip, strip_ansi, tail};
|
||||||
|
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
use filter_plugin::tokens as token_filters;
|
||||||
|
|
||||||
use crate::meta_plugin::{
|
use crate::meta_plugin::{
|
||||||
cwd, digest, env, exec, hostname, keep_pid, read_rate, read_time, shell, shell_pid, user,
|
cwd, digest, env, exec, hostname, keep_pid, read_rate, read_time, shell, shell_pid, user,
|
||||||
};
|
};
|
||||||
@@ -60,6 +67,10 @@ use crate::meta_plugin::{
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use crate::meta_plugin::magic_file;
|
use crate::meta_plugin::magic_file;
|
||||||
|
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
use crate::meta_plugin::tokens;
|
||||||
|
|
||||||
/// Initializes plugins at library load time.
|
/// Initializes plugins at library load time.
|
||||||
///
|
///
|
||||||
/// Plugin registration happens automatically via `#[ctor]` constructors
|
/// Plugin registration happens automatically via `#[ctor]` constructors
|
||||||
|
|||||||
48
src/main.rs
48
src/main.rs
@@ -1,3 +1,6 @@
|
|||||||
|
use std::io::Write;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use anyhow::{Context, Error, Result, anyhow};
|
use anyhow::{Context, Error, Result, anyhow};
|
||||||
use clap::error::ErrorKind;
|
use clap::error::ErrorKind;
|
||||||
use clap::*;
|
use clap::*;
|
||||||
@@ -25,13 +28,36 @@ fn main() -> Result<(), Error> {
|
|||||||
cmd.error(ErrorKind::ValueValidation, e).exit();
|
cmd.error(ErrorKind::ValueValidation, e).exit();
|
||||||
}
|
}
|
||||||
|
|
||||||
stderrlog::new()
|
let start = Instant::now();
|
||||||
.module(module_path!())
|
let mut builder = env_logger::Builder::new();
|
||||||
.quiet(args.options.quiet)
|
let show_module = args.options.verbose >= 2;
|
||||||
.verbosity(usize::from(args.options.verbose + 2))
|
builder.format(move |buf, record| {
|
||||||
//.timestamp(stderrlog::Timestamp::Second)
|
let elapsed = start.elapsed();
|
||||||
.init()
|
let ts = format!("[{:>6}.{:03}]", elapsed.as_secs(), elapsed.subsec_millis());
|
||||||
.unwrap();
|
if show_module {
|
||||||
|
writeln!(
|
||||||
|
buf,
|
||||||
|
"{} {:<5} {}: {}",
|
||||||
|
ts,
|
||||||
|
record.level(),
|
||||||
|
record.module_path().unwrap_or("?"),
|
||||||
|
record.args()
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
writeln!(buf, "{} {:<5} {}", ts, record.level(), record.args())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let max_level = if args.options.quiet {
|
||||||
|
LevelFilter::Off
|
||||||
|
} else {
|
||||||
|
match args.options.verbose {
|
||||||
|
0 => LevelFilter::Warn,
|
||||||
|
1 => LevelFilter::Debug,
|
||||||
|
_ => LevelFilter::Trace,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
builder.filter_module("keep", max_level);
|
||||||
|
builder.init();
|
||||||
|
|
||||||
debug!("MAIN: Start");
|
debug!("MAIN: Start");
|
||||||
|
|
||||||
@@ -188,8 +214,12 @@ fn main() -> Result<(), Error> {
|
|||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
{
|
{
|
||||||
if let Some(ref client_url) = settings.client_url {
|
if let Some(ref client_url) = settings.client_url {
|
||||||
let client =
|
let client = keep::client::KeepClient::new(
|
||||||
keep::client::KeepClient::new(client_url, settings.client_password.clone())?;
|
client_url,
|
||||||
|
settings.client_username.clone(),
|
||||||
|
settings.client_password.clone(),
|
||||||
|
settings.client_jwt.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
return match mode {
|
return match mode {
|
||||||
KeepModes::Save => {
|
KeepModes::Save => {
|
||||||
|
|||||||
@@ -258,6 +258,10 @@ impl MetaPlugin for DigestMetaPlugin {
|
|||||||
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||||
Ok(self.base.options_mut())
|
Ok(self.base.options_mut())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parallel_safe(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use crate::meta_plugin::register_meta_plugin;
|
use crate::meta_plugin::register_meta_plugin;
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ pub struct MetaPluginExec {
|
|||||||
pub supported: bool,
|
pub supported: bool,
|
||||||
pub split_whitespace: bool,
|
pub split_whitespace: bool,
|
||||||
process: Option<Child>,
|
process: Option<Child>,
|
||||||
writer: Option<Box<dyn Write>>,
|
writer: Option<Box<dyn Write + Send>>,
|
||||||
result: Option<String>,
|
result: Option<String>,
|
||||||
base: BaseMetaPlugin,
|
base: BaseMetaPlugin,
|
||||||
}
|
}
|
||||||
@@ -263,6 +263,10 @@ impl MetaPlugin for MetaPluginExec {
|
|||||||
fn default_outputs(&self) -> Vec<String> {
|
fn default_outputs(&self) -> Vec<String> {
|
||||||
vec!["exec".to_string()]
|
vec!["exec".to_string()]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parallel_safe(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use crate::meta_plugin::register_meta_plugin;
|
use crate::meta_plugin::register_meta_plugin;
|
||||||
|
|||||||
@@ -12,13 +12,36 @@ use crate::meta_plugin::{
|
|||||||
process_metadata_outputs,
|
process_metadata_outputs,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Wrapper around `magic::Cookie` that is Send.
|
||||||
|
///
|
||||||
|
/// Libmagic cookies are thread-safe per-instance (separate cookies have
|
||||||
|
/// independent state). The raw pointer `*mut magic_sys::magic_set` does not
|
||||||
|
/// auto-derive Send, but concurrent access to distinct cookies is safe per
|
||||||
|
/// the libmagic documentation.
|
||||||
|
#[cfg(feature = "magic")]
|
||||||
|
struct SendCookie(Cookie);
|
||||||
|
|
||||||
|
#[cfg(feature = "magic")]
|
||||||
|
// SAFETY: Each SendCookie owns a distinct libmagic instance. Libmagic
|
||||||
|
// documents that separate cookies can be used from different threads
|
||||||
|
// concurrently without synchronization.
|
||||||
|
#[allow(unsafe_code)]
|
||||||
|
unsafe impl Send for SendCookie {}
|
||||||
|
|
||||||
|
#[cfg(feature = "magic")]
|
||||||
|
impl std::fmt::Debug for SendCookie {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("SendCookie").finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "magic")]
|
#[cfg(feature = "magic")]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MagicFileMetaPluginImpl {
|
pub struct MagicFileMetaPluginImpl {
|
||||||
buffer: Vec<u8>,
|
buffer: Vec<u8>,
|
||||||
max_buffer_size: usize,
|
max_buffer_size: usize,
|
||||||
is_finalized: bool,
|
is_finalized: bool,
|
||||||
cookie: Option<Cookie>,
|
cookie: Option<SendCookie>,
|
||||||
base: BaseMetaPlugin,
|
base: BaseMetaPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +74,8 @@ impl MagicFileMetaPluginImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_magic_result(&self, flags: CookieFlags) -> io::Result<String> {
|
fn get_magic_result(&self, flags: CookieFlags) -> io::Result<String> {
|
||||||
if let Some(cookie) = &self.cookie {
|
if let Some(send_cookie) = &self.cookie {
|
||||||
|
let cookie = &send_cookie.0;
|
||||||
cookie
|
cookie
|
||||||
.set_flags(flags)
|
.set_flags(flags)
|
||||||
.map_err(|e| io::Error::other(format!("Failed to set magic flags: {e}")))?;
|
.map_err(|e| io::Error::other(format!("Failed to set magic flags: {e}")))?;
|
||||||
@@ -125,7 +149,7 @@ impl MetaPlugin for MagicFileMetaPluginImpl {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
self.cookie = Some(cookie);
|
self.cookie = Some(SendCookie(cookie));
|
||||||
|
|
||||||
MetaPluginResponse {
|
MetaPluginResponse {
|
||||||
metadata: Vec::new(),
|
metadata: Vec::new(),
|
||||||
@@ -210,6 +234,10 @@ impl MetaPlugin for MagicFileMetaPluginImpl {
|
|||||||
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||||
Ok(self.base.options_mut())
|
Ok(self.base.options_mut())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parallel_safe(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "magic")]
|
#[cfg(feature = "magic")]
|
||||||
@@ -308,17 +336,16 @@ impl FallbackMagicFileMetaPlugin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get human-readable file type via --brief
|
// Get human-readable file type via --brief
|
||||||
if let Some(file_type) = self.run_file_command(&["--brief"]) {
|
if let Some(file_type) = self.run_file_command(&["--brief"])
|
||||||
if !file_type.is_empty() {
|
&& !file_type.is_empty()
|
||||||
if let Some(meta_data) = process_metadata_outputs(
|
&& let Some(meta_data) = process_metadata_outputs(
|
||||||
"file_type",
|
"file_type",
|
||||||
serde_yaml::Value::String(file_type),
|
serde_yaml::Value::String(file_type),
|
||||||
self.base.outputs(),
|
self.base.outputs(),
|
||||||
) {
|
)
|
||||||
|
{
|
||||||
metadata.push(meta_data);
|
metadata.push(meta_data);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata
|
metadata
|
||||||
}
|
}
|
||||||
@@ -415,6 +442,10 @@ impl MetaPlugin for FallbackMagicFileMetaPlugin {
|
|||||||
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||||
Ok(self.base.options_mut())
|
Ok(self.base.options_mut())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parallel_safe(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "magic"))]
|
#[cfg(not(feature = "magic"))]
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ pub mod read_time;
|
|||||||
pub mod shell;
|
pub mod shell;
|
||||||
pub mod shell_pid;
|
pub mod shell_pid;
|
||||||
pub mod text;
|
pub mod text;
|
||||||
|
#[cfg(feature = "tokens")]
|
||||||
|
pub mod tokens;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
// pub mod text; // Removed duplicate
|
|
||||||
|
|
||||||
pub use digest::DigestMetaPlugin;
|
pub use digest::DigestMetaPlugin;
|
||||||
pub use exec::MetaPluginExec;
|
pub use exec::MetaPluginExec;
|
||||||
@@ -232,6 +233,7 @@ pub enum MetaPluginType {
|
|||||||
Hostname,
|
Hostname,
|
||||||
Exec,
|
Exec,
|
||||||
Env,
|
Env,
|
||||||
|
Tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Central function to handle metadata output with name mapping.
|
/// Central function to handle metadata output with name mapping.
|
||||||
@@ -316,7 +318,7 @@ pub fn process_metadata_outputs(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait MetaPlugin
|
pub trait MetaPlugin: Send
|
||||||
where
|
where
|
||||||
Self: 'static,
|
Self: 'static,
|
||||||
{
|
{
|
||||||
@@ -477,6 +479,82 @@ where
|
|||||||
vec![self.meta_type().to_string()]
|
vec![self.meta_type().to_string()]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a description of this plugin for display in config templates.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A description string (empty by default).
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this plugin can execute concurrently with other
|
||||||
|
/// parallel-safe plugins.
|
||||||
|
///
|
||||||
|
/// Plugins that do significant per-chunk work (hashing, tokenization,
|
||||||
|
/// piping to child processes) should return true. The MetaService will
|
||||||
|
/// run all parallel-safe plugins in separate threads per phase, then
|
||||||
|
/// process results sequentially.
|
||||||
|
fn parallel_safe(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds the schema for this plugin from its options and outputs.
|
||||||
|
///
|
||||||
|
/// Default implementation infers option types from YAML values and
|
||||||
|
/// collects enabled outputs.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A `PluginSchema` describing this plugin's configuration.
|
||||||
|
fn schema(&self) -> crate::common::schema::PluginSchema {
|
||||||
|
use crate::common::schema::{OptionSchema, OptionType, OutputSchema, PluginSchema};
|
||||||
|
|
||||||
|
let options: Vec<OptionSchema> = self
|
||||||
|
.options()
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| {
|
||||||
|
let option_type = OptionType::from_yaml_value(value);
|
||||||
|
let (default, required) = if value.is_null() {
|
||||||
|
(None, true)
|
||||||
|
} else {
|
||||||
|
(Some(value.clone()), false)
|
||||||
|
};
|
||||||
|
OptionSchema {
|
||||||
|
name: key.clone(),
|
||||||
|
option_type,
|
||||||
|
default,
|
||||||
|
required,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut outputs: Vec<OutputSchema> = Vec::new();
|
||||||
|
for (key, value) in self.outputs() {
|
||||||
|
if !value.is_null() {
|
||||||
|
outputs.push(OutputSchema {
|
||||||
|
name: key.clone(),
|
||||||
|
description: key.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if outputs.is_empty() {
|
||||||
|
for output_name in self.default_outputs() {
|
||||||
|
outputs.push(OutputSchema {
|
||||||
|
name: output_name.clone(),
|
||||||
|
description: output_name,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PluginSchema {
|
||||||
|
name: self.meta_type().to_string(),
|
||||||
|
description: self.description().to_string(),
|
||||||
|
options,
|
||||||
|
outputs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Method to downcast to concrete type (for checking finalization state).
|
/// Method to downcast to concrete type (for checking finalization state).
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
|
|||||||
316
src/meta_plugin/tokens.rs
Normal file
316
src/meta_plugin/tokens.rs
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
use crate::common::PIPESIZE;
|
||||||
|
use crate::common::is_binary::is_binary;
|
||||||
|
use crate::meta_plugin::{MetaPlugin, MetaPluginResponse, MetaPluginType};
|
||||||
|
use crate::tokenizer::{TokenEncoding, get_tokenizer};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TokensMetaPlugin {
|
||||||
|
/// Buffer for binary detection (up to PIPESIZE bytes).
|
||||||
|
buffer: Option<Vec<u8>>,
|
||||||
|
max_buffer_size: usize,
|
||||||
|
is_finalized: bool,
|
||||||
|
is_binary_content: Option<bool>,
|
||||||
|
/// Running token count accumulated across chunks.
|
||||||
|
token_count: usize,
|
||||||
|
/// UTF-8 boundary carry buffer.
|
||||||
|
utf8_buffer: Vec<u8>,
|
||||||
|
base: crate::meta_plugin::BaseMetaPlugin,
|
||||||
|
/// The tokenizer encoding.
|
||||||
|
encoding: TokenEncoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokensMetaPlugin {
|
||||||
|
pub fn new(
|
||||||
|
options: Option<std::collections::HashMap<String, serde_yaml::Value>>,
|
||||||
|
outputs: Option<std::collections::HashMap<String, serde_yaml::Value>>,
|
||||||
|
) -> Self {
|
||||||
|
let mut base = crate::meta_plugin::BaseMetaPlugin::new();
|
||||||
|
|
||||||
|
base.initialize_plugin(&["token_count"], &options, &outputs);
|
||||||
|
|
||||||
|
// Set default options
|
||||||
|
let default_options = vec![
|
||||||
|
(
|
||||||
|
"token_detect_size",
|
||||||
|
serde_yaml::Value::Number(PIPESIZE.into()),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"encoding",
|
||||||
|
serde_yaml::Value::String("cl100k_base".to_string()),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (key, value) in default_options {
|
||||||
|
if !base.options.contains_key(key) {
|
||||||
|
base.options.insert(key.to_string(), value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_buffer_size = base
|
||||||
|
.options
|
||||||
|
.get("token_detect_size")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(PIPESIZE as u64) as usize;
|
||||||
|
|
||||||
|
let encoding = base
|
||||||
|
.options
|
||||||
|
.get("encoding")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.and_then(|s| s.parse::<TokenEncoding>().ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
buffer: Some(Vec::new()),
|
||||||
|
max_buffer_size,
|
||||||
|
is_finalized: false,
|
||||||
|
is_binary_content: None,
|
||||||
|
token_count: 0,
|
||||||
|
utf8_buffer: Vec::new(),
|
||||||
|
base,
|
||||||
|
encoding,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokenize a byte chunk, handling UTF-8 boundaries.
|
||||||
|
///
|
||||||
|
/// Combines with any pending UTF-8 carry bytes, converts to text,
|
||||||
|
/// and adds the token count to the running total.
|
||||||
|
///
|
||||||
|
/// Avoids unnecessary allocations when there is no pending UTF-8 carry
|
||||||
|
/// and the data is valid UTF-8.
|
||||||
|
fn count_tokens(&mut self, data: &[u8]) {
|
||||||
|
if data.is_empty() && self.utf8_buffer.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokenizer = get_tokenizer(self.encoding);
|
||||||
|
|
||||||
|
if self.utf8_buffer.is_empty() {
|
||||||
|
// Fast path: no pending carry — try to use data directly
|
||||||
|
match std::str::from_utf8(data) {
|
||||||
|
Ok(text) => {
|
||||||
|
if !text.is_empty() {
|
||||||
|
self.token_count += tokenizer.count(text);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let valid_up_to = e.valid_up_to();
|
||||||
|
if valid_up_to > 0 {
|
||||||
|
// Count the valid prefix without copying
|
||||||
|
let text =
|
||||||
|
std::str::from_utf8(&data[..valid_up_to]).expect("validated prefix");
|
||||||
|
self.token_count += tokenizer.count(text);
|
||||||
|
}
|
||||||
|
// Save invalid trailing bytes for next call
|
||||||
|
self.utf8_buffer.extend_from_slice(&data[valid_up_to..]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slow path: pending carry bytes — must build combined buffer
|
||||||
|
let mut combined = std::mem::take(&mut self.utf8_buffer);
|
||||||
|
combined.extend_from_slice(data);
|
||||||
|
|
||||||
|
match std::str::from_utf8(&combined) {
|
||||||
|
Ok(text) => {
|
||||||
|
if !text.is_empty() {
|
||||||
|
self.token_count += tokenizer.count(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let valid_up_to = e.valid_up_to();
|
||||||
|
if valid_up_to > 0 {
|
||||||
|
let text =
|
||||||
|
std::str::from_utf8(&combined[..valid_up_to]).expect("validated prefix");
|
||||||
|
self.token_count += tokenizer.count(text);
|
||||||
|
}
|
||||||
|
self.utf8_buffer.extend_from_slice(&combined[valid_up_to..]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform binary detection on the buffer.
|
||||||
|
fn detect_binary(&mut self, buffer: &[u8]) -> bool {
|
||||||
|
let result = is_binary(buffer);
|
||||||
|
self.is_binary_content = Some(result);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetaPlugin for TokensMetaPlugin {
|
||||||
|
fn is_finalized(&self) -> bool {
|
||||||
|
self.is_finalized
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_finalized(&mut self, finalized: bool) {
|
||||||
|
self.is_finalized = finalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, data: &[u8]) -> MetaPluginResponse {
|
||||||
|
if self.is_finalized {
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata: Vec::new(),
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut metadata = Vec::new();
|
||||||
|
|
||||||
|
if self.is_binary_content.is_none() {
|
||||||
|
// Add data to the buffer
|
||||||
|
let should_detect = if let Some(ref mut buffer) = self.buffer {
|
||||||
|
let remaining = self.max_buffer_size.saturating_sub(buffer.len());
|
||||||
|
let to_take = std::cmp::min(data.len(), remaining);
|
||||||
|
buffer.extend_from_slice(&data[..to_take]);
|
||||||
|
buffer.len() >= std::cmp::min(1024, self.max_buffer_size)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
if should_detect {
|
||||||
|
let buffer_data = self.buffer.as_ref().unwrap().clone();
|
||||||
|
let is_binary = self.detect_binary(&buffer_data);
|
||||||
|
|
||||||
|
if is_binary {
|
||||||
|
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||||
|
"token_count",
|
||||||
|
serde_yaml::Value::Null,
|
||||||
|
self.base.outputs(),
|
||||||
|
) {
|
||||||
|
metadata.push(md);
|
||||||
|
}
|
||||||
|
self.buffer = None;
|
||||||
|
self.is_finalized = true;
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's text — tokenize the full buffer (nothing was counted yet),
|
||||||
|
// then clear to avoid double-counting in finalize().
|
||||||
|
self.count_tokens(&buffer_data);
|
||||||
|
self.buffer = Some(Vec::new());
|
||||||
|
}
|
||||||
|
} else if self.is_binary_content == Some(false) {
|
||||||
|
self.count_tokens(data);
|
||||||
|
} else if self.is_binary_content == Some(true) {
|
||||||
|
self.is_finalized = true;
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata: Vec::new(),
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: self.is_finalized,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalize(&mut self) -> MetaPluginResponse {
|
||||||
|
if self.is_finalized {
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata: Vec::new(),
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut metadata = Vec::new();
|
||||||
|
|
||||||
|
// If binary detection hasn't completed, do it now
|
||||||
|
if self.is_binary_content.is_none()
|
||||||
|
&& let Some(buffer) = &self.buffer
|
||||||
|
&& !buffer.is_empty()
|
||||||
|
{
|
||||||
|
let buffer_data = buffer.clone();
|
||||||
|
let is_binary = self.detect_binary(&buffer_data);
|
||||||
|
|
||||||
|
if is_binary {
|
||||||
|
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||||
|
"token_count",
|
||||||
|
serde_yaml::Value::Null,
|
||||||
|
self.base.outputs(),
|
||||||
|
) {
|
||||||
|
metadata.push(md);
|
||||||
|
}
|
||||||
|
self.buffer = None;
|
||||||
|
self.is_finalized = true;
|
||||||
|
return MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenize any bytes in the buffer
|
||||||
|
if let Some(buffer) = &self.buffer {
|
||||||
|
let data = buffer.clone();
|
||||||
|
self.count_tokens(&data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process any remaining UTF-8 bytes
|
||||||
|
if !self.utf8_buffer.is_empty() {
|
||||||
|
self.count_tokens(&[]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit token count
|
||||||
|
if let Some(md) = crate::meta_plugin::process_metadata_outputs(
|
||||||
|
"token_count",
|
||||||
|
serde_yaml::Value::String(self.token_count.to_string()),
|
||||||
|
self.base.outputs(),
|
||||||
|
) {
|
||||||
|
metadata.push(md);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.buffer = None;
|
||||||
|
self.is_finalized = true;
|
||||||
|
MetaPluginResponse {
|
||||||
|
metadata,
|
||||||
|
is_finalized: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn meta_type(&self) -> MetaPluginType {
|
||||||
|
MetaPluginType::Tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn outputs(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
|
||||||
|
self.base.outputs()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn outputs_mut(
|
||||||
|
&mut self,
|
||||||
|
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||||
|
Ok(self.base.outputs_mut())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_outputs(&self) -> Vec<String> {
|
||||||
|
vec!["token_count".to_string()]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options(&self) -> &std::collections::HashMap<String, serde_yaml::Value> {
|
||||||
|
self.base.options()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn options_mut(
|
||||||
|
&mut self,
|
||||||
|
) -> anyhow::Result<&mut std::collections::HashMap<String, serde_yaml::Value>> {
|
||||||
|
Ok(self.base.options_mut())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parallel_safe(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use crate::meta_plugin::register_meta_plugin;
|
||||||
|
|
||||||
|
#[ctor::ctor]
|
||||||
|
fn register_tokens_plugin() {
|
||||||
|
register_meta_plugin(MetaPluginType::Tokens, |options, outputs| {
|
||||||
|
Box::new(TokensMetaPlugin::new(options, outputs))
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -100,11 +100,14 @@ pub fn mode(
|
|||||||
|
|
||||||
// Streamer thread: reads compressed bytes from pipe → POST to server
|
// Streamer thread: reads compressed bytes from pipe → POST to server
|
||||||
let client_url = client.base_url().to_string();
|
let client_url = client.base_url().to_string();
|
||||||
|
let client_username = client.username().cloned();
|
||||||
let client_password = client.password().cloned();
|
let client_password = client.password().cloned();
|
||||||
|
let client_jwt = client.jwt().cloned();
|
||||||
let tags_clone = tags.clone();
|
let tags_clone = tags.clone();
|
||||||
|
|
||||||
let streamer_handle = std::thread::spawn(move || -> Result<ItemInfo> {
|
let streamer_handle = std::thread::spawn(move || -> Result<ItemInfo> {
|
||||||
let streaming_client = KeepClient::new(&client_url, client_password)?;
|
let streaming_client =
|
||||||
|
KeepClient::new(&client_url, client_username, client_password, client_jwt)?;
|
||||||
let params = [
|
let params = [
|
||||||
("compress".to_string(), server_compress.to_string()),
|
("compress".to_string(), server_compress.to_string()),
|
||||||
("meta".to_string(), "false".to_string()),
|
("meta".to_string(), "false".to_string()),
|
||||||
|
|||||||
@@ -208,14 +208,13 @@ pub fn settings_meta_plugin_types(
|
|||||||
for meta_plugin_type in MetaPluginType::iter() {
|
for meta_plugin_type in MetaPluginType::iter() {
|
||||||
if let Ok(meta_plugin) =
|
if let Ok(meta_plugin) =
|
||||||
crate::meta_plugin::get_meta_plugin(meta_plugin_type.clone(), None, None)
|
crate::meta_plugin::get_meta_plugin(meta_plugin_type.clone(), None, None)
|
||||||
|
&& meta_plugin.meta_type().to_string() == trimmed_name
|
||||||
{
|
{
|
||||||
if meta_plugin.meta_type().to_string() == trimmed_name {
|
|
||||||
meta_plugin_types.push(meta_plugin_type);
|
meta_plugin_types.push(meta_plugin_type);
|
||||||
found = true;
|
found = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
cmd.error(
|
cmd.error(
|
||||||
|
|||||||
@@ -1,81 +1,17 @@
|
|||||||
use crate::meta_plugin::MetaPlugin;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::Command;
|
use clap::Command;
|
||||||
use serde::{Deserialize, Serialize};
|
use std::collections::HashMap;
|
||||||
use serde_yaml;
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
/// Mode for generating a default configuration file.
|
use crate::common::schema::{gather_filter_plugin_schemas, gather_meta_plugin_schemas};
|
||||||
///
|
use crate::compression_engine::CompressionType;
|
||||||
/// This module creates a commented YAML template with default values for settings,
|
use crate::config;
|
||||||
/// including list format, server config, compression, and meta plugins.
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
/// Default configuration structure for the generated template.
|
|
||||||
///
|
|
||||||
/// Includes core settings, list formatting, server options, compression, and meta plugins.
|
|
||||||
struct DefaultConfig {
|
|
||||||
dir: Option<String>,
|
|
||||||
list_format: Vec<ColumnConfig>,
|
|
||||||
human_readable: bool,
|
|
||||||
output_format: Option<String>,
|
|
||||||
quiet: bool,
|
|
||||||
force: bool,
|
|
||||||
server: Option<ServerConfig>,
|
|
||||||
compression_plugin: Option<CompressionPluginConfig>,
|
|
||||||
meta_plugins: Option<Vec<MetaPluginConfig>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
/// Configuration for a column in the list format.
|
|
||||||
struct ColumnConfig {
|
|
||||||
name: String,
|
|
||||||
label: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
align: ColumnAlignment,
|
|
||||||
#[serde(default)]
|
|
||||||
max_len: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
/// Alignment options for table columns.
|
|
||||||
enum ColumnAlignment {
|
|
||||||
#[default]
|
|
||||||
Left,
|
|
||||||
Right,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
/// Server configuration options.
|
|
||||||
struct ServerConfig {
|
|
||||||
address: Option<String>,
|
|
||||||
port: Option<u16>,
|
|
||||||
password_file: Option<String>,
|
|
||||||
password: Option<String>,
|
|
||||||
password_hash: Option<String>,
|
|
||||||
cors_origin: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
/// Configuration for the compression plugin.
|
|
||||||
struct CompressionPluginConfig {
|
|
||||||
name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
/// Configuration for a meta plugin.
|
|
||||||
struct MetaPluginConfig {
|
|
||||||
name: String,
|
|
||||||
#[serde(default)]
|
|
||||||
options: std::collections::HashMap<String, serde_yaml::Value>,
|
|
||||||
#[serde(default)]
|
|
||||||
outputs: std::collections::HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates and prints a default commented YAML configuration template.
|
/// Generates and prints a default commented YAML configuration template.
|
||||||
///
|
///
|
||||||
/// Creates instances of available meta plugins to populate default options and outputs,
|
/// Discovers all registered meta plugins, filter plugins, and compression engines
|
||||||
/// then serializes the config to YAML with all lines commented for easy editing.
|
/// at runtime via the plugin schema system. Outputs a commented YAML template
|
||||||
|
/// with all available plugins and their default options/outputs.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
@@ -85,153 +21,244 @@ struct MetaPluginConfig {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// `Ok(())` on success.
|
/// `Ok(())` on success.
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// // Example usage requires Command and Settings instances
|
|
||||||
/// mode_generate_config(&mut cmd, &settings)?;
|
|
||||||
/// ```
|
|
||||||
pub fn mode_generate_config(_cmd: &mut Command, _settings: &crate::config::Settings) -> Result<()> {
|
pub fn mode_generate_config(_cmd: &mut Command, _settings: &crate::config::Settings) -> Result<()> {
|
||||||
// Create instances of each meta plugin to get their default options and outputs
|
let meta_schemas = gather_meta_plugin_schemas();
|
||||||
let cwd_plugin = crate::meta_plugin::cwd::CwdMetaPlugin::new(None, None);
|
let filter_schemas = gather_filter_plugin_schemas();
|
||||||
let digest_plugin = crate::meta_plugin::digest::DigestMetaPlugin::new(None, None);
|
|
||||||
let hostname_plugin = crate::meta_plugin::hostname::HostnameMetaPlugin::new(None, None);
|
|
||||||
#[cfg(feature = "magic")]
|
|
||||||
let magic_file_plugin = crate::meta_plugin::magic_file::MagicFileMetaPlugin::new(None, None);
|
|
||||||
let env_plugin = crate::meta_plugin::env::EnvMetaPlugin::new(None, None);
|
|
||||||
|
|
||||||
// Create a default configuration
|
// Build list_format defaults matching config.rs
|
||||||
let default_config = DefaultConfig {
|
let list_format = default_list_format();
|
||||||
dir: Some("~/.local/share/keep".to_string()),
|
|
||||||
list_format: vec![
|
|
||||||
ColumnConfig {
|
|
||||||
name: "id".to_string(),
|
|
||||||
label: Some("Item".to_string()),
|
|
||||||
align: ColumnAlignment::Right,
|
|
||||||
max_len: None,
|
|
||||||
},
|
|
||||||
ColumnConfig {
|
|
||||||
name: "time".to_string(),
|
|
||||||
label: Some("Time".to_string()),
|
|
||||||
align: ColumnAlignment::Right,
|
|
||||||
max_len: None,
|
|
||||||
},
|
|
||||||
ColumnConfig {
|
|
||||||
name: "size".to_string(),
|
|
||||||
label: Some("Size".to_string()),
|
|
||||||
align: ColumnAlignment::Right,
|
|
||||||
max_len: None,
|
|
||||||
},
|
|
||||||
ColumnConfig {
|
|
||||||
name: "tags".to_string(),
|
|
||||||
label: Some("Tags".to_string()),
|
|
||||||
align: ColumnAlignment::Left,
|
|
||||||
max_len: Some("40".to_string()),
|
|
||||||
},
|
|
||||||
ColumnConfig {
|
|
||||||
name: "meta:hostname_full".to_string(),
|
|
||||||
label: Some("Hostname".to_string()),
|
|
||||||
align: ColumnAlignment::Left,
|
|
||||||
max_len: Some("28".to_string()),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
human_readable: false,
|
|
||||||
output_format: Some("table".to_string()),
|
|
||||||
quiet: false,
|
|
||||||
force: false,
|
|
||||||
server: Some(ServerConfig {
|
|
||||||
address: Some("127.0.0.1".to_string()),
|
|
||||||
port: Some(8080),
|
|
||||||
password_file: None,
|
|
||||||
password: None,
|
|
||||||
password_hash: None,
|
|
||||||
cors_origin: None,
|
|
||||||
}),
|
|
||||||
compression_plugin: None,
|
|
||||||
meta_plugins: Some(vec![
|
|
||||||
MetaPluginConfig {
|
|
||||||
name: "cwd".to_string(),
|
|
||||||
options: cwd_plugin.options().clone(),
|
|
||||||
outputs: convert_outputs_to_string_map(cwd_plugin.outputs()),
|
|
||||||
},
|
|
||||||
MetaPluginConfig {
|
|
||||||
name: "digest".to_string(),
|
|
||||||
options: digest_plugin.options().clone(),
|
|
||||||
outputs: convert_outputs_to_string_map(digest_plugin.outputs()),
|
|
||||||
},
|
|
||||||
MetaPluginConfig {
|
|
||||||
name: "hostname".to_string(),
|
|
||||||
options: hostname_plugin.options().clone(),
|
|
||||||
outputs: convert_outputs_to_string_map(hostname_plugin.outputs()),
|
|
||||||
},
|
|
||||||
#[cfg(feature = "magic")]
|
|
||||||
MetaPluginConfig {
|
|
||||||
name: "magic_file".to_string(),
|
|
||||||
options: magic_file_plugin.options().clone(),
|
|
||||||
outputs: convert_outputs_to_string_map(magic_file_plugin.outputs()),
|
|
||||||
},
|
|
||||||
MetaPluginConfig {
|
|
||||||
name: "env".to_string(),
|
|
||||||
options: env_plugin.options().clone(),
|
|
||||||
outputs: convert_outputs_to_string_map(env_plugin.outputs()),
|
|
||||||
},
|
|
||||||
]),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Serialize to YAML and comment out all lines
|
// Build meta_plugins with env as the default (active), rest commented
|
||||||
let yaml = serde_yaml::to_string(&default_config)?;
|
let meta_plugins = build_meta_plugins_section(&meta_schemas);
|
||||||
|
|
||||||
// Comment out every line
|
// Build the full YAML
|
||||||
let commented_yaml = yaml
|
let mut lines = Vec::with_capacity(128);
|
||||||
.lines()
|
|
||||||
.map(|line| {
|
lines.push("# Keep configuration file".to_string());
|
||||||
if line.trim().is_empty() {
|
lines.push("# Uncomment and modify the settings you need.".to_string());
|
||||||
line.to_string()
|
lines.push(String::new());
|
||||||
} else {
|
|
||||||
format!("# {line}")
|
// Core settings
|
||||||
|
lines.push("# Data directory for storing items".to_string());
|
||||||
|
lines.push("dir: ~/.local/share/keep".to_string());
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
// List format
|
||||||
|
lines.push("# Column configuration for --list output".to_string());
|
||||||
|
lines.push("list_format:".to_string());
|
||||||
|
for col in &list_format {
|
||||||
|
lines.push(format!(" - name: {}", col.name));
|
||||||
|
lines.push(format!(" label: {}", col.label));
|
||||||
|
lines.push(format!(" align: {}", col.align));
|
||||||
}
|
}
|
||||||
})
|
lines.push(String::new());
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
println!("{commented_yaml}");
|
// Table config
|
||||||
|
lines.push("# Table display configuration".to_string());
|
||||||
|
lines.push("#table_config:".to_string());
|
||||||
|
lines.push("# style: nothing".to_string());
|
||||||
|
lines.push("# modifiers: []".to_string());
|
||||||
|
lines.push("# content_arrangement: dynamic".to_string());
|
||||||
|
lines.push("# truncination_indicator: \"\"".to_string());
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
// Other settings
|
||||||
|
lines.push("human_readable: false".to_string());
|
||||||
|
lines.push("output_format: table".to_string());
|
||||||
|
lines.push("quiet: false".to_string());
|
||||||
|
lines.push("force: false".to_string());
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
// Server config
|
||||||
|
lines.push("# Server configuration (only used with --server)".to_string());
|
||||||
|
lines.push("server:".to_string());
|
||||||
|
lines.push(" address: 127.0.0.1".to_string());
|
||||||
|
lines.push(" port: 8080".to_string());
|
||||||
|
lines.push("# username: keep".to_string());
|
||||||
|
lines.push("# password: null".to_string());
|
||||||
|
lines.push("# password_file: null".to_string());
|
||||||
|
lines.push("# password_hash: null".to_string());
|
||||||
|
lines.push("# jwt_secret: null".to_string());
|
||||||
|
lines.push("# jwt_secret_file: null".to_string());
|
||||||
|
lines.push("# cert_file: null".to_string());
|
||||||
|
lines.push("# key_file: null".to_string());
|
||||||
|
lines.push("# cors_origin: null".to_string());
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
// Compression plugin
|
||||||
|
lines.push("# Compression plugin to use".to_string());
|
||||||
|
lines.push("#compression_plugin:".to_string());
|
||||||
|
let mut comp_types: Vec<String> = CompressionType::iter().map(|ct| ct.to_string()).collect();
|
||||||
|
comp_types.sort();
|
||||||
|
for ct in &comp_types {
|
||||||
|
lines.push(format!("# name: {ct} # {}", compression_description(ct)));
|
||||||
|
}
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
// Meta plugins
|
||||||
|
lines.push("# Meta plugins to run when saving items".to_string());
|
||||||
|
lines.push("meta_plugins:".to_string());
|
||||||
|
for line in &meta_plugins {
|
||||||
|
lines.push(line.clone());
|
||||||
|
}
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
// Filter plugins reference
|
||||||
|
if !filter_schemas.is_empty() {
|
||||||
|
lines.push("# Available filter plugins (use with --filter)".to_string());
|
||||||
|
for schema in &filter_schemas {
|
||||||
|
lines.push(format!("# {}", schema.name));
|
||||||
|
if !schema.description.is_empty() {
|
||||||
|
lines.push(format!("# {}", schema.description));
|
||||||
|
}
|
||||||
|
for opt in &schema.options {
|
||||||
|
let req = if opt.required { "required" } else { "optional" };
|
||||||
|
lines.push(format!(
|
||||||
|
"# {} ({:?}, {})",
|
||||||
|
opt.name, opt.option_type, req
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lines.push(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client config
|
||||||
|
lines.push("# Client configuration (requires client feature)".to_string());
|
||||||
|
lines.push("#client:".to_string());
|
||||||
|
lines.push("# url: null".to_string());
|
||||||
|
lines.push("# username: null".to_string());
|
||||||
|
lines.push("# password: null".to_string());
|
||||||
|
lines.push("# jwt: null".to_string());
|
||||||
|
|
||||||
|
// Print
|
||||||
|
for line in &lines {
|
||||||
|
println!("{line}");
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper function to convert outputs from serde_yaml::Value to String.
|
struct ListColumn {
|
||||||
///
|
name: String,
|
||||||
/// Handles null (uses key), strings, and other values by serializing to YAML string.
|
label: String,
|
||||||
///
|
align: String,
|
||||||
/// # Arguments
|
}
|
||||||
///
|
|
||||||
/// * `outputs` - Reference to the outputs HashMap.
|
fn default_list_format() -> Vec<ListColumn> {
|
||||||
///
|
vec![
|
||||||
/// # Returns
|
ListColumn {
|
||||||
///
|
name: "id".into(),
|
||||||
/// A HashMap with string keys and values.
|
label: "Item".into(),
|
||||||
fn convert_outputs_to_string_map(
|
align: "right".into(),
|
||||||
outputs: &std::collections::HashMap<String, serde_yaml::Value>,
|
},
|
||||||
) -> std::collections::HashMap<String, String> {
|
ListColumn {
|
||||||
let mut result = std::collections::HashMap::new();
|
name: "time".into(),
|
||||||
for (key, value) in outputs {
|
label: "Time".into(),
|
||||||
match value {
|
align: "right".into(),
|
||||||
serde_yaml::Value::Null => {
|
},
|
||||||
// For null, use the key as the value
|
ListColumn {
|
||||||
result.insert(key.clone(), key.clone());
|
name: "size".into(),
|
||||||
}
|
label: "Size".into(),
|
||||||
serde_yaml::Value::String(s) => {
|
align: "right".into(),
|
||||||
result.insert(key.clone(), s.clone());
|
},
|
||||||
}
|
ListColumn {
|
||||||
_ => {
|
name: "meta:text_line_count".into(),
|
||||||
// Convert other values to their YAML string representation
|
label: "Lines".into(),
|
||||||
result.insert(
|
align: "right".into(),
|
||||||
key.clone(),
|
},
|
||||||
serde_yaml::to_string(value).unwrap_or_default(),
|
ListColumn {
|
||||||
);
|
name: "tags".into(),
|
||||||
}
|
label: "Tags".into(),
|
||||||
}
|
align: "left".into(),
|
||||||
}
|
},
|
||||||
result
|
ListColumn {
|
||||||
|
name: "meta:hostname_short".into(),
|
||||||
|
label: "Host".into(),
|
||||||
|
align: "left".into(),
|
||||||
|
},
|
||||||
|
ListColumn {
|
||||||
|
name: "meta:command".into(),
|
||||||
|
label: "Command".into(),
|
||||||
|
align: "left".into(),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_meta_plugins_section(schemas: &[crate::common::schema::PluginSchema]) -> Vec<String> {
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
|
||||||
|
for (i, schema) in schemas.iter().enumerate() {
|
||||||
|
let is_default = schema.name == "env";
|
||||||
|
let prefix = if is_default { "" } else { "# " };
|
||||||
|
|
||||||
|
if i > 0 {
|
||||||
|
lines.push(format!("{prefix}# --- {name} ---", name = schema.name));
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push(format!("{prefix}- name: {}", schema.name));
|
||||||
|
|
||||||
|
// Options
|
||||||
|
if !schema.options.is_empty() {
|
||||||
|
lines.push(format!("{prefix} options:"));
|
||||||
|
for opt in &schema.options {
|
||||||
|
if let Some(ref default) = opt.default {
|
||||||
|
let default_str = format_yaml_value(default);
|
||||||
|
lines.push(format!("{prefix} {}: {}", opt.name, default_str));
|
||||||
|
} else if opt.required {
|
||||||
|
lines.push(format!("{prefix} {}: null # required", opt.name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lines.push(format!("{prefix} options: {{}}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Outputs
|
||||||
|
if !schema.outputs.is_empty() {
|
||||||
|
lines.push(format!("{prefix} outputs:"));
|
||||||
|
for output in &schema.outputs {
|
||||||
|
lines.push(format!("{prefix} {}: {}", output.name, output.name));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lines.push(format!("{prefix} outputs: {{}}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_yaml_value(value: &serde_yaml::Value) -> String {
|
||||||
|
match value {
|
||||||
|
serde_yaml::Value::Null => "null".into(),
|
||||||
|
serde_yaml::Value::Bool(b) => b.to_string(),
|
||||||
|
serde_yaml::Value::Number(n) => n.to_string(),
|
||||||
|
serde_yaml::Value::String(s) => {
|
||||||
|
if s.contains(' ') || s.contains(':') || s.contains('#') {
|
||||||
|
format!("\"{s}\"")
|
||||||
|
} else {
|
||||||
|
s.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
serde_yaml::Value::Sequence(_) | serde_yaml::Value::Mapping(_) => {
|
||||||
|
serde_yaml::to_string(value)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.trim()
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
serde_yaml::Value::Tagged(_) => serde_yaml::to_string(value)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.trim()
|
||||||
|
.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compression_description(name: &str) -> &str {
|
||||||
|
match name {
|
||||||
|
"lz4" => "Fast compression (native)",
|
||||||
|
"gzip" => "Good compression ratio (native)",
|
||||||
|
"bzip2" => "High compression (requires bzip2 binary)",
|
||||||
|
"xz" => "Very high compression (requires xz binary)",
|
||||||
|
"zstd" => "Modern fast compression (requires zstd binary)",
|
||||||
|
"none" => "No compression",
|
||||||
|
_ => "",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
231
src/modes/server/auth.rs
Normal file
231
src/modes/server/auth.rs
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
use axum::http::Method;
|
||||||
|
use jsonwebtoken::{DecodingKey, TokenData, Validation, decode};
|
||||||
|
use log::debug;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
/// JWT claims for permission-based access control.
|
||||||
|
///
|
||||||
|
/// External token generators should include these claims in the JWT payload.
|
||||||
|
/// The server validates the signature and checks permissions for each request.
|
||||||
|
///
|
||||||
|
/// # Example token payload
|
||||||
|
///
|
||||||
|
/// ```json
|
||||||
|
/// {
|
||||||
|
/// "sub": "my-client",
|
||||||
|
/// "exp": 1735689600,
|
||||||
|
/// "read": true,
|
||||||
|
/// "write": true,
|
||||||
|
/// "delete": false
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Claims {
|
||||||
|
/// Subject (client identifier).
|
||||||
|
pub sub: String,
|
||||||
|
/// Expiration time (Unix timestamp).
|
||||||
|
pub exp: usize,
|
||||||
|
/// Read permission (GET requests).
|
||||||
|
#[serde(default)]
|
||||||
|
pub read: bool,
|
||||||
|
/// Write permission (POST/PUT requests).
|
||||||
|
#[serde(default)]
|
||||||
|
pub write: bool,
|
||||||
|
/// Delete permission (DELETE requests).
|
||||||
|
#[serde(default)]
|
||||||
|
pub delete: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the required permission for an HTTP method.
|
||||||
|
///
|
||||||
|
/// # Mapping
|
||||||
|
///
|
||||||
|
/// - GET, HEAD → "read"
|
||||||
|
/// - POST, PUT, PATCH → "write"
|
||||||
|
/// - DELETE → "delete"
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `method` - The HTTP method of the incoming request.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A string slice representing the required permission.
|
||||||
|
pub fn required_permission(method: &Method) -> &'static str {
|
||||||
|
if method == Method::GET || method == Method::HEAD {
|
||||||
|
"read"
|
||||||
|
} else if method == Method::DELETE {
|
||||||
|
"delete"
|
||||||
|
} else {
|
||||||
|
"write"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if the JWT claims grant the required permission.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `claims` - The validated JWT claims.
|
||||||
|
/// * `permission` - The required permission string ("read", "write", or "delete").
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// `true` if the claims grant the permission, `false` otherwise.
|
||||||
|
pub fn check_permission(claims: &Claims, permission: &str) -> bool {
|
||||||
|
match permission {
|
||||||
|
"read" => claims.read,
|
||||||
|
"write" => claims.write,
|
||||||
|
"delete" => claims.delete,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates a JWT token and returns the claims.
|
||||||
|
///
|
||||||
|
/// Uses HMAC-SHA256 signature verification with the provided secret.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `token` - The JWT token string (without "Bearer " prefix).
|
||||||
|
/// * `secret` - The secret key used to verify the signature.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `Ok(Claims)` - The validated claims if the token is valid.
|
||||||
|
/// * `Err(String)` - A human-readable error message if validation fails.
|
||||||
|
pub fn validate_jwt(token: &str, secret: &str) -> Result<Claims, String> {
|
||||||
|
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||||
|
validation.algorithms = vec![jsonwebtoken::Algorithm::HS256];
|
||||||
|
validation.set_required_spec_claims(&["exp", "sub"]);
|
||||||
|
|
||||||
|
let token_data: TokenData<Claims> = decode::<Claims>(
|
||||||
|
token,
|
||||||
|
&DecodingKey::from_secret(secret.as_bytes()),
|
||||||
|
&validation,
|
||||||
|
)
|
||||||
|
.map_err(|e| {
|
||||||
|
debug!("JWT validation failed: {e}");
|
||||||
|
match e.kind() {
|
||||||
|
jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token expired".to_string(),
|
||||||
|
jsonwebtoken::errors::ErrorKind::InvalidSignature => "Invalid token".to_string(),
|
||||||
|
jsonwebtoken::errors::ErrorKind::InvalidToken => "Malformed token".to_string(),
|
||||||
|
jsonwebtoken::errors::ErrorKind::ImmatureSignature => "Token not yet valid".to_string(),
|
||||||
|
_ => "Invalid token".to_string(),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(token_data.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use jsonwebtoken::{EncodingKey, Header, encode};
|
||||||
|
|
||||||
|
fn make_token(claims: &serde_json::Value, secret: &str) -> String {
|
||||||
|
let header = Header::new(jsonwebtoken::Algorithm::HS256);
|
||||||
|
encode(
|
||||||
|
&header,
|
||||||
|
claims,
|
||||||
|
&EncodingKey::from_secret(secret.as_bytes()),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_valid_token() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 9999999999usize,
|
||||||
|
"read": true,
|
||||||
|
"write": true,
|
||||||
|
"delete": false
|
||||||
|
});
|
||||||
|
let token = make_token(&claims, secret);
|
||||||
|
|
||||||
|
let result = validate_jwt(&token, secret);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let claims = result.unwrap();
|
||||||
|
assert_eq!(claims.sub, "test-client");
|
||||||
|
assert!(claims.read);
|
||||||
|
assert!(claims.write);
|
||||||
|
assert!(!claims.delete);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_expired_token() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 1000000000usize,
|
||||||
|
"read": true
|
||||||
|
});
|
||||||
|
let token = make_token(&claims, secret);
|
||||||
|
|
||||||
|
let result = validate_jwt(&token, secret);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(result.unwrap_err(), "Token expired");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_wrong_secret() {
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 9999999999usize,
|
||||||
|
"read": true
|
||||||
|
});
|
||||||
|
let token = make_token(&claims, "correct-secret");
|
||||||
|
|
||||||
|
let result = validate_jwt(&token, "wrong-secret");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_malformed_token() {
|
||||||
|
let result = validate_jwt("not.a.jwt", "secret");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_required_permission() {
|
||||||
|
assert_eq!(required_permission(&Method::GET), "read");
|
||||||
|
assert_eq!(required_permission(&Method::HEAD), "read");
|
||||||
|
assert_eq!(required_permission(&Method::POST), "write");
|
||||||
|
assert_eq!(required_permission(&Method::PUT), "write");
|
||||||
|
assert_eq!(required_permission(&Method::PATCH), "write");
|
||||||
|
assert_eq!(required_permission(&Method::DELETE), "delete");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_check_permission() {
|
||||||
|
let claims = Claims {
|
||||||
|
sub: "test".to_string(),
|
||||||
|
exp: 9999999999,
|
||||||
|
read: true,
|
||||||
|
write: false,
|
||||||
|
delete: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(check_permission(&claims, "read"));
|
||||||
|
assert!(!check_permission(&claims, "write"));
|
||||||
|
assert!(check_permission(&claims, "delete"));
|
||||||
|
assert!(!check_permission(&claims, "unknown"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_check_permission_default_false() {
|
||||||
|
// When fields are missing from JSON, serde(default) makes them false
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 9999999999usize
|
||||||
|
});
|
||||||
|
let token = make_token(&claims, secret);
|
||||||
|
|
||||||
|
let claims = validate_jwt(&token, secret).unwrap();
|
||||||
|
assert!(!claims.read);
|
||||||
|
assert!(!claims.write);
|
||||||
|
assert!(!claims.delete);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,7 +15,7 @@ use crate::services::item_service::ItemService;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{ConnectInfo, Request},
|
extract::{ConnectInfo, Request},
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, Method, StatusCode},
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::Response,
|
response::Response,
|
||||||
};
|
};
|
||||||
@@ -42,8 +42,13 @@ use utoipa::ToSchema;
|
|||||||
/// let config = ServerConfig {
|
/// let config = ServerConfig {
|
||||||
/// address: "127.0.0.1".to_string(),
|
/// address: "127.0.0.1".to_string(),
|
||||||
/// port: Some(8080),
|
/// port: Some(8080),
|
||||||
|
/// username: None,
|
||||||
/// password: None,
|
/// password: None,
|
||||||
/// password_hash: None,
|
/// password_hash: None,
|
||||||
|
/// jwt_secret: None,
|
||||||
|
/// cert_file: None,
|
||||||
|
/// key_file: None,
|
||||||
|
/// cors_origin: None,
|
||||||
/// };
|
/// };
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -58,9 +63,13 @@ pub struct ServerConfig {
|
|||||||
/// The TCP port number to listen on. If not specified, a default port (typically
|
/// The TCP port number to listen on. If not specified, a default port (typically
|
||||||
/// 8080 or 21080) will be used.
|
/// 8080 or 21080) will be used.
|
||||||
pub port: Option<u16>,
|
pub port: Option<u16>,
|
||||||
|
/// Optional authentication username.
|
||||||
|
///
|
||||||
|
/// Username for Basic authentication. Defaults to "keep" when not specified.
|
||||||
|
pub username: Option<String>,
|
||||||
/// Optional authentication password.
|
/// Optional authentication password.
|
||||||
///
|
///
|
||||||
/// Plain text password for basic or bearer token authentication. This should be
|
/// Plain text password for Basic authentication. This should be
|
||||||
/// used only for testing or low-security environments.
|
/// used only for testing or low-security environments.
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
/// Optional hashed authentication password.
|
/// Optional hashed authentication password.
|
||||||
@@ -68,6 +77,11 @@ pub struct ServerConfig {
|
|||||||
/// Pre-hashed password (Unix crypt format) for secure authentication. Preferred
|
/// Pre-hashed password (Unix crypt format) for secure authentication. Preferred
|
||||||
/// over plain text password for production use.
|
/// over plain text password for production use.
|
||||||
pub password_hash: Option<String>,
|
pub password_hash: Option<String>,
|
||||||
|
/// Optional JWT secret for token-based authentication.
|
||||||
|
///
|
||||||
|
/// When set, the server validates JWT tokens (HS256) and checks permission claims
|
||||||
|
/// (read, write, delete) for each request. Takes priority over password auth.
|
||||||
|
pub jwt_secret: Option<String>,
|
||||||
/// Optional path to TLS certificate file (PEM).
|
/// Optional path to TLS certificate file (PEM).
|
||||||
///
|
///
|
||||||
/// When both cert_file and key_file are set, the server uses HTTPS.
|
/// When both cert_file and key_file are set, the server uses HTTPS.
|
||||||
@@ -633,56 +647,59 @@ pub struct CreateItemRequest {
|
|||||||
pub metadata: Option<std::collections::HashMap<String, String>>,
|
pub metadata: Option<std::collections::HashMap<String, String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates bearer authentication token.
|
/// Checks authorization header for valid credentials.
|
||||||
///
|
///
|
||||||
/// This function checks if the provided authorization string is a valid Bearer token
|
/// This function inspects the HTTP Authorization header for valid Basic
|
||||||
/// matching the expected password or hash.
|
/// authentication credentials against the provided username and password or hash.
|
||||||
|
/// Bearer tokens are not checked here — JWT validation is handled separately
|
||||||
|
/// in the middleware.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `auth_str` - The authorization string from the header.
|
/// * `headers` - HTTP headers from the request.
|
||||||
/// * `expected_password` - The expected plain text password.
|
/// * `username` - Optional expected username (defaults to "keep").
|
||||||
/// * `expected_hash` - Optional expected password hash.
|
/// * `password` - Optional expected password.
|
||||||
|
/// * `password_hash` - Optional expected password hash.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * `true` - If authentication succeeds.
|
/// * `true` - If authorized (or no auth required).
|
||||||
/// * `false` - Otherwise.
|
/// * `false` - If unauthorized.
|
||||||
///
|
pub fn check_auth(
|
||||||
/// # Errors
|
headers: &HeaderMap,
|
||||||
///
|
username: &Option<String>,
|
||||||
/// None; returns false on failure.
|
password: &Option<String>,
|
||||||
fn check_bearer_auth(
|
password_hash: &Option<String>,
|
||||||
auth_str: &str,
|
|
||||||
expected_password: &str,
|
|
||||||
expected_hash: &Option<String>,
|
|
||||||
) -> bool {
|
) -> bool {
|
||||||
if !auth_str.starts_with("Bearer ") {
|
// If neither password nor hash is set, no authentication required
|
||||||
return false;
|
if password.is_none() && password_hash.is_none() {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
let provided_password = &auth_str[7..];
|
let effective_username = username.as_deref().unwrap_or("keep");
|
||||||
|
|
||||||
// If we have a password hash, verify against it
|
if let Some(auth_header) = headers.get("authorization") {
|
||||||
if let Some(hash) = expected_hash {
|
if let Ok(auth_str) = auth_header.to_str() {
|
||||||
return pwhash::unix::verify(provided_password, hash);
|
return check_basic_auth(
|
||||||
|
auth_str,
|
||||||
|
effective_username,
|
||||||
|
password.as_deref().unwrap_or(""),
|
||||||
|
password_hash,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// Otherwise, do constant-time comparison to prevent timing attacks
|
false
|
||||||
provided_password
|
|
||||||
.as_bytes()
|
|
||||||
.ct_eq(expected_password.as_bytes())
|
|
||||||
.into()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates basic authentication credentials.
|
/// Validates basic authentication credentials.
|
||||||
///
|
///
|
||||||
/// This function decodes and validates Basic Auth credentials from the authorization
|
/// This function decodes and validates Basic Auth credentials from the authorization
|
||||||
/// header against the expected password or hash.
|
/// header against the expected username and password or hash.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `auth_str` - The authorization string from the header.
|
/// * `auth_str` - The authorization string from the header.
|
||||||
|
/// * `expected_username` - The expected username.
|
||||||
/// * `expected_password` - The expected plain text password.
|
/// * `expected_password` - The expected plain text password.
|
||||||
/// * `expected_hash` - Optional expected password hash.
|
/// * `expected_hash` - Optional expected password hash.
|
||||||
///
|
///
|
||||||
@@ -696,6 +713,7 @@ fn check_bearer_auth(
|
|||||||
/// Returns false on decode or validation failure.
|
/// Returns false on decode or validation failure.
|
||||||
fn check_basic_auth(
|
fn check_basic_auth(
|
||||||
auth_str: &str,
|
auth_str: &str,
|
||||||
|
expected_username: &str,
|
||||||
expected_password: &str,
|
expected_password: &str,
|
||||||
expected_hash: &Option<String>,
|
expected_hash: &Option<String>,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
@@ -707,67 +725,35 @@ fn check_basic_auth(
|
|||||||
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
|
if let Ok(decoded_bytes) = base64::engine::general_purpose::STANDARD.decode(encoded) {
|
||||||
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
|
if let Ok(decoded_str) = String::from_utf8(decoded_bytes) {
|
||||||
if let Some(colon_pos) = decoded_str.find(':') {
|
if let Some(colon_pos) = decoded_str.find(':') {
|
||||||
|
let provided_username = &decoded_str[..colon_pos];
|
||||||
let provided_password = &decoded_str[colon_pos + 1..];
|
let provided_password = &decoded_str[colon_pos + 1..];
|
||||||
|
|
||||||
|
// Check username with constant-time comparison
|
||||||
|
if !bool::from(
|
||||||
|
provided_username
|
||||||
|
.as_bytes()
|
||||||
|
.ct_eq(expected_username.as_bytes()),
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// If we have a password hash, verify against it
|
// If we have a password hash, verify against it
|
||||||
if let Some(hash) = expected_hash {
|
if let Some(hash) = expected_hash {
|
||||||
return pwhash::unix::verify(provided_password, hash);
|
return pwhash::unix::verify(provided_password, hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, do constant-time comparison to prevent timing attacks
|
// Otherwise, do constant-time comparison to prevent timing attacks
|
||||||
let expected_credentials = format!("keep:{expected_password}");
|
return bool::from(
|
||||||
return decoded_str
|
provided_password
|
||||||
.as_bytes()
|
.as_bytes()
|
||||||
.ct_eq(expected_credentials.as_bytes())
|
.ct_eq(expected_password.as_bytes()),
|
||||||
.into();
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Checks authorization header for valid credentials.
|
|
||||||
///
|
|
||||||
/// This function inspects the HTTP Authorization header for valid Bearer or Basic
|
|
||||||
/// authentication credentials against the provided password or hash.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `headers` - HTTP headers from the request.
|
|
||||||
/// * `password` - Optional expected password.
|
|
||||||
/// * `password_hash` - Optional expected password hash.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `true` - If authorized (or no auth required).
|
|
||||||
/// * `false` - If unauthorized.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// if check_auth(&headers, &Some("pass".to_string()), &None) {
|
|
||||||
/// // Proceed
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn check_auth(
|
|
||||||
headers: &HeaderMap,
|
|
||||||
password: &Option<String>,
|
|
||||||
password_hash: &Option<String>,
|
|
||||||
) -> bool {
|
|
||||||
// If neither password nor hash is set, no authentication required
|
|
||||||
if password.is_none() && password_hash.is_none() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(auth_header) = headers.get("authorization") {
|
|
||||||
if let Ok(auth_str) = auth_header.to_str() {
|
|
||||||
return check_bearer_auth(auth_str, password.as_deref().unwrap_or(""), password_hash)
|
|
||||||
|| check_basic_auth(auth_str, password.as_deref().unwrap_or(""), password_hash);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Middleware for logging requests and responses.
|
/// Middleware for logging requests and responses.
|
||||||
///
|
///
|
||||||
/// This middleware logs incoming requests and outgoing responses, including method,
|
/// This middleware logs incoming requests and outgoing responses, including method,
|
||||||
@@ -830,14 +816,17 @@ pub async fn logging_middleware(
|
|||||||
|
|
||||||
/// Creates authentication middleware for the application.
|
/// Creates authentication middleware for the application.
|
||||||
///
|
///
|
||||||
/// This function returns a middleware that enforces authentication on protected routes
|
/// This function returns a middleware that enforces authentication on protected routes.
|
||||||
/// using Bearer token or Basic Auth, challenging unauthorized requests with appropriate
|
/// When `jwt_secret` is set, it validates JWT tokens and checks permission claims
|
||||||
/// headers.
|
/// (read, write, delete) based on the HTTP method. Otherwise, it falls back to
|
||||||
|
/// Basic Auth password authentication.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
|
/// * `username` - Optional username (defaults to "keep").
|
||||||
/// * `password` - Optional plain text password.
|
/// * `password` - Optional plain text password.
|
||||||
/// * `password_hash` - Optional hashed password.
|
/// * `password_hash` - Optional hashed password.
|
||||||
|
/// * `jwt_secret` - Optional JWT secret for token-based authentication.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
@@ -846,13 +835,15 @@ pub async fn logging_middleware(
|
|||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// let auth_middleware = create_auth_middleware(Some("pass".to_string()), None);
|
/// let auth_middleware = create_auth_middleware(None, Some("pass".to_string()), None, None);
|
||||||
/// router.layer(auth_middleware);
|
/// router.layer(auth_middleware);
|
||||||
/// ```
|
/// ```
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
pub fn create_auth_middleware(
|
pub fn create_auth_middleware(
|
||||||
|
username: Option<String>,
|
||||||
password: Option<String>,
|
password: Option<String>,
|
||||||
password_hash: Option<String>,
|
password_hash: Option<String>,
|
||||||
|
jwt_secret: Option<String>,
|
||||||
) -> impl Fn(
|
) -> impl Fn(
|
||||||
ConnectInfo<SocketAddr>,
|
ConnectInfo<SocketAddr>,
|
||||||
Request,
|
Request,
|
||||||
@@ -862,13 +853,63 @@ pub fn create_auth_middleware(
|
|||||||
+ Clone
|
+ Clone
|
||||||
+ Send {
|
+ Send {
|
||||||
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, request: Request, next: Next| {
|
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, request: Request, next: Next| {
|
||||||
|
let username = username.clone();
|
||||||
let password = password.clone();
|
let password = password.clone();
|
||||||
let password_hash = password_hash.clone();
|
let password_hash = password_hash.clone();
|
||||||
|
let jwt_secret = jwt_secret.clone();
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let headers = request.headers().clone();
|
let headers = request.headers().clone();
|
||||||
let uri = request.uri().clone();
|
let uri = request.uri().clone();
|
||||||
|
let method = request.method().clone();
|
||||||
|
|
||||||
if !check_auth(&headers, &password, &password_hash) {
|
// CORS preflight requests pass through without authentication
|
||||||
|
if method == Method::OPTIONS {
|
||||||
|
return Ok(next.run(request).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT authentication takes priority when secret is configured
|
||||||
|
if let Some(ref secret) = jwt_secret {
|
||||||
|
if let Some(auth_header) = headers.get("authorization") {
|
||||||
|
if let Ok(auth_str) = auth_header.to_str() {
|
||||||
|
if let Some(token) = auth_str.strip_prefix("Bearer ") {
|
||||||
|
match super::auth::validate_jwt(token, secret) {
|
||||||
|
Ok(claims) => {
|
||||||
|
let required = super::auth::required_permission(&method);
|
||||||
|
if !super::auth::check_permission(&claims, required) {
|
||||||
|
warn!(
|
||||||
|
"Forbidden: {method} {uri} from {addr} \
|
||||||
|
(sub={}, missing permission: {required})",
|
||||||
|
claims.sub
|
||||||
|
);
|
||||||
|
let mut response =
|
||||||
|
Response::new(axum::body::Body::from("Forbidden"));
|
||||||
|
*response.status_mut() = StatusCode::FORBIDDEN;
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
// JWT valid and authorized, proceed
|
||||||
|
let response = next.run(request).await;
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("JWT validation failed for {uri} from {addr}: {e}");
|
||||||
|
let mut response =
|
||||||
|
Response::new(axum::body::Body::from("Unauthorized"));
|
||||||
|
*response.status_mut() = StatusCode::UNAUTHORIZED;
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// JWT secret configured but no valid Bearer token provided
|
||||||
|
warn!("Missing JWT token for {uri} from {addr}");
|
||||||
|
let mut response = Response::new(axum::body::Body::from("Unauthorized"));
|
||||||
|
*response.status_mut() = StatusCode::UNAUTHORIZED;
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to Basic Auth password authentication
|
||||||
|
if !check_auth(&headers, &username, &password, &password_hash) {
|
||||||
warn!("Unauthorized request to {uri} from {addr}");
|
warn!("Unauthorized request to {uri} from {addr}");
|
||||||
// Add WWW-Authenticate header to trigger basic auth in browsers
|
// Add WWW-Authenticate header to trigger basic auth in browsers
|
||||||
let mut response = Response::new(axum::body::Body::from("Unauthorized"));
|
let mut response = Response::new(axum::body::Body::from("Unauthorized"));
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ use tower_http::cors::CorsLayer;
|
|||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
mod api;
|
mod api;
|
||||||
|
pub mod auth;
|
||||||
pub mod common;
|
pub mod common;
|
||||||
#[cfg(feature = "mcp")]
|
#[cfg(feature = "mcp")]
|
||||||
mod mcp;
|
mod mcp;
|
||||||
@@ -50,8 +51,10 @@ pub fn mode_server(
|
|||||||
let server_config = common::ServerConfig {
|
let server_config = common::ServerConfig {
|
||||||
address: server_address,
|
address: server_address,
|
||||||
port: Some(server_port),
|
port: Some(server_port),
|
||||||
|
username: settings.server_username(),
|
||||||
password: settings.server_password(),
|
password: settings.server_password(),
|
||||||
password_hash: settings.server_password_hash(),
|
password_hash: settings.server_password_hash(),
|
||||||
|
jwt_secret: settings.server_jwt_secret(),
|
||||||
cert_file: settings.server_cert_file(),
|
cert_file: settings.server_cert_file(),
|
||||||
key_file: settings.server_key_file(),
|
key_file: settings.server_key_file(),
|
||||||
cors_origin: settings.server_cors_origin(),
|
cors_origin: settings.server_cors_origin(),
|
||||||
@@ -119,9 +122,13 @@ async fn run_server(
|
|||||||
protected_router = protected_router.merge(mcp_router);
|
protected_router = protected_router.merge(mcp_router);
|
||||||
}
|
}
|
||||||
|
|
||||||
let protected_router = protected_router.layer(axum::middleware::from_fn(
|
let protected_router =
|
||||||
create_auth_middleware(config.password.clone(), config.password_hash.clone()),
|
protected_router.layer(axum::middleware::from_fn(create_auth_middleware(
|
||||||
));
|
config.username.clone(),
|
||||||
|
config.password.clone(),
|
||||||
|
config.password_hash.clone(),
|
||||||
|
config.jwt_secret.clone(),
|
||||||
|
)));
|
||||||
|
|
||||||
// Build CORS layer - restricted by default, configurable via cors_origin setting
|
// Build CORS layer - restricted by default, configurable via cors_origin setting
|
||||||
let cors_origin = config.cors_origin.as_deref().unwrap_or("http://localhost");
|
let cors_origin = config.cors_origin.as_deref().unwrap_or("http://localhost");
|
||||||
@@ -166,16 +173,16 @@ async fn run_server(
|
|||||||
|
|
||||||
let addr: SocketAddr = bind_address.parse()?;
|
let addr: SocketAddr = bind_address.parse()?;
|
||||||
|
|
||||||
// Warn if password auth is enabled without TLS
|
// Warn if authentication is enabled without TLS
|
||||||
if config.password.is_some() || config.password_hash.is_some() {
|
if config.password.is_some() || config.password_hash.is_some() || config.jwt_secret.is_some() {
|
||||||
#[cfg(not(feature = "tls"))]
|
#[cfg(not(feature = "tls"))]
|
||||||
log::warn!(
|
log::warn!(
|
||||||
"SECURITY: Password authentication enabled but TLS support is not compiled in. Password will be transmitted in plain text!"
|
"SECURITY: Authentication enabled but TLS support is not compiled in. Credentials will be transmitted in plain text!"
|
||||||
);
|
);
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
if config.cert_file.is_none() || config.key_file.is_none() {
|
if config.cert_file.is_none() || config.key_file.is_none() {
|
||||||
log::warn!(
|
log::warn!(
|
||||||
"SECURITY: Password authentication enabled but TLS is not configured. Password will be transmitted in plain text!"
|
"SECURITY: Authentication enabled but TLS is not configured. Credentials will be transmitted in plain text!"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,27 @@
|
|||||||
use crate::config::Settings;
|
use crate::config::Settings;
|
||||||
use crate::meta_plugin::{MetaPlugin, MetaPluginType};
|
use crate::meta_plugin::{MetaPlugin, MetaPluginResponse, MetaPluginType};
|
||||||
use crate::modes::common::settings_meta_plugin_types;
|
use crate::modes::common::settings_meta_plugin_types;
|
||||||
use clap::Command;
|
use clap::Command;
|
||||||
use log::debug;
|
use log::{debug, error};
|
||||||
use rusqlite::Connection;
|
use rusqlite::Connection;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub struct MetaService;
|
pub struct MetaService;
|
||||||
|
|
||||||
|
/// Sentinel plugin used as a placeholder when extracting plugins for parallel
|
||||||
|
/// execution. The original plugin is written back immediately after the threads
|
||||||
|
/// complete. Never leaks into the DB or visible state.
|
||||||
|
struct NullMetaPlugin;
|
||||||
|
impl MetaPlugin for NullMetaPlugin {
|
||||||
|
fn meta_type(&self) -> MetaPluginType {
|
||||||
|
MetaPluginType::Digest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replace_plugin(plugins: &mut [Box<dyn MetaPlugin>], i: usize) -> Box<dyn MetaPlugin> {
|
||||||
|
std::mem::replace(&mut plugins[i], Box::new(NullMetaPlugin))
|
||||||
|
}
|
||||||
|
|
||||||
impl MetaService {
|
impl MetaService {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self
|
Self
|
||||||
@@ -77,16 +91,14 @@ impl MetaService {
|
|||||||
|
|
||||||
for plugin in plugins.iter() {
|
for plugin in plugins.iter() {
|
||||||
let plugin_name = plugin.meta_type().to_string();
|
let plugin_name = plugin.meta_type().to_string();
|
||||||
// For each plugin, collect all the output names it might write to
|
|
||||||
for (internal_name, output_config) in plugin.outputs() {
|
for (internal_name, output_config) in plugin.outputs() {
|
||||||
let output_name = match output_config {
|
let output_name = match output_config {
|
||||||
serde_yaml::Value::String(remapped_name) => remapped_name.clone(),
|
serde_yaml::Value::String(remapped_name) => remapped_name.clone(),
|
||||||
serde_yaml::Value::Bool(true) => internal_name.clone(),
|
serde_yaml::Value::Bool(true) => internal_name.clone(),
|
||||||
serde_yaml::Value::Bool(false) => continue, // This output is disabled
|
serde_yaml::Value::Bool(false) => continue,
|
||||||
_ => internal_name.clone(), // Default to internal name for other types
|
_ => internal_name.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only track outputs that will actually be written
|
|
||||||
if !matches!(output_config, serde_yaml::Value::Bool(false)) {
|
if !matches!(output_config, serde_yaml::Value::Bool(false)) {
|
||||||
output_names
|
output_names
|
||||||
.entry(output_name)
|
.entry(output_name)
|
||||||
@@ -96,7 +108,6 @@ impl MetaService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print warnings for duplicate output names
|
|
||||||
for (output_name, plugin_names) in &output_names {
|
for (output_name, plugin_names) in &output_names {
|
||||||
if plugin_names.len() > 1 {
|
if plugin_names.len() > 1 {
|
||||||
log::warn!(
|
log::warn!(
|
||||||
@@ -107,9 +118,68 @@ impl MetaService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for meta_plugin in plugins.iter_mut() {
|
// Partition into parallel-safe and sequential indices
|
||||||
let response = meta_plugin.initialize();
|
let (parallel_idx, sequential_idx): (Vec<usize>, Vec<usize>) = plugins
|
||||||
self.process_plugin_response(conn, item_id, &mut **meta_plugin, response);
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, p)| !p.is_finalized())
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.partition(|&i| plugins[i].parallel_safe());
|
||||||
|
|
||||||
|
// Run parallel-safe plugins concurrently
|
||||||
|
if !parallel_idx.is_empty() {
|
||||||
|
// Extract plugins by unique index into a flat Vec indexed by position
|
||||||
|
let mut parallel_plugins: Vec<Box<dyn MetaPlugin>> =
|
||||||
|
Vec::with_capacity(parallel_idx.len());
|
||||||
|
for &i in ¶llel_idx {
|
||||||
|
parallel_plugins.push(replace_plugin(plugins, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write results back to original slots sequentially (DB writes are serial)
|
||||||
|
let (results, panicked): (Vec<(usize, MetaPluginResponse)>, Vec<usize>) =
|
||||||
|
std::thread::scope(|s| {
|
||||||
|
let handles: Vec<_> = parallel_plugins
|
||||||
|
.iter_mut()
|
||||||
|
.map(|plugin| s.spawn(move || plugin.initialize()))
|
||||||
|
.collect();
|
||||||
|
let mut results = Vec::with_capacity(handles.len());
|
||||||
|
let mut panicked = Vec::new();
|
||||||
|
for (j, handle) in handles.into_iter().enumerate() {
|
||||||
|
match handle.join() {
|
||||||
|
Ok(response) => results.push((j, response)),
|
||||||
|
Err(e) => {
|
||||||
|
error!("META_SERVICE: Plugin panicked during initialize: {e:?}");
|
||||||
|
panicked.push(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(results, panicked)
|
||||||
|
});
|
||||||
|
|
||||||
|
for (j, response) in results {
|
||||||
|
store_plugin_metadata(conn, item_id, &response);
|
||||||
|
let mut plugin = replace_plugin(&mut parallel_plugins, j);
|
||||||
|
if response.is_finalized {
|
||||||
|
plugin.set_finalized(true);
|
||||||
|
}
|
||||||
|
plugins[parallel_idx[j]] = plugin;
|
||||||
|
}
|
||||||
|
// Panicked plugins: restore the NullMetaPlugin sentinel and
|
||||||
|
// mark it finalized so future phases skip it cleanly.
|
||||||
|
for j in panicked {
|
||||||
|
let mut plugin = replace_plugin(&mut parallel_plugins, j);
|
||||||
|
plugin.set_finalized(true);
|
||||||
|
plugins[parallel_idx[j]] = plugin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run sequential plugins
|
||||||
|
for &i in &sequential_idx {
|
||||||
|
let response = plugins[i].initialize();
|
||||||
|
store_plugin_metadata(conn, item_id, &response);
|
||||||
|
if response.is_finalized {
|
||||||
|
plugins[i].set_finalized(true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,18 +190,64 @@ impl MetaService {
|
|||||||
conn: &Connection,
|
conn: &Connection,
|
||||||
item_id: i64,
|
item_id: i64,
|
||||||
) {
|
) {
|
||||||
for meta_plugin in plugins.iter_mut() {
|
// Partition non-finalized plugins by parallel_safe
|
||||||
// Skip plugins that are already finalized
|
let (parallel_idx, sequential_idx): (Vec<usize>, Vec<usize>) = plugins
|
||||||
if meta_plugin.is_finalized() {
|
.iter()
|
||||||
continue;
|
.enumerate()
|
||||||
|
.filter(|(_, p)| !p.is_finalized())
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.partition(|&i| plugins[i].parallel_safe());
|
||||||
|
|
||||||
|
// Run parallel-safe plugins concurrently on this chunk
|
||||||
|
if !parallel_idx.is_empty() {
|
||||||
|
// Extract plugins by unique index into a flat Vec indexed by position
|
||||||
|
let mut parallel_plugins: Vec<Box<dyn MetaPlugin>> =
|
||||||
|
Vec::with_capacity(parallel_idx.len());
|
||||||
|
for &i in ¶llel_idx {
|
||||||
|
parallel_plugins.push(replace_plugin(plugins, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = meta_plugin.update(chunk);
|
let (results, panicked): (Vec<(usize, MetaPluginResponse)>, Vec<usize>) =
|
||||||
self.process_plugin_response(conn, item_id, &mut **meta_plugin, response.clone());
|
std::thread::scope(|s| {
|
||||||
|
let handles: Vec<_> = parallel_plugins
|
||||||
|
.iter_mut()
|
||||||
|
.map(|plugin| s.spawn(move || plugin.update(chunk)))
|
||||||
|
.collect();
|
||||||
|
let mut results = Vec::with_capacity(handles.len());
|
||||||
|
let mut panicked = Vec::new();
|
||||||
|
for (j, handle) in handles.into_iter().enumerate() {
|
||||||
|
match handle.join() {
|
||||||
|
Ok(response) => results.push((j, response)),
|
||||||
|
Err(e) => {
|
||||||
|
error!("META_SERVICE: Plugin panicked during update: {e:?}");
|
||||||
|
panicked.push(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(results, panicked)
|
||||||
|
});
|
||||||
|
|
||||||
// Set finalized flag if response indicates finalization
|
for (j, response) in results {
|
||||||
|
store_plugin_metadata(conn, item_id, &response);
|
||||||
|
let mut plugin = replace_plugin(&mut parallel_plugins, j);
|
||||||
if response.is_finalized {
|
if response.is_finalized {
|
||||||
meta_plugin.set_finalized(true);
|
plugin.set_finalized(true);
|
||||||
|
}
|
||||||
|
plugins[parallel_idx[j]] = plugin;
|
||||||
|
}
|
||||||
|
for j in panicked {
|
||||||
|
let mut plugin = replace_plugin(&mut parallel_plugins, j);
|
||||||
|
plugin.set_finalized(true);
|
||||||
|
plugins[parallel_idx[j]] = plugin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run sequential plugins
|
||||||
|
for &i in &sequential_idx {
|
||||||
|
let response = plugins[i].update(chunk);
|
||||||
|
store_plugin_metadata(conn, item_id, &response);
|
||||||
|
if response.is_finalized {
|
||||||
|
plugins[i].set_finalized(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -143,57 +259,19 @@ impl MetaService {
|
|||||||
item_id: i64,
|
item_id: i64,
|
||||||
) {
|
) {
|
||||||
for meta_plugin in plugins.iter_mut() {
|
for meta_plugin in plugins.iter_mut() {
|
||||||
// Skip plugins that are already finalized
|
|
||||||
if meta_plugin.is_finalized() {
|
if meta_plugin.is_finalized() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = meta_plugin.finalize();
|
let response = meta_plugin.finalize();
|
||||||
self.process_plugin_response(conn, item_id, &mut **meta_plugin, response.clone());
|
store_plugin_metadata(conn, item_id, &response);
|
||||||
|
|
||||||
// Set finalized flag if response indicates finalization
|
|
||||||
if response.is_finalized {
|
if response.is_finalized {
|
||||||
meta_plugin.set_finalized(true);
|
meta_plugin.set_finalized(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal helper to process a meta plugin response and store metadata.
|
|
||||||
///
|
|
||||||
/// Iterates over the metadata entries in the response and stores each in the database
|
|
||||||
/// using `store_meta`. Logs warnings if storage fails.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `conn` - Database connection.
|
|
||||||
/// * `item_id` - Item ID to associate with the metadata.
|
|
||||||
/// * `_plugin` - Reference to the plugin (unused).
|
|
||||||
/// * `response` - The plugin response containing metadata.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Logs warnings for individual storage failures but does not return errors.
|
|
||||||
fn process_plugin_response(
|
|
||||||
&self,
|
|
||||||
conn: &Connection,
|
|
||||||
item_id: i64,
|
|
||||||
_plugin: &mut dyn MetaPlugin,
|
|
||||||
response: crate::meta_plugin::MetaPluginResponse,
|
|
||||||
) {
|
|
||||||
for meta_data in response.metadata {
|
|
||||||
// The metadata has already been processed by the plugin, so we can use it directly
|
|
||||||
// Save to database
|
|
||||||
let db_meta = crate::db::Meta {
|
|
||||||
id: item_id,
|
|
||||||
name: meta_data.name,
|
|
||||||
value: meta_data.value,
|
|
||||||
};
|
|
||||||
if let Err(e) = crate::db::store_meta(conn, db_meta) {
|
|
||||||
log::warn!("META_SERVICE: Failed to store metadata: {e}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Collects initial metadata from environment variables and hostname.
|
/// Collects initial metadata from environment variables and hostname.
|
||||||
///
|
///
|
||||||
/// Gathers metadata from `KEEP_META_*` environment variables and adds hostname
|
/// Gathers metadata from `KEEP_META_*` environment variables and adds hostname
|
||||||
@@ -222,6 +300,26 @@ impl MetaService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Stores metadata entries from a plugin response into the database.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `conn` - Database connection.
|
||||||
|
/// * `item_id` - Item ID to associate with the metadata.
|
||||||
|
/// * `response` - The plugin response containing metadata.
|
||||||
|
fn store_plugin_metadata(conn: &Connection, item_id: i64, response: &MetaPluginResponse) {
|
||||||
|
for meta_data in &response.metadata {
|
||||||
|
let db_meta = crate::db::Meta {
|
||||||
|
id: item_id,
|
||||||
|
name: meta_data.name.clone(),
|
||||||
|
value: meta_data.value.clone(),
|
||||||
|
};
|
||||||
|
if let Err(e) = crate::db::store_meta(conn, db_meta) {
|
||||||
|
log::warn!("META_SERVICE: Failed to store metadata: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for MetaService {
|
impl Default for MetaService {
|
||||||
/// Provides a default `MetaService` instance.
|
/// Provides a default `MetaService` instance.
|
||||||
///
|
///
|
||||||
|
|||||||
@@ -1,92 +1,309 @@
|
|||||||
#[cfg(all(test, feature = "server"))]
|
#[cfg(all(test, feature = "server"))]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::modes::server::auth::{Claims, check_permission, required_permission, validate_jwt};
|
||||||
use crate::modes::server::common::check_auth;
|
use crate::modes::server::common::check_auth;
|
||||||
use axum::http::{HeaderMap, HeaderValue};
|
use axum::http::{HeaderMap, HeaderValue, Method};
|
||||||
|
use jsonwebtoken::{EncodingKey, Header, encode};
|
||||||
|
|
||||||
|
fn make_jwt(claims: &serde_json::Value, secret: &str) -> String {
|
||||||
|
let header = Header::new(jsonwebtoken::Algorithm::HS256);
|
||||||
|
encode(
|
||||||
|
&header,
|
||||||
|
claims,
|
||||||
|
&EncodingKey::from_secret(secret.as_bytes()),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_basic_auth(username: &str, password: &str) -> String {
|
||||||
|
use base64::Engine;
|
||||||
|
let credentials = format!("{username}:{password}");
|
||||||
|
let encoded = base64::engine::general_purpose::STANDARD.encode(&credentials);
|
||||||
|
format!("Basic {encoded}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Password auth tests (Basic auth) ---
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_auth_with_no_password_required() {
|
fn test_auth_with_no_password_required() {
|
||||||
let headers = HeaderMap::new();
|
let headers = HeaderMap::new();
|
||||||
let password = None;
|
assert!(check_auth(&headers, &None, &None, &None));
|
||||||
|
|
||||||
// When no password is required, auth should pass
|
|
||||||
assert!(check_auth(&headers, &password));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_auth_with_bearer_token() {
|
fn test_auth_with_basic_auth_default_username() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"authorization",
|
||||||
|
HeaderValue::from_str(&make_basic_auth("keep", "secret123")).unwrap(),
|
||||||
|
);
|
||||||
|
let password = Some("secret123".to_string());
|
||||||
|
assert!(check_auth(&headers, &None, &password, &None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_auth_with_basic_auth_custom_username() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"authorization",
|
||||||
|
HeaderValue::from_str(&make_basic_auth("admin", "secret123")).unwrap(),
|
||||||
|
);
|
||||||
|
let username = Some("admin".to_string());
|
||||||
|
let password = Some("secret123".to_string());
|
||||||
|
assert!(check_auth(&headers, &username, &password, &None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_auth_with_basic_auth_wrong_password() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"authorization",
|
||||||
|
HeaderValue::from_str(&make_basic_auth("keep", "wrongpass")).unwrap(),
|
||||||
|
);
|
||||||
|
let password = Some("secret123".to_string());
|
||||||
|
assert!(!check_auth(&headers, &None, &password, &None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_auth_with_basic_auth_wrong_username() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"authorization",
|
||||||
|
HeaderValue::from_str(&make_basic_auth("wrong", "secret123")).unwrap(),
|
||||||
|
);
|
||||||
|
let username = Some("admin".to_string());
|
||||||
|
let password = Some("secret123".to_string());
|
||||||
|
assert!(!check_auth(&headers, &username, &password, &None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_auth_with_basic_auth_wrong_password_custom_username() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"authorization",
|
||||||
|
HeaderValue::from_str(&make_basic_auth("admin", "wrongpass")).unwrap(),
|
||||||
|
);
|
||||||
|
let username = Some("admin".to_string());
|
||||||
|
let password = Some("secret123".to_string());
|
||||||
|
assert!(!check_auth(&headers, &username, &password, &None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_auth_bearer_token_ignored_for_password_auth() {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"authorization",
|
"authorization",
|
||||||
HeaderValue::from_static("Bearer secret123"),
|
HeaderValue::from_static("Bearer secret123"),
|
||||||
);
|
);
|
||||||
|
|
||||||
let password = Some("secret123".to_string());
|
let password = Some("secret123".to_string());
|
||||||
|
// Bearer tokens are not checked for password auth — only Basic auth is valid
|
||||||
// Valid bearer token should pass
|
assert!(!check_auth(&headers, &None, &password, &None));
|
||||||
assert!(check_auth(&headers, &password));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_auth_with_invalid_bearer_token() {
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
"authorization",
|
|
||||||
HeaderValue::from_static("Bearer wrongtoken"),
|
|
||||||
);
|
|
||||||
|
|
||||||
let password = Some("secret123".to_string());
|
|
||||||
|
|
||||||
// Invalid bearer token should fail
|
|
||||||
assert!(!check_auth(&headers, &password));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_auth_with_basic_auth() {
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
// Basic auth for "keep:secret123" base64 encoded
|
|
||||||
headers.insert(
|
|
||||||
"authorization",
|
|
||||||
HeaderValue::from_static("Basic a2VlcDpzZWNyZXQxMjM="),
|
|
||||||
);
|
|
||||||
|
|
||||||
let password = Some("secret123".to_string());
|
|
||||||
|
|
||||||
// Valid basic auth should pass
|
|
||||||
assert!(check_auth(&headers, &password));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_auth_with_invalid_basic_auth() {
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
// Basic auth for "keep:wrongpass" base64 encoded
|
|
||||||
headers.insert(
|
|
||||||
"authorization",
|
|
||||||
HeaderValue::from_static("Basic a2VlcDp3cm9uZ3Bhc3M="),
|
|
||||||
);
|
|
||||||
|
|
||||||
let password = Some("secret123".to_string());
|
|
||||||
|
|
||||||
// Invalid basic auth should fail
|
|
||||||
assert!(!check_auth(&headers, &password));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_auth_with_missing_auth_header() {
|
fn test_auth_with_missing_auth_header() {
|
||||||
let headers = HeaderMap::new();
|
let headers = HeaderMap::new();
|
||||||
let password = Some("secret123".to_string());
|
let password = Some("secret123".to_string());
|
||||||
|
assert!(!check_auth(&headers, &None, &password, &None));
|
||||||
// Missing auth header should fail when password is required
|
|
||||||
assert!(!check_auth(&headers, &password));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_auth_with_malformed_auth_header() {
|
fn test_auth_with_malformed_auth_header() {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("authorization", HeaderValue::from_static("Invalid header"));
|
headers.insert("authorization", HeaderValue::from_static("Invalid header"));
|
||||||
|
|
||||||
let password = Some("secret123".to_string());
|
let password = Some("secret123".to_string());
|
||||||
|
assert!(!check_auth(&headers, &None, &password, &None));
|
||||||
|
}
|
||||||
|
|
||||||
// Malformed auth header should fail
|
// --- JWT validation tests ---
|
||||||
assert!(!check_auth(&headers, &password));
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_valid_token() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 9999999999usize,
|
||||||
|
"read": true,
|
||||||
|
"write": true,
|
||||||
|
"delete": false
|
||||||
|
});
|
||||||
|
let token = make_jwt(&claims, secret);
|
||||||
|
|
||||||
|
let result = validate_jwt(&token, secret);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let claims = result.unwrap();
|
||||||
|
assert_eq!(claims.sub, "test-client");
|
||||||
|
assert!(claims.read);
|
||||||
|
assert!(claims.write);
|
||||||
|
assert!(!claims.delete);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_expired() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 1000000000usize,
|
||||||
|
"read": true
|
||||||
|
});
|
||||||
|
let token = make_jwt(&claims, secret);
|
||||||
|
|
||||||
|
let result = validate_jwt(&token, secret);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_wrong_secret() {
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 9999999999usize,
|
||||||
|
"read": true
|
||||||
|
});
|
||||||
|
let token = make_jwt(&claims, "correct-secret");
|
||||||
|
|
||||||
|
let result = validate_jwt(&token, "wrong-secret");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_malformed() {
|
||||||
|
let result = validate_jwt("not.a.jwt", "secret");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt_missing_permissions_default_false() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "test-client",
|
||||||
|
"exp": 9999999999usize
|
||||||
|
});
|
||||||
|
let token = make_jwt(&claims, secret);
|
||||||
|
|
||||||
|
let claims = validate_jwt(&token, secret).unwrap();
|
||||||
|
assert!(!claims.read);
|
||||||
|
assert!(!claims.write);
|
||||||
|
assert!(!claims.delete);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Permission tests ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_required_permission_mapping() {
|
||||||
|
assert_eq!(required_permission(&Method::GET), "read");
|
||||||
|
assert_eq!(required_permission(&Method::HEAD), "read");
|
||||||
|
assert_eq!(required_permission(&Method::POST), "write");
|
||||||
|
assert_eq!(required_permission(&Method::PUT), "write");
|
||||||
|
assert_eq!(required_permission(&Method::PATCH), "write");
|
||||||
|
assert_eq!(required_permission(&Method::DELETE), "delete");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_check_permission_granted() {
|
||||||
|
let claims = Claims {
|
||||||
|
sub: "test".to_string(),
|
||||||
|
exp: 9999999999,
|
||||||
|
read: true,
|
||||||
|
write: false,
|
||||||
|
delete: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(check_permission(&claims, "read"));
|
||||||
|
assert!(!check_permission(&claims, "write"));
|
||||||
|
assert!(check_permission(&claims, "delete"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_check_permission_unknown_denied() {
|
||||||
|
let claims = Claims {
|
||||||
|
sub: "test".to_string(),
|
||||||
|
exp: 9999999999,
|
||||||
|
read: true,
|
||||||
|
write: true,
|
||||||
|
delete: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(!check_permission(&claims, "unknown"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- End-to-end permission scenarios ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_read_only_token_scenarios() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "readonly-client",
|
||||||
|
"exp": 9999999999usize,
|
||||||
|
"read": true,
|
||||||
|
"write": false,
|
||||||
|
"delete": false
|
||||||
|
});
|
||||||
|
let token = make_jwt(&claims, secret);
|
||||||
|
let claims = validate_jwt(&token, secret).unwrap();
|
||||||
|
|
||||||
|
// GET (read) should be allowed
|
||||||
|
assert!(check_permission(&claims, required_permission(&Method::GET)));
|
||||||
|
// POST (write) should be denied
|
||||||
|
assert!(!check_permission(
|
||||||
|
&claims,
|
||||||
|
required_permission(&Method::POST)
|
||||||
|
));
|
||||||
|
// DELETE should be denied
|
||||||
|
assert!(!check_permission(
|
||||||
|
&claims,
|
||||||
|
required_permission(&Method::DELETE)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_full_access_token_scenarios() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "admin",
|
||||||
|
"exp": 9999999999usize,
|
||||||
|
"read": true,
|
||||||
|
"write": true,
|
||||||
|
"delete": true
|
||||||
|
});
|
||||||
|
let token = make_jwt(&claims, secret);
|
||||||
|
let claims = validate_jwt(&token, secret).unwrap();
|
||||||
|
|
||||||
|
assert!(check_permission(&claims, required_permission(&Method::GET)));
|
||||||
|
assert!(check_permission(
|
||||||
|
&claims,
|
||||||
|
required_permission(&Method::POST)
|
||||||
|
));
|
||||||
|
assert!(check_permission(&claims, required_permission(&Method::PUT)));
|
||||||
|
assert!(check_permission(
|
||||||
|
&claims,
|
||||||
|
required_permission(&Method::DELETE)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_write_only_token_scenarios() {
|
||||||
|
let secret = "test-secret";
|
||||||
|
let claims = serde_json::json!({
|
||||||
|
"sub": "writer",
|
||||||
|
"exp": 9999999999usize,
|
||||||
|
"read": false,
|
||||||
|
"write": true,
|
||||||
|
"delete": false
|
||||||
|
});
|
||||||
|
let token = make_jwt(&claims, secret);
|
||||||
|
let claims = validate_jwt(&token, secret).unwrap();
|
||||||
|
|
||||||
|
assert!(!check_permission(
|
||||||
|
&claims,
|
||||||
|
required_permission(&Method::GET)
|
||||||
|
));
|
||||||
|
assert!(check_permission(
|
||||||
|
&claims,
|
||||||
|
required_permission(&Method::POST)
|
||||||
|
));
|
||||||
|
assert!(!check_permission(
|
||||||
|
&claims,
|
||||||
|
required_permission(&Method::DELETE)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
198
src/tokenizer/mod.rs
Normal file
198
src/tokenizer/mod.rs
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
|
||||||
|
/// Supported LLM token encodings.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||||
|
pub enum TokenEncoding {
|
||||||
|
/// cl100k_base — used by GPT-3.5, GPT-4, text-embedding-ada-002.
|
||||||
|
#[default]
|
||||||
|
Cl100kBase,
|
||||||
|
/// o200k_base — used by GPT-4o, GPT-5, o1, o3, o4 models.
|
||||||
|
O200kBase,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for TokenEncoding {
|
||||||
|
type Err = anyhow::Error;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"cl100k_base" => Ok(TokenEncoding::Cl100kBase),
|
||||||
|
"o200k_base" => Ok(TokenEncoding::O200kBase),
|
||||||
|
_ => bail!("Unknown token encoding: {s}. Supported: cl100k_base, o200k_base"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TokenEncoding {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TokenEncoding::Cl100kBase => write!(f, "cl100k_base"),
|
||||||
|
TokenEncoding::O200kBase => write!(f, "o200k_base"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrapper around tiktoken BPE tokenizer.
|
||||||
|
///
|
||||||
|
/// Provides streaming-friendly tokenization: count tokens in text,
|
||||||
|
/// split text into token strings, and decode token IDs back to text.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Tokenizer {
|
||||||
|
bpe: tiktoken_rs::CoreBPE,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Tokenizer {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("Tokenizer").finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Static tokenizer instances — loaded once per process, shared across all plugins.
|
||||||
|
static CL100K: Lazy<Tokenizer> = Lazy::new(|| {
|
||||||
|
Tokenizer::new(TokenEncoding::Cl100kBase).expect("Failed to create cl100k_base tokenizer")
|
||||||
|
});
|
||||||
|
static O200K: Lazy<Tokenizer> = Lazy::new(|| {
|
||||||
|
Tokenizer::new(TokenEncoding::O200kBase).expect("Failed to create o200k_base tokenizer")
|
||||||
|
});
|
||||||
|
|
||||||
|
/// Returns a reference to a cached tokenizer for the given encoding.
|
||||||
|
///
|
||||||
|
/// The BPE vocabulary is loaded once per encoding and reused for the
|
||||||
|
/// lifetime of the process.
|
||||||
|
pub fn get_tokenizer(encoding: TokenEncoding) -> &'static Tokenizer {
|
||||||
|
match encoding {
|
||||||
|
TokenEncoding::Cl100kBase => &CL100K,
|
||||||
|
TokenEncoding::O200kBase => &O200K,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tokenizer {
|
||||||
|
/// Creates a new tokenizer for the specified encoding.
|
||||||
|
pub fn new(encoding: TokenEncoding) -> Result<Self> {
|
||||||
|
let bpe = match encoding {
|
||||||
|
TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base()?,
|
||||||
|
TokenEncoding::O200kBase => tiktoken_rs::o200k_base()?,
|
||||||
|
};
|
||||||
|
Ok(Self { bpe })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Counts the number of tokens in the given text.
|
||||||
|
///
|
||||||
|
/// Uses `encode_ordinary` which treats the text as a single unit.
|
||||||
|
/// For streaming: tokenizing chunks independently and summing gives
|
||||||
|
/// the same result as tokenizing the full text (exact when no regex
|
||||||
|
/// match spans a chunk boundary).
|
||||||
|
pub fn count(&self, text: &str) -> usize {
|
||||||
|
self.bpe.encode_ordinary(text).len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Splits text into individual decoded token strings.
|
||||||
|
///
|
||||||
|
/// Each returned string corresponds to one token. Useful for finding
|
||||||
|
/// exact byte boundaries when filtering by token count.
|
||||||
|
pub fn split_by_token(&self, text: &str) -> Result<Vec<String>> {
|
||||||
|
self.bpe.split_by_token(text, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an iterator over decoded token strings.
|
||||||
|
///
|
||||||
|
/// Lazily produces token strings without allocating a Vec for all tokens.
|
||||||
|
/// Use this when you only need the first N tokens (e.g., head/skip filters).
|
||||||
|
pub fn split_by_token_iter<'a>(
|
||||||
|
&'a self,
|
||||||
|
text: &'a str,
|
||||||
|
) -> impl Iterator<Item = Result<String>> + 'a {
|
||||||
|
self.bpe.split_by_token_iter(text, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Counts tokens up to `max_tokens` and returns `(token_count, byte_position)`.
|
||||||
|
///
|
||||||
|
/// Uses an iterator to stop early, avoiding allocation of token strings
|
||||||
|
/// beyond `max_tokens`. The byte_position is in the lossy UTF-8 encoding
|
||||||
|
/// of `text` — use `map_lossy_pos_to_bytes` to map back to original bytes.
|
||||||
|
pub fn count_bounded(&self, text: &str, max_tokens: usize) -> (usize, usize) {
|
||||||
|
let mut count = 0usize;
|
||||||
|
let mut byte_pos = 0usize;
|
||||||
|
for token_str in self.bpe.split_by_token_iter(text, false) {
|
||||||
|
if let Ok(s) = token_str {
|
||||||
|
byte_pos += s.len();
|
||||||
|
}
|
||||||
|
count += 1;
|
||||||
|
if count >= max_tokens {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(count, byte_pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decodes a slice of token IDs back into a string.
|
||||||
|
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||||
|
self.bpe.decode(tokens.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tokenizer_count() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let count = tok.count("Hello, world!");
|
||||||
|
assert!(count > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tokenizer_split() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let tokens = tok.split_by_token("Hello world").unwrap();
|
||||||
|
assert_eq!(tokens.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tokenizer_roundtrip() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let text = "The quick brown fox jumps over the lazy dog.";
|
||||||
|
let token_ids: Vec<u32> = tok
|
||||||
|
.bpe
|
||||||
|
.encode_ordinary(text)
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| x as u32)
|
||||||
|
.collect();
|
||||||
|
let decoded = tok.decode(&token_ids).unwrap();
|
||||||
|
assert_eq!(text, decoded);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chunk_sum_close_to_full() {
|
||||||
|
let tok = Tokenizer::new(TokenEncoding::Cl100kBase).unwrap();
|
||||||
|
let text = "The quick brown fox jumps over the lazy dog. \
|
||||||
|
Pack my box with five dozen liquor jugs. \
|
||||||
|
How vexingly quick daft zebras jump!";
|
||||||
|
let full_count = tok.count(text);
|
||||||
|
|
||||||
|
// Split into chunks at word boundaries
|
||||||
|
let mid = text.find("Pack").unwrap();
|
||||||
|
let (a, b) = text.split_at(mid);
|
||||||
|
let chunk_sum = tok.count(a) + tok.count(b);
|
||||||
|
// Chunk-based counting may differ by 1-2 tokens when a BPE merge
|
||||||
|
// boundary falls near the chunk split point
|
||||||
|
assert!(
|
||||||
|
(full_count as isize - chunk_sum as isize).abs() <= 2,
|
||||||
|
"full={full_count}, chunk_sum={chunk_sum}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encoding_from_str() {
|
||||||
|
assert_eq!(
|
||||||
|
"cl100k_base".parse::<TokenEncoding>().unwrap(),
|
||||||
|
TokenEncoding::Cl100kBase
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
"o200k_base".parse::<TokenEncoding>().unwrap(),
|
||||||
|
TokenEncoding::O200kBase
|
||||||
|
);
|
||||||
|
assert!("unknown".parse::<TokenEncoding>().is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user