diff --git a/.gitignore b/.gitignore
index 8cfd041ff..446cdb256 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,6 +32,7 @@ wheels/
/temp
MANIFEST
.locks/
+.temp/
# PyInstaller
# Usually these files are written by a python script from a template
@@ -131,8 +132,7 @@ replace.sh
result.png
result.jpg
result.mp4
-output/
-outputs/
+output*
wandb/
*.out
benchmarks/
diff --git a/Dockerfile b/Dockerfile
index 8107ebcb3..debd1b9b2 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,51 +1,21 @@
-FROM modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.9.1-1.35.0
+FROM modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.9.1-py312-torch2.10.0-vllm0.19.1-modelscope1.35.4-swift4.1.3
-# Install miniconda with Python 3.12
-RUN curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
- bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda && \
- rm Miniconda3-latest-Linux-x86_64.sh
-ENV PATH="/opt/conda/bin:${PATH}"
-RUN conda create -n twinkle python=3.12 -y --override-channels -c conda-forge
-ENV PATH="/opt/conda/envs/twinkle/bin:${PATH}"
+# Forward-compat user-mode CUDA driver shim, then pip cuDNN ahead of apt cuDNN (transformer_engine undefined-symbol fix). Path must match base image's Python version.
+ENV LD_LIBRARY_PATH="/usr/local/cuda/compat:/usr/local/lib/python3.12/dist-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH}"
-ENV SETUPTOOLS_USE_DISTUTILS=local
+# Only twinkle-specific deps; everything else (torch, vllm, TE, flash-attn, megatron-core, mcore-bridge, transformers, peft, accelerate) ships in the base.
+RUN pip install --no-cache-dir \
+ flash-linear-attention \
+ tinker==0.16.1 \
+ "ray[serve]" && \
+ rm -rf /root/.cache /tmp/*
-# Install base packages
-RUN pip install --upgrade peft accelerate transformers "modelscope[framework]" --no-cache-dir
-
-# Install vllm
-RUN pip install --upgrade vllm --no-cache-dir
-
-# Install transformer_engine and megatron_core
-RUN SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") && \
- CUDNN_PATH=$SITE_PACKAGES/nvidia/cudnn \
- CPLUS_INCLUDE_PATH=$SITE_PACKAGES/nvidia/cudnn/include \
- pip install --no-build-isolation "transformer_engine[pytorch]" --no-cache-dir
-
-RUN pip install megatron_core mcore_bridge --no-cache-dir
-
-# Install flash-attention (default arch 8.0;9.0, override via build-arg if needed)
-ARG TORCH_CUDA_ARCH_LIST="8.0;9.0"
-RUN TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
- MAX_JOBS=8 \
- FLASH_ATTENTION_FORCE_BUILD=TRUE \
- pip install flash-attn --no-build-isolation --no-cache-dir
-
-RUN pip install flash-linear-attention -U --no-cache-dir
-
-# Install numpy
-RUN pip install numpy==2.2 --no-cache-dir
-
-# Install tinker, ray, and other deps
-RUN pip install --no-cache-dir tinker==0.16.1 "ray[serve]" transformers peft<=0.18 accelerate -U
-
-# Clone and install twinkle, checkout to latest v-tag
-RUN git clone https://github.com/modelscope/twinkle.git
-WORKDIR /twinkle
-RUN echo "Available release branches:" && git branch -r -l 'origin/release/*' --sort=-v:refname && \
+# Clone latest release and install twinkle in editable mode.
+RUN git clone https://github.com/modelscope/twinkle.git /twinkle && \
+ cd /twinkle && \
LATEST_RELEASE=$(git branch -r -l 'origin/release/*' --sort=-v:refname | head -n 1 | tr -d ' ') && \
- echo "Checking out: $LATEST_RELEASE" && \
- git checkout --track "$LATEST_RELEASE"
+ if [ -n "$LATEST_RELEASE" ]; then echo "Checking out: $LATEST_RELEASE" && git checkout --track "$LATEST_RELEASE"; else echo "No release branch found, staying on default branch"; fi && \
+ pip install --no-cache-dir --no-build-isolation -e . && \
+ rm -rf /root/.cache
-# Install twinkle itself
-RUN pip install -e . --no-build-isolation
+WORKDIR /twinkle
diff --git a/README.md b/README.md
index 26edb400e..4dd203cfb 100644
--- a/README.md
+++ b/README.md
@@ -61,7 +61,7 @@ pip install -e .
### Use our docker image:
```text
-modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.2.1
+modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.3.0
```
If you need to use Twinkle's Client, you can use our one-click installation script:
diff --git a/README_ZH.md b/README_ZH.md
index 845c42e18..5d588b393 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -54,7 +54,7 @@ pip install -e .
### 使用docker镜像:
```text
-modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.2.1
+modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.3.0
```
如果你需要使用Twinkle的Client,可以使用我们的一键安装脚本:
diff --git a/cookbook/exp/cold_start/train_cold_start.py b/cookbook/exp/cold_start/train_cold_start.py
new file mode 100644
index 000000000..da7149bba
--- /dev/null
+++ b/cookbook/exp/cold_start/train_cold_start.py
@@ -0,0 +1,332 @@
+import json
+import os
+from functools import partial
+from pathlib import Path
+from typing import Any, Dict, Iterator, List
+
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, PackingDataset
+from twinkle.dataset.base import DatasetMeta
+from twinkle.model import MegatronModel
+from twinkle_agentic.preprocessor import (
+ QualityPreprocessor, SamplerBackend,
+ IntentClassifier, HardFilter, RefuseFilter, DeadLoopFilter, TokenSoupFilter, MessageSanityFilter,
+ SpecialCharsFilter, ModelFilter, DedupFilter,
+ MessageNormalizer,
+)
+
+logger = get_logger()
+
+# ── Model ────────────────────────────────────────────────────────────────────
+MODEL_ID = 'ms://Qwen/Qwen3-4B'
+TEMPLATE_NAME = 'Template'
+MAX_LENGTH = 80000
+
+# ── GPU allocation ───────────────────────────────────────────────────────────
+MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
+SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 0))
+NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
+
+# ── Training ─────────────────────────────────────────────────────────────────
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 1))
+LEARNING_RATE = float(os.environ.get('LR', 1e-5))
+GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRAD_ACCUM', 4))
+LOG_INTERVAL = 1
+SAVE_INTERVAL = 500
+NUM_STEPS = int(os.environ.get('NUM_STEPS', 5000))
+
+# ── Output ───────────────────────────────────────────────────────────────────
+OUTPUT_DIR = './output/streaming_sft'
+TRAINED_DATA_PATH = os.path.join(OUTPUT_DIR, 'trained_data.jsonl')
+DROPPED_DATA_PATH = os.path.join(OUTPUT_DIR, 'dropped_data.jsonl')
+ADAPTER_NAME = 'default'
+
+# ── Data source ──────────────────────────────────────────────────────────────
+CSV_PATH = os.environ.get('CSV_PATH')
+DATASET_TOTAL = int(os.environ.get('DATASET_TOTAL', 10000)) # 0 = full materialized dataset
+# Worker count for HF Dataset.map(num_proc=N); spawn start method is forced in twinkle.dataset.base.
+MAP_NUM_PROC = int(os.environ.get('MAP_NUM_PROC', 16))
+
+
+def _canonicalize_tool_call(tc: Any) -> Dict[str, Any]:
+ """Coerce ``tool_calls[i]`` to a fixed-schema dict for stable Arrow inference.
+
+ Keeps ``function.arguments`` as the OpenAI-native JSON string so every row
+ sees a uniform ``string`` field; any string→dict decoding is the
+ chat_template's concern (see ``Template._apply_chat_template``).
+
+ The decoded form is enforced to be a JSON object so the chat_template's
+ ``|items`` filter never receives list/scalar/null — those originate from
+ dirty CSV rows and are coerced to ``{}`` here, the ingestion boundary.
+ """
+ tc = tc if isinstance(tc, dict) else {}
+ fn = tc.get('function') if isinstance(tc.get('function'), dict) else {}
+ args = fn.get('arguments')
+ if isinstance(args, dict):
+ args_str = json.dumps(args, ensure_ascii=False)
+ elif isinstance(args, str) and args.strip():
+ try:
+ decoded = json.loads(args)
+ except json.JSONDecodeError:
+ decoded = {}
+ if not isinstance(decoded, dict):
+ decoded = {}
+ args_str = json.dumps(decoded, ensure_ascii=False)
+ else:
+ args_str = '{}'
+ return {
+ 'id': str(tc.get('id') or ''),
+ 'type': str(tc.get('type') or 'function'),
+ 'function': {
+ 'name': str(fn.get('name') or ''),
+ 'arguments': args_str,
+ },
+ }
+
+
+def _stream_csv_rows(csv_path: str, max_rows: int = 0) -> Iterator[Dict[str, Any]]:
+ """Stream the custom CSV: each line is `ts,model,req_id,messages_json` (no quoting).
+
+ The first 3 fields are scalar; the remainder of the line is a JSON array of
+ chat messages, possibly containing commas — so we split on the first 3 commas only.
+ ``max_rows`` caps the yielded rows at ingestion time so Arrow never materializes
+ the unused tail.
+ """
+ emitted = 0
+ with open(csv_path, 'rb') as f:
+ bad_bytes = 0
+ for raw in f:
+ try:
+ line = raw.decode('utf-8').rstrip('\n').rstrip('\r')
+ except UnicodeDecodeError:
+ bad_bytes += 1
+ continue
+ if not line:
+ continue
+ parts = line.split(',', 3)
+ if len(parts) < 4:
+ continue
+ ts, _model, req_id, msgs_raw = parts
+ try:
+ raw_msgs = json.loads(msgs_raw)
+ except json.JSONDecodeError:
+ continue
+ messages: List[Dict[str, Any]] = []
+ for m in raw_msgs:
+ role = m.get('role', '')
+ content = m.get('content')
+ # User content arrives as [{'type':'text','text':...}, ...]; flatten to plain string.
+ if isinstance(content, list):
+ content = ''.join(
+ p.get('text', '') for p in content
+ if isinstance(p, dict) and p.get('type') == 'text')
+ if content is None:
+ content = ''
+ if not isinstance(content, str):
+ continue
+ raw_tcs = m.get('tool_calls') if role == 'assistant' else None
+ tc_list = [_canonicalize_tool_call(tc) for tc in raw_tcs] if raw_tcs else []
+ if role == 'assistant':
+ if not content and not tc_list:
+ continue
+ if m.get('reasoning_content'):
+ content = f"{m['reasoning_content']}{content}"
+ elif role == 'tool':
+ pass
+ elif not content:
+ continue
+ # tool_calls stored as JSON string (empty -> ''): keeps Arrow schema as a
+ # stable Value(string) regardless of empty-list / heterogeneous-struct shards.
+ # Template._apply_chat_template decodes it back to list before jinja render.
+ messages.append({
+ 'role': role,
+ 'content': content,
+ 'tool_calls': json.dumps(tc_list, ensure_ascii=False) if tc_list else '',
+ 'tool_call_id': str(m.get('tool_call_id') or '') if role == 'tool' else '',
+ })
+ if not messages:
+ continue
+ yield {
+ 'id': f'csv__{ts}__{req_id}',
+ 'source': Path(csv_path).stem,
+ 'model_id': _model,
+ 'messages': messages,
+ 'user_data': [],
+ }
+ emitted += 1
+ if max_rows and emitted >= max_rows:
+ break
+
+
+# ── QualityPreprocessor config ───────────────────────────────────────────────
+SENSITIVE_WORDS_FILE = str(
+ Path(__file__).resolve().parent.parent.parent / 'sensitive_words.txt')
+# chr_min cutoff: keep round if chr_min < threshold (low chr_min = hard).
+CHR_MIN_THRESHOLD = float(os.environ.get('CHR_MIN_THRESHOLD', 0.5))
+REFINE_TEMPERATURE = float(os.environ.get('REFINE_TEMPERATURE', 0.6))
+REFINE_MAX_TOKENS = int(os.environ.get('REFINE_MAX_TOKENS', 4096))
+
+# ── Pass@4 LLM-as-judge (grades each diagnostic rollout vs GT) ───────────────
+# Set JUDGE_MODEL='' to disable; otherwise judge runs over every diagnostic round.
+JUDGE_MODEL = os.environ.get('JUDGE_MODEL', 'qwen3.7-max')
+JUDGE_BASE_URL = os.environ.get('JUDGE_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
+JUDGE_API_KEY = os.environ.get('JUDGE_API_KEY', 'EMPTY')
+JUDGE_TEMPERATURE = float(os.environ.get('JUDGE_TEMPERATURE', 0.3))
+JUDGE_MAX_TOKENS = int(os.environ.get('JUDGE_MAX_TOKENS', 32000))
+JUDGE_MAX_WORKERS = int(os.environ.get('JUDGE_MAX_WORKERS', 16))
+
+
+def build_dataset(backend: SamplerBackend) -> Dataset:
+ """Materialize the local CSV, convert to SFT messages format, run QualityPreprocessor.
+
+ Switched from streaming IterableDataset to in-memory Dataset so HF
+ `Dataset.map(num_proc=N)` can parallelize the QualityPreprocessor pipeline.
+ """
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+ # Custom CSV format (commas inside JSON) — feed framework via callable, not csv loader.
+ meta = DatasetMeta(
+ dataset_id=Path(CSV_PATH).stem,
+ data=partial(_stream_csv_rows, csv_path=CSV_PATH, max_rows=DATASET_TOTAL),
+ )
+ dataset = PackingDataset(meta)
+
+ qp = QualityPreprocessor(
+ pipeline=[
+ ModelFilter(),
+ MessageNormalizer(),
+ HardFilter(
+ min_user_chars_cjk=14, min_user_chars=24,
+ system_deny_keywords=[
+ '角色扮演', '扮演', '人设', 'roleplay', 'role play', 'cosplay',
+ '群聊模拟', '虚拟角色', '二次元', 'OC设定',
+ ],
+ max_rounds=30,
+ ),
+ RefuseFilter(),
+ DeadLoopFilter(),
+ MessageSanityFilter(sensitive_words_file='.temp/sensitive_words.txt'),
+ SpecialCharsFilter(max_ratio=0.6),
+ TokenSoupFilter(max_chars=8000),
+ IntentClassifier(),
+ # ScoreFilter(
+ # template=template,
+ # backend=backend,
+ # scorers=[
+ # ChrMinScorer(),
+ # ],
+ # ),
+ # PIIPresidioFilter(languages=('en', 'zh')),
+ ],
+ dropped_log_path=DROPPED_DATA_PATH,
+ )
+ dataset.map(qp, num_proc=8, load_from_cache_file=True)
+ dataset.map(
+ QualityPreprocessor(pipeline=[DedupFilter()]),
+ num_proc=1,
+ batch_size=len(dataset.dataset),
+ load_from_cache_file=True,
+ )
+
+ print(len(dataset.dataset))
+ dataset.set_template(
+ TEMPLATE_NAME,
+ model_id=MODEL_ID,
+ max_length=MAX_LENGTH,
+ truncation_strategy='delete',
+ enable_thinking=False,
+ )
+ dataset.encode(num_proc=16, load_from_cache_file=True)
+ dataset.pack_dataset()
+ return dataset
+
+
+def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader):
+ model.save(
+ checkpoint_name,
+ output_dir=OUTPUT_DIR,
+ adapter_name=ADAPTER_NAME,
+ save_optimizer=True,
+ consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
+ )
+
+
+def train():
+ # ── Ray mode: GPUs 0-3 for training, GPUs 4-7 for vLLMSampler ────────────
+ device_groups = [
+ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
+ # DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=2),
+ ]
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=1, cp_size=8)
+ # sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS // 2, tp_size=2)
+ twinkle.initialize(mode='local', nproc_per_node=NUM_GPUS, groups=device_groups,
+ global_device_mesh=model_mesh, lazy_collect=False)
+
+ # ── vLLMSampler on GPUs 4-7 (Ray actor, no HTTP overhead) ────────────────
+ # sampler = vLLMSampler(
+ # model_id=MODEL_ID,
+ # engine_args={
+ # 'gpu_memory_utilization': 0.6,
+ # 'max_model_len': MAX_LENGTH,
+ # },
+ # device_mesh=sampler_mesh,
+ # remote_group='sampler',
+ # )
+ # sampler.set_template(TEMPLATE_NAME, model_id=MODEL_ID)
+ # backend = SamplerBackend(sampler)
+ # logger.info(f'vLLMSampler ready on GPUs {MODEL_GPUS}-{NUM_GPUS - 1}')
+
+ # ── Dataset with full QualityPreprocessor (uses SamplerBackend) ───────────
+ dataset = build_dataset(None)
+ dataloader = DataLoader(
+ dataset=dataset,
+ batch_size=BATCH_SIZE,
+ )
+
+ # ── Model (LoRA on 4 GPUs) ────────────────────────────────────────────────
+ model = MegatronModel(
+ model_id=MODEL_ID,
+ device_mesh=model_mesh,
+ # remote_group='model',
+ # attn_implementation='flash_attention_2',
+ )
+
+ lora_config = LoraConfig(r=16, lora_alpha=32, target_modules='all-linear')
+ model.add_adapter_to_model(
+ ADAPTER_NAME, lora_config,
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE)
+ model.set_lr_scheduler(
+ scheduler_cls='default',
+ lr_warmup_steps=2,
+ lr_decay_steps=len(dataloader))
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+ logger.info(f'Total steps: {NUM_STEPS}, model GPUs: {MODEL_GPUS}, sampler GPUs: {SAMPLER_GPUS}')
+
+ for cur_step, batch in enumerate(dataloader):
+ model.forward_backward(inputs=batch)
+ model.clip_grad_and_step()
+
+ if cur_step % LOG_INTERVAL == 0:
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Step {cur_step}/{NUM_STEPS}, metric: {metric}')
+
+ if cur_step % SAVE_INTERVAL == 0:
+ save_checkpoint(model, f'step-{cur_step}', dataloader)
+
+ if cur_step >= NUM_STEPS:
+ break
+
+ save_checkpoint(model, 'last-checkpoint', dataloader)
+ logger.info(f'Training complete. Trained data saved to: {TRAINED_DATA_PATH}')
+ logger.info(f'Dropped data saved to: {DROPPED_DATA_PATH}')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/exp/condenser/dataset.py b/cookbook/exp/condenser/dataset.py
new file mode 100644
index 000000000..32c30de4b
--- /dev/null
+++ b/cookbook/exp/condenser/dataset.py
@@ -0,0 +1,459 @@
+import hashlib
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+from datasets import Features, Value
+from modelscope import dataset_snapshot_download
+
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import Preprocessor
+
+_TARGET_FEATURES = Features({
+ 'id': Value('string'),
+ 'source': Value('string'),
+ 'messages': [{'role': Value('string'), 'content': Value('string')}],
+})
+
+
+def _hash_id(prefix: str, content: str) -> str:
+ """Stable id from MD5 of content; collision-free for textual datasets."""
+ return f'{prefix}__{hashlib.md5(content.encode("utf-8")).hexdigest()[:16]}'
+
+
+def _register(dataset, processor_cls, meta: DatasetMeta, init_args: Optional[Dict[str, Any]] = None,
+ load_from_cache_file: bool = True) -> None:
+ """Add dataset and run preprocessor; auto-strip every input column to enforce
+ the universal ``{id, source, messages}`` output schema."""
+ dataset.add_dataset(meta)
+ cols = list(dataset.datasets[meta.get_id()].column_names)
+ dataset.map(
+ processor_cls,
+ dataset_meta=meta,
+ init_args=init_args or {},
+ remove_columns=cols,
+ load_from_cache_file=load_from_cache_file,
+ features=_TARGET_FEATURES,
+ )
+
+
+# ===== MuSiQue =====
+MUSIQUE_REPO = 'voidful/MuSiQue'
+
+
+class MusiqueProcessor(Preprocessor):
+ """MuSiQue raw row → multiple ``{id, source, messages}`` rows, one per paragraph."""
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ if row.get('answerable') is False:
+ continue
+ parent = str(row.get('id', ''))
+ for idx, p in enumerate(row.get('paragraphs') or []):
+ text = (p.get('paragraph_text') or '').strip()
+ if not text:
+ continue
+ out.append({
+ 'id': f'musique__{parent}__{idx}',
+ 'source': 'musique',
+ 'messages': [{'role': 'assistant', 'content': text}],
+ })
+ return self.map_row_to_col(out, keys=['id', 'source', 'messages'])
+
+
+# Repo 仅含原始 JSONL 无 HF 元数据,必须先快照下载再以文件路径注册。
+_musique_jsonl = Path(dataset_snapshot_download(MUSIQUE_REPO)) / 'musique_ans_v1.0_train.jsonl'
+if not _musique_jsonl.is_file():
+ raise FileNotFoundError(f'MuSiQue raw file not found: {_musique_jsonl}')
+
+
+# ===== swift/github-code =====
+GITHUB_CODE_REPO = 'ms://swift/github-code'
+
+
+class GithubCodeProcessor(Preprocessor):
+ """github-code row → ``{id, source, messages}``;按代码长度均匀采样。
+
+ 把 ``[length_min, length_max)`` 切 ``n_buckets`` 桶,每桶配额 ``target/n_buckets``,
+ 桶满或超界即丢;近似得到 ``target`` 条且长度均匀分布的样本。
+ 依赖 batched map 单进程下实例状态跨 batch 共享(``num_proc>1`` 会失效)。
+ """
+
+ def __init__(self, target: int = 30000, length_min: int = 500,
+ length_max: int = 40000, n_buckets: int = 30):
+ self.length_min = length_min
+ self.length_max = length_max
+ self.n_buckets = n_buckets
+ self.bucket_quota = max(1, target // n_buckets)
+ self.bucket_count = [0] * n_buckets
+
+ def _bucket(self, n: int) -> int:
+ if n < self.length_min or n >= self.length_max:
+ return -1
+ idx = int((n - self.length_min) / (self.length_max - self.length_min) * self.n_buckets)
+ return min(idx, self.n_buckets - 1)
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ code = row.get('code') or ''
+ if not isinstance(code, str):
+ continue
+ b = self._bucket(len(code))
+ if b < 0 or self.bucket_count[b] >= self.bucket_quota:
+ continue
+ self.bucket_count[b] += 1
+ lang = row.get('language') or 'unknown'
+ out.append({
+ 'id': _hash_id(f'github_code__{lang}', code),
+ 'source': 'github-code',
+ 'messages': [{'role': 'assistant', 'content': code}],
+ })
+ return self.map_row_to_col(out, keys=['id', 'source', 'messages'])
+
+
+# ===== modelscope/competition_math =====
+COMPETITION_MATH_REPO = 'ms://modelscope/competition_math'
+
+
+class MathProcessor(Preprocessor):
+ """competition_math row → ``{id, source, messages}`` (user/assistant pair)."""
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ problem = (row.get('problem') or '').strip()
+ solution = (row.get('solution') or '').strip()
+ if not problem or not solution:
+ continue
+ out.append({
+ 'id': _hash_id('math', f'{problem}\n{solution}'),
+ 'source': 'competition_math',
+ 'messages': [
+ {'role': 'assistant', 'content': solution},
+ ],
+ })
+ return self.map_row_to_col(out, keys=['id', 'source', 'messages'])
+
+
+# ===== nampdn-ai/tiny-textbooks =====
+TINY_TEXTBOOKS_REPO = 'ms://AI-ModelScope/tiny-textbooks'
+
+
+class TinyTextbooksProcessor(Preprocessor):
+ """tiny-textbooks row → ``{id, source, messages}`` (user/assistant pair)."""
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ text = (row.get('text') or '').strip()
+ textbook = (row.get('textbook') or '').strip()
+ if not text or not textbook:
+ continue
+ out.append({
+ 'id': _hash_id('tinytb', f'{text}\n{textbook}'),
+ 'source': 'tiny-textbooks',
+ 'messages': [
+ {'role': 'assistant', 'content': textbook},
+ ],
+ })
+ return self.map_row_to_col(out, keys=['id', 'source', 'messages'])
+
+
+# ===== Passage Explosion for Compression Distillation =====
+# Each message content >= threshold becomes a standalone row: messages=[{role:user, content:X}]
+
+_MIN_PASSAGE_LEN = 500 # CJK-equivalent units
+
+
+def _effective_len(text: str) -> int:
+ """CJK chars count double; threshold 500 ≈ 500 Chinese chars ≈ 1000 Latin chars."""
+ cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f')
+ return cjk * 2 + (len(text) - cjk)
+
+
+def _extract_content(msg: dict) -> str:
+ """Extract text content from a message dict, handling multimodal list-content."""
+ content = msg.get('content')
+ if isinstance(content, list):
+ content = '\n'.join(
+ p.get('text', '') if isinstance(p, dict) else str(p) for p in content)
+ if not isinstance(content, str):
+ return ''
+ return content.strip()
+
+
+class PassageExplodeProcessor(Preprocessor):
+ """Explode multi-turn messages into individual long passages for compression distillation."""
+
+ def __init__(self, source: str):
+ self.source = source
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ messages = row.get('messages')
+ if isinstance(messages, str):
+ try:
+ messages = json.loads(messages)
+ except (ValueError, TypeError):
+ continue
+ if not isinstance(messages, list):
+ continue
+ for msg in messages:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('role') or ''
+ if role == 'system':
+ continue
+ content = _extract_content(msg)
+ if not content or _effective_len(content) < _MIN_PASSAGE_LEN:
+ continue
+ out.append({
+ 'id': _hash_id(self.source, content),
+ 'source': self.source,
+ 'messages': [{'role': 'assistant', 'content': content}],
+ })
+ return self.map_row_to_col(out, keys=['id', 'source', 'messages'])
+
+
+# ===== Reasoning / CoT datasets — explode query and assistant separately =====
+_THINK_RE = re.compile(r'(.*?)', re.DOTALL)
+
+
+class CotExplodeProcessor(Preprocessor):
+ """Base for CoT datasets: explode query and full assistant content as separate passages."""
+
+ def _extract_rows(self, rows: List[Dict[str, Any]]) -> List[tuple]:
+ """Subclass returns list of (query, cot, response) tuples."""
+ raise NotImplementedError
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows_list = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for query, cot, response, source in self._extract_rows(rows_list):
+ if cot:
+ response = _THINK_RE.sub('', response).strip()
+ assistant_content = f'{cot}{response}' if cot else response
+ for text in (query, assistant_content):
+ if not text or _effective_len(text) < _MIN_PASSAGE_LEN:
+ continue
+ out.append({
+ 'id': _hash_id(source, text),
+ 'source': source,
+ 'messages': [{'role': 'assistant', 'content': text}],
+ })
+ return self.map_row_to_col(out, keys=['id', 'source', 'messages'])
+
+
+# -- Chinese-DeepSeek-R1-Distill-data-110k --
+CN_R1_DISTILL_REPO = 'ms://AI-ModelScope/Chinese-DeepSeek-R1-Distill-data-110k'
+
+
+class ChineseR1DistillProcessor(CotExplodeProcessor):
+ """input → query, reasoning_content → cot, content → response."""
+
+ def _extract_rows(self, rows):
+ for row in rows:
+ query = (row.get('input') or '').strip()
+ cot = (row.get('reasoning_content') or '').strip()
+ response = (row.get('content') or '').strip()
+ if not query or not response:
+ continue
+ yield query, cot, response, 'Chinese-DeepSeek-R1-Distill-data-110k'
+
+
+# -- Opus-4.6-Reasoning-3000x-filtered --
+OPUS_REASONING_REPO = 'ms://nohurry/Opus-4.6-Reasoning-3000x-filtered'
+
+
+class OpusReasoningProcessor(CotExplodeProcessor):
+ """problem → query, thinking → cot, solution → response."""
+
+ def _extract_rows(self, rows):
+ for row in rows:
+ query = (row.get('problem') or '').strip()
+ cot = (row.get('thinking') or '').strip()
+ response = (row.get('solution') or '').strip()
+ if not query or not response:
+ continue
+ yield query, cot, response, 'Opus-4.6-Reasoning-3000x-filtered'
+
+
+# -- claude-opus-4.6-10000x --
+CLAUDE_OPUS_REPO = 'ms://Roman1111111/claude-opus-4.6-10000x'
+
+
+class ClaudeOpusProcessor(CotExplodeProcessor):
+ """messages (OpenAI format) → extract user/assistant, split or reasoning field."""
+
+ def _extract_rows(self, rows):
+ for row in rows:
+ messages = row.get('messages')
+ if not isinstance(messages, list):
+ continue
+ query = ''
+ assistant_text = ''
+ reasoning = ''
+ for msg in messages:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('role') or ''
+ content = msg.get('content') or ''
+ if not isinstance(content, str):
+ continue
+ if role == 'user' and not query:
+ query = content.strip()
+ elif role == 'assistant' and not assistant_text:
+ assistant_text = content.strip()
+ reasoning = (msg.get('reasoning') or '').strip()
+ break
+ if not query or not assistant_text:
+ continue
+ cot = reasoning
+ if not cot:
+ m = _THINK_RE.search(assistant_text)
+ if m:
+ cot = m.group(1).strip()
+ assistant_text = assistant_text[m.end():].strip()
+ response = assistant_text if not reasoning else _THINK_RE.sub('', assistant_text).strip()
+ if not response:
+ continue
+ yield query, cot, response, 'claude-opus-4.6-10000x'
+
+
+# -- angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k --
+ANGRYGIRAFFE_REPO = 'ms://hf/angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k'
+
+
+class AngrygiraffeOpusReasoningProcessor(CotExplodeProcessor):
+ """messages (OpenAI format) → extract first user/assistant, split tag."""
+
+ def _extract_rows(self, rows):
+ for row in rows:
+ messages = row.get('messages')
+ if not isinstance(messages, list):
+ continue
+ query = ''
+ assistant_text = ''
+ for msg in messages:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('role') or ''
+ content = msg.get('content') or ''
+ if not isinstance(content, str):
+ continue
+ if role == 'user' and not query:
+ query = content.strip()
+ elif role == 'assistant' and not assistant_text:
+ assistant_text = content.strip()
+ break
+ if not query or not assistant_text:
+ continue
+ m = _THINK_RE.search(assistant_text)
+ if m:
+ cot = m.group(1).strip()
+ response = assistant_text[m.end():].strip()
+ else:
+ cot = ''
+ response = assistant_text
+ if not response:
+ continue
+ yield query, cot, response, 'angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k'
+
+
+_BASE_SIZES = {
+ 'tiny_textbooks': 10000,
+ 'musique': 1000,
+ 'github_code': 30000,
+ 'competition_math': 7500,
+ 'toucan': 10000,
+ 'swe_smith': 1000,
+ 'cn_r1_distill': 10000,
+ 'opus_reasoning': 3000,
+ 'claude_opus': 10000,
+ 'angrygiraffe': 20000,
+}
+
+
+def _scaled_sizes(total: Optional[int]) -> Dict[str, int]:
+ if total is None:
+ return dict(_BASE_SIZES)
+ scale = total / sum(_BASE_SIZES.values())
+ return {k: max(1, int(round(v * scale))) for k, v in _BASE_SIZES.items()}
+
+
+def get_dataset(total: Optional[int] = None, load_from_cache_file: bool = True) -> Dataset:
+ """Build the unified compression-distillation dataset.
+
+ If ``total`` is given, every per-source row count in ``_BASE_SIZES`` is
+ scaled proportionally so the input-row sum approximates ``total``.
+ """
+ sizes = _scaled_sizes(total)
+ dataset = Dataset()
+
+ _register(dataset, TinyTextbooksProcessor,
+ DatasetMeta(dataset_id=TINY_TEXTBOOKS_REPO, split='train',
+ data_slice=range(sizes['tiny_textbooks'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, MusiqueProcessor,
+ DatasetMeta(str(_musique_jsonl), data_slice=range(sizes['musique'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, GithubCodeProcessor,
+ DatasetMeta(dataset_id=GITHUB_CODE_REPO, subset_name='all-apache-2.0', split='train'),
+ init_args={'target': sizes['github_code']},
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, MathProcessor,
+ DatasetMeta(dataset_id=COMPETITION_MATH_REPO, subset_name='default', split='train',
+ data_slice=range(sizes['competition_math'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, PassageExplodeProcessor,
+ DatasetMeta(dataset_id='ms://Agent-Ark/Toucan-1.5M', subset_name='Kimi-K2', split='train',
+ data_slice=range(sizes['toucan'])),
+ init_args={'source': 'toucan'},
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, PassageExplodeProcessor,
+ DatasetMeta(dataset_id='ms://SWE-bench/SWE-smith-trajectories', split='tool',
+ data_slice=range(sizes['swe_smith'])),
+ init_args={'source': 'swe-smith'},
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, ChineseR1DistillProcessor,
+ DatasetMeta(dataset_id=CN_R1_DISTILL_REPO, split='train',
+ data_slice=range(sizes['cn_r1_distill'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, OpusReasoningProcessor,
+ DatasetMeta(dataset_id=OPUS_REASONING_REPO, split='train',
+ data_slice=range(sizes['opus_reasoning'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, ClaudeOpusProcessor,
+ DatasetMeta(dataset_id=CLAUDE_OPUS_REPO, split='train',
+ data_slice=range(sizes['claude_opus'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, AngrygiraffeOpusReasoningProcessor,
+ DatasetMeta(dataset_id=ANGRYGIRAFFE_REPO, split='train',
+ data_slice=range(sizes['angrygiraffe'])),
+ load_from_cache_file=load_from_cache_file)
+
+ dataset.mix_dataset(False)
+ return dataset
+
+
+if __name__ == '__main__':
+ dataset = get_dataset(load_from_cache_file=True)
+ print(len(dataset))
diff --git a/cookbook/exp/condenser/make_condenser_dataset.py b/cookbook/exp/condenser/make_condenser_dataset.py
new file mode 100644
index 000000000..cf56a44e3
--- /dev/null
+++ b/cookbook/exp/condenser/make_condenser_dataset.py
@@ -0,0 +1,737 @@
+import argparse
+import hashlib
+import json
+import os
+import random
+import re
+import sys
+import threading
+from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
+from typing import Any, Dict, Iterator, List, Optional, Set
+
+from tqdm import tqdm
+
+from twinkle.data_format.sampling import SamplingParams
+from twinkle_agentic.protocol.openai import OpenAI
+
+
+# ═══════════════════════════════════════════════════════════════════════════════
+# Prompts
+# ═══════════════════════════════════════════════════════════════════════════════
+
+QUERY_GEN_SYSTEM = """\
+You are a query designer. Given a source passage, enumerate distinct information \
+queries a reader might ask of it. Each query must steer toward a meaningfully \
+DIFFERENT compression of the same source — different facets, not rephrasings of \
+the same need.
+
+Category hints (not exhaustive — combine or invent as fits the source):
+- Interface extraction (code): class / method signatures, parameter and return types
+- Functional summary: what the passage accomplishes at a high level
+- Error & pitfall analysis: bugs, anti-patterns, failure modes, edge cases
+- Experience distillation: lessons learned, best practices, do's and don'ts
+- Skill extraction (knowledge-as-skill): WHAT this passage lets you do, HOW to \
+apply it as reusable steps, WHEN to invoke it (trigger conditions / use cases)
+- Abstract analysis: design patterns, architectural decisions, trade-offs
+- Information summary: key facts, entities, numbers, relationships
+- Dependency & context: prerequisites, imports, environment, related modules
+
+Rules:
+1. SHAPE — each query is one short imperative or interrogative sentence (e.g. \
+"List all public method signatures with parameter and return types", "What race \
+conditions does this code contain?").
+2. DISTINCT — reject any pair whose answers would substantially overlap; \
+rephrasings of the same information need do NOT count as separate queries.
+3. SKILL FOR KNOWLEDGE — when the source reads as tutorial / experience / \
+how-to / domain knowledge, ALWAYS include exactly one skill-style query asking \
+what the reader can accomplish with it and how to apply it (phrased in the \
+source language).
+4. ANSWERABLE — skip queries the source cannot actually answer, and skip \
+trivial queries that would just reproduce the source verbatim.
+5. SCALE — short / single-purpose → 1; medium → 2; rich / multi-topic → 3–4. \
+Do not pad.
+6. LANGUAGE — query language MUST match the source language.
+7. OUTPUT — a single JSON array of strings; no preamble, no code fences, \
+nothing else.\
+"""
+
+QUERY_GEN_USER = 'Analyze the following text and return a JSON array of queries.\n\n{text}'
+
+COMPRESS_SYSTEM = """\
+You are a compression assistant. For the (query, source) pair, emit a Markdown \
+answer with TWO sections, designed to pair with the `extract_compressed` tool: \
+the reader absorbs `## Summary` directly, then calls `extract_compressed` \
+on any topic-key listed under `## More` to recover its \
+fuller content.
+
+ `## Summary` — extreme-density text the reader reads directly.
+ `## More` — a topic index whose keys are valid arguments \
+to `extract_compressed` for recovering material not captured inline.
+
+Together the two sections must form a COMPLETE, NON-DISTORTING inventory of the \
+source for the query — nothing essential lost, nothing implied that the source \
+does not support. NO preamble, NO meta-commentary, NO code fences wrapping the \
+whole output.
+
+Output skeleton:
+
+## Summary
+Topic:
+
+
+## More
+- :
+- ...
+
+Format selection for the inline body (pick the MOST COMPACT form per query, mix \
+when helpful):
+- Interface / signature → code notation directly: `func(a:int)->str`
+- Factual / entity → telegraphic prose; drop function words; ":" for "is", "," \
+for "has"
+- Skill / how-to / usage → lead with `Use when: `; numbered telegraphic \
+steps `1.do X 2.then Y`; close with `Output: ` when relevant
+- Procedural → numbered short steps
+- Analytical / design → hierarchical bullets with abbreviations
+
+`## Summary` rules:
+1. TOPIC LINE — line 1 is ALWAYS `Topic: `, even when the \
+query is narrow. Anchors both the reader and the tool.
+2. DENSITY — every token in the body carries query-relevant signal; cut filler.
+3. PRIMARY-COMPLETE — never silently drop a fact essential to answering the \
+query. Anything cut for length MUST appear as a key under \
+`## More`.
+4. NON-MISLEADING — phrasing must not let the reader infer anything the source \
+does not support; partial truths that mislead are worse than honest omissions \
+flagged in the index.
+5. SELF-CONTAINED — the reader can act on the answer without re-opening the source.
+6. FAITHFUL — only content the source supports; no fabrication, no extrapolation.
+7. LANGUAGE — match the source language.
+8. NO outer code fences around the whole answer; no meta-commentary.
+
+`## More` rules (MANDATORY — this section is never omitted):
+1. FORMAT — each bullet is `- : `:
+ • topic-key — short, unambiguous, grounded in source vocabulary so the \
+`extract_compressed` tool can locate the aspect (e.g. `decorators`, \
+`error handling`, `pitfalls`).
+ • hint — tells WHAT the reader gains by expanding (concrete numbers, code \
+listings, secondary cases, edge details, related context, …); do NOT restate \
+the inline answer.
+2. CRITERION — each bullet names an aspect that EXISTS in the source but is \
+NOT fully captured inline. Material that genuinely fits inline without \
+distortion MUST NOT be duplicated here.
+3. FAITHFUL — hints must be grounded in the source; never speculate or invent.
+4. ORDER — by relevance to the query, then by importance.
+5. EMPTY CASE — if the source is so short / single-purpose that everything \
+fits inline, write a single line `- (none)`.
+
+Examples:
+
+Query: List all public method signatures with parameter and return types
+Source: (a Python HTTP client class with retry decorator, structured logging, \
+and request helpers)
+## Summary
+Topic: Python HTTP client class — public surface of retried request helpers.
+retry_request(url:str, max_retries:int=3, timeout:float=10.0) -> Response
+fetch_json(endpoint:str, params:dict|None=None) -> dict
+post_data(endpoint:str, payload:dict, headers:dict|None=None) -> Response
+
+## More
+- decorators: @retry config — exponential backoff (base=2.0, max=60s)
+- logging: structured per-request logs with request_id and latency_ms
+- private helpers: _build_headers, _parse_error — not in public surface
+───
+Query: What can this passage help you accomplish, and how to use it?
+Source: (a tutorial on configuring Linux cgroups v2 caps for a systemd service)
+## Summary
+Topic: Linux cgroups v2 — per-service CPU / memory caps via systemd slice units.
+Use when: needing per-service CPU/memory caps on systemd hosts.
+1.create slice unit /etc/systemd/system/.slice with CPUQuota=, MemoryMax=
+2.attach service via Slice=.slice in [Service]
+3.systemctl daemon-reload + restart service
+4.verify: systemctl status shows Tasks/CPU/Memory inside slice
+Output: hard caps enforced by kernel cgroup v2.
+
+## More
+- pitfalls: cgroup v1/v2 mode detection, MemorySwapMax behavior on OOM
+- delegation: Delegate=yes for nested controllers in container managers
+- examples: nginx and postgres slice templates with concrete numeric caps
+- diagnostics: systemd-cgls / systemd-cgtop walkthrough
+───
+Query: 总结这段代码的错误和改进经验
+Source: (一段有 race condition 和未关闭资源的 Go 代码)
+## Summary
+Topic: Go HTTP fetch 循环 — 并发写共享 map + 未关闭响应体导致的稳定性缺陷。
+1.race: 并发写 map 未锁 → sync.RWMutex 或 sync.Map
+2.泄漏: resp.Body 未 Close → 请求后立即 defer resp.Body.Close()
+3.吞错: err 未检查 → 每处 err!=nil 必处理或上抛
+
+## More
+- (none)
+
+Now begin.\
+"""
+
+COMPRESS_USER = '## Query\n{query}\n\n## Source\n{text}'
+
+# Short system prompt embedded in emitted SFT samples — the long COMPRESS_SYSTEM
+# is for data generation only; training samples carry only the binding contract.
+COMPRESS_SYSTEM_TRAIN = """\
+You are a compression assistant. For the (query, source) pair, emit a Markdown \
+answer with TWO sections, designed to pair with the `extract_compressed` tool: \
+the reader absorbs `## Summary` directly, then calls `extract_compressed` \
+on any topic-key listed under `## More` to recover its \
+fuller content.
+
+Output skeleton:
+
+## Summary
+Topic:
+
+
+## More
+- :
+- ...
+
+Rules:
+1. Line 1 of `## Summary` is ALWAYS `Topic: ...`.
+2. Body is maximally dense; every token carries query-relevant signal.
+3. Never silently drop a fact — anything cut for length MUST appear as a key \
+under `## More` (do not duplicate inline material here).
+4. No fabrication, no extrapolation, no misleading partial truths.
+5. Match the source language. No outer code fences, no meta-commentary.\
+"""
+
+# Fixed queries — used directly (no Phase-1 LLM generation) for a proportion of items.
+FIXED_QUERY_NEED = (
+ 'What problem does this passage address, and what skill or method is needed? '
+ 'Topic must name the specific pattern, never generic labels. '
+ 'Compress into a retrieval-friendly need description.')
+FIXED_QUERY_SKILL = (
+ 'Extract the reusable skill: trigger conditions, key steps, and expected output. '
+ 'Topic names the method/pattern; format as "Use when: ...", numbered steps, '
+ '"Output: ...". Compress into a standardized procedure for retrieval.')
+FIXED_QUERIES = [FIXED_QUERY_NEED, FIXED_QUERY_SKILL]
+FIXED_QUERY_RATIO = 0.3
+
+
+# ═══════════════════════════════════════════════════════════════════════════════
+# Core logic
+# ═══════════════════════════════════════════════════════════════════════════════
+
+def _extract_json_array(text: str) -> Optional[List[str]]:
+ """Best-effort extraction of a JSON string array from LLM output."""
+ text = text.strip()
+ # Try direct parse first
+ if text.startswith('['):
+ try:
+ arr = json.loads(text)
+ if isinstance(arr, list) and all(isinstance(x, str) for x in arr):
+ return arr
+ except json.JSONDecodeError:
+ pass
+ # Fallback: find first [...] block
+ m = re.search(r'\[.*\]', text, re.DOTALL)
+ if m:
+ try:
+ arr = json.loads(m.group())
+ if isinstance(arr, list) and all(isinstance(x, str) for x in arr):
+ return arr
+ except json.JSONDecodeError:
+ pass
+ return None
+
+
+def generate_queries(api: OpenAI, text: str) -> List[str]:
+ """Phase 1: ask the LLM what queries can be asked about ``text``."""
+ trajectory = {
+ 'messages': [
+ {'role': 'system', 'content': QUERY_GEN_SYSTEM},
+ {'role': 'user', 'content': QUERY_GEN_USER.format(text=text)},
+ ]
+ }
+ sp = SamplingParams(temperature=0.7, max_tokens=1024)
+ for attempt in range(2):
+ try:
+ reply = api(trajectory, sp, extra_body={'enable_thinking': True})
+ except Exception as exc:
+ sys.stderr.write(f'[query_gen] error: {exc}\n')
+ return []
+ content = reply.get('content') or ''
+ queries = _extract_json_array(content)
+ if queries:
+ return queries
+ if attempt == 0:
+ sys.stderr.write('[query_gen] retry: failed to parse JSON array\n')
+ return []
+
+
+def compress_for_query(api: OpenAI, text: str, query: str,
+ thinking_budget: int = 1024) -> Optional[str]:
+ """Phase 2: compress ``text`` w.r.t. ``query``. Returns compressed content or None."""
+ trajectory = {
+ 'messages': [
+ {'role': 'system', 'content': COMPRESS_SYSTEM},
+ {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)},
+ ]
+ }
+ sp = SamplingParams(temperature=0.3, max_tokens=16384)
+ for attempt in range(2):
+ try:
+ reply = api(trajectory, sp, extra_body={
+ 'enable_thinking': False,
+ 'thinking_budget': thinking_budget,
+ })
+ except Exception as exc:
+ sys.stderr.write(f'[compress] error: {exc}\n')
+ return None
+ content = (reply.get('content') or '').strip()
+ if not content:
+ if attempt == 0:
+ sys.stderr.write('[compress] retry: empty response\n')
+ continue
+ # Strip whole-answer code fence if present.
+ m = re.match(r'^```[a-zA-Z]*\n(.*?)\n```\s*$', content, re.DOTALL)
+ if m:
+ content = m.group(1).strip()
+ if not (re.search(r'(?im)^##\s*Summary\b', content)
+ and re.search(r'(?im)^##\s*More\b', content)):
+ if attempt == 0:
+ sys.stderr.write('[compress] retry: missing required sections\n')
+ continue
+ return content
+ return None
+
+
+def _query_hash(query: str) -> str:
+ """Stable short hash of a query string — embedded in sample id for resume."""
+ return hashlib.md5(query.strip().encode('utf-8')).hexdigest()[:8]
+
+
+def process_item(
+ api: OpenAI,
+ item: Dict[str, Any],
+ done_sample_ids: Optional[Set[str]] = None,
+ thinking_budget: int = 1024,
+ fixed_query_ratio: float = FIXED_QUERY_RATIO,
+) -> List[Dict[str, Any]]:
+ """Run both phases on one dataset item. Returns list of SFT samples.
+
+ Input rows come from ``dataset.py`` (single assistant message) or
+ ``dataset_think.py`` (user query + assistant with reasoning_content).
+ For thinking-data rows, ``FIXED_QUERY_NEED`` is applied to the query
+ and ``FIXED_QUERY_SKILL`` to the CoT, skipping Phase-1 generation.
+
+ ``done_sample_ids`` (full sample ids already on disk for this item)
+ lets resume skip queries that were already emitted, keyed by query
+ content hash so a phase-1 reorder still resolves correctly.
+ """
+ done = done_sample_ids or set()
+ messages = item.get('messages') or []
+
+ # Detect thinking-data: user message + assistant with reasoning_content
+ user_query = ''
+ cot_text = ''
+ assistant_text = ''
+ for m in messages:
+ if not isinstance(m, dict):
+ continue
+ role = m.get('role', '')
+ if role == 'user' and not user_query:
+ user_query = (m.get('content') or '').strip()
+ elif role == 'assistant':
+ cot_text = (m.get('reasoning_content') or '').strip()
+ assistant_text = (m.get('content') or '').strip()
+ break
+
+ item_id = item.get('id')
+ if not item_id:
+ return []
+ source = item.get('source', 'unknown')
+
+ # Thinking-data path: compress query and CoT separately with fixed queries
+ if user_query and cot_text:
+ pairs = [(user_query, FIXED_QUERY_NEED), (cot_text, FIXED_QUERY_SKILL)]
+ samples: List[Dict[str, Any]] = []
+ for text, query in pairs:
+ if len(text) < 100:
+ continue
+ sample_id = f'{item_id}__{_query_hash(query)}'
+ if sample_id in done:
+ continue
+ compressed = compress_for_query(api, text, query, thinking_budget=thinking_budget)
+ if not compressed:
+ continue
+ sft_messages = [
+ {'role': 'system', 'content': COMPRESS_SYSTEM_TRAIN},
+ {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)},
+ {'role': 'assistant', 'content': compressed},
+ ]
+ samples.append({
+ 'id': sample_id,
+ 'source': source,
+ 'query': query,
+ 'original_len': len(text),
+ 'compressed_len': len(compressed),
+ 'original_tokens': 0,
+ 'compressed_tokens': 0,
+ 'messages': sft_messages,
+ '__src': text,
+ '__cmp': compressed,
+ })
+ return samples
+
+ # Plain-data path: single assistant message
+ text = assistant_text
+ if not text or len(text) < 100:
+ return []
+
+ queries = generate_queries(api, text)
+ if not queries:
+ return []
+ queries = queries[:2]
+
+ # Mix in fixed queries for a proportion of items
+ if random.random() < fixed_query_ratio:
+ queries = list(FIXED_QUERIES)
+
+ samples: List[Dict[str, Any]] = []
+ for query in queries:
+ sample_id = f'{item_id}__{_query_hash(query)}'
+ if sample_id in done:
+ continue
+ compressed = compress_for_query(api, text, query, thinking_budget=thinking_budget)
+ if not compressed:
+ continue
+ sft_messages = [
+ {'role': 'system', 'content': COMPRESS_SYSTEM_TRAIN},
+ {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)},
+ {'role': 'assistant', 'content': compressed},
+ ]
+ samples.append({
+ 'id': sample_id,
+ 'source': source,
+ 'query': query,
+ 'original_len': len(text),
+ 'compressed_len': len(compressed),
+ 'original_tokens': 0,
+ 'compressed_tokens': 0,
+ 'messages': sft_messages,
+ # Stashed for sparse tokenization on main thread; popped before write.
+ '__src': text,
+ '__cmp': compressed,
+ })
+ return samples
+
+
+def process_failure(
+ api: OpenAI,
+ item: Dict[str, Any],
+ thinking_budget: int = 1024,
+) -> List[Dict[str, Any]]:
+ """Re-compress a single failure record (id, query, text already pinned).
+
+ Used by ``--failures`` mode: query and source passage are taken verbatim
+ from the original failure entry, so Phase-1 generation is skipped and the
+ output id matches the original sample id.
+ """
+ sid = item.get('id') or ''
+ query = (item.get('query') or '').strip()
+ text = (item.get('text') or '').strip()
+ if not sid or not query or not text:
+ return []
+ compressed = compress_for_query(api, text, query, thinking_budget=thinking_budget)
+ if not compressed:
+ return []
+ sft_messages = [
+ {'role': 'system', 'content': COMPRESS_SYSTEM_TRAIN},
+ {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=text)},
+ {'role': 'assistant', 'content': compressed},
+ ]
+ return [{
+ 'id': sid,
+ 'source': item.get('source', 'failure_regen'),
+ 'query': query,
+ 'original_len': len(text),
+ 'compressed_len': len(compressed),
+ 'original_tokens': 0,
+ 'compressed_tokens': 0,
+ 'messages': sft_messages,
+ '__src': text,
+ '__cmp': compressed,
+ }]
+
+
+# ═══════════════════════════════════════════════════════════════════════════════
+# I/O helpers
+# ═══════════════════════════════════════════════════════════════════════════════
+
+def iter_input(path: str) -> Iterator[Dict[str, Any]]:
+ """Stream JSONL dataset row-by-row (no full-file load)."""
+ with open(path, 'r', encoding='utf-8') as fh:
+ for line in fh:
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ yield json.loads(line)
+ except json.JSONDecodeError:
+ continue
+
+
+def iter_dataset_py(total: Optional[int], load_from_cache_file: bool) -> Iterator[Dict[str, Any]]:
+ """Stream rows directly from ``dataset.py::get_dataset`` without any JSONL hop."""
+ # Lazy import: dataset.py triggers HF / ModelScope downloads at module load.
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+ from cookbook.exp.condenser.dataset import get_dataset
+ hf = get_dataset(total=total, load_from_cache_file=load_from_cache_file)
+ sys.stderr.write(f'Loaded dataset.py::get_dataset: {len(hf)} rows\n')
+ for row in hf:
+ yield row
+
+
+def iter_dataset_think_py(total: Optional[int], load_from_cache_file: bool) -> Iterator[Dict[str, Any]]:
+ """Stream rows from ``dataset_think.py::get_dataset`` (query + CoT data)."""
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+ from dataset_think import get_dataset
+ hf = get_dataset(total=total, load_from_cache_file=load_from_cache_file)
+ sys.stderr.write(f'Loaded dataset_think.py::get_dataset: {len(hf)} rows\n')
+ for row in hf:
+ yield row
+
+
+def iter_failures(path: str, skip_ids: Optional[Set[str]] = None) -> Iterator[Dict[str, Any]]:
+ """Stream records from a ``failures.jsonl`` for re-compression.
+
+ Each input record carries a full sample id, the original query, and a
+ user message whose body embeds the source passage after a ``## Passage``
+ or ``## Source`` header. The yielded item is shaped for ``process_failure``
+ (id, source, query, text). Items whose id is in ``skip_ids`` are skipped.
+ """
+ skip = skip_ids or set()
+ n_total = n_skipped = n_yielded = n_bad = 0
+ with open(path, 'r', encoding='utf-8') as fh:
+ for line in fh:
+ line = line.strip()
+ if not line:
+ continue
+ n_total += 1
+ try:
+ obj = json.loads(line)
+ except json.JSONDecodeError:
+ n_bad += 1
+ continue
+ sid = obj.get('id') or ''
+ if not sid:
+ n_bad += 1
+ continue
+ if sid in skip:
+ n_skipped += 1
+ continue
+ query = (obj.get('query') or '').strip()
+ user_content = ''
+ for m in obj.get('messages') or []:
+ if isinstance(m, dict) and m.get('role') == 'user':
+ user_content = m.get('content') or ''
+ break
+ text = ''
+ for sep in ('## Passage\n', '## Source\n'):
+ if sep in user_content:
+ text = user_content.split(sep, 1)[1].strip()
+ break
+ if not query or not text:
+ sys.stderr.write(f'[failures] skip {sid}: missing query/passage\n')
+ n_bad += 1
+ continue
+ n_yielded += 1
+ yield {
+ 'id': sid,
+ 'source': obj.get('source', 'failure_regen'),
+ 'query': query,
+ 'text': text,
+ }
+ sys.stderr.write(
+ f'[failures] total={n_total} yielded={n_yielded} '
+ f'resume_skipped={n_skipped} malformed={n_bad}\n')
+
+
+def load_done_sample_ids(path: str) -> Set[str]:
+ """Collect already-written full sample ids (``base__hash``) for resume."""
+ if not os.path.exists(path):
+ return set()
+ done: Set[str] = set()
+ with open(path, 'r', encoding='utf-8') as fh:
+ for line in fh:
+ try:
+ obj = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+ sid = obj.get('id', '')
+ if sid:
+ done.add(sid)
+ return done
+
+
+# ═══════════════════════════════════════════════════════════════════════════════
+# Main
+# ═══════════════════════════════════════════════════════════════════════════════
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description='Two-phase query-diverse condenser dataset builder.')
+ parser.add_argument('--input', default=None,
+ help='Optional JSONL override; default uses dataset.py::get_dataset')
+ parser.add_argument('--output', required=True,
+ help='Output JSONL file for SFT samples')
+ parser.add_argument('--total', type=int, default=0,
+ help='Total input rows for proportional scaling in dataset.py (0 = base sizes)')
+ parser.add_argument('--no-cache', action='store_true',
+ help='Disable load_from_cache_file when calling dataset.py::get_dataset')
+ parser.add_argument('--model', required=True,
+ help='API model name')
+ parser.add_argument('--api-key', default=os.environ.get('OPENAI_API_KEY'))
+ parser.add_argument('--base-url', default=os.environ.get('OPENAI_BASE_URL'))
+ parser.add_argument('--concurrency', type=int, default=32,
+ help='Number of parallel workers')
+ parser.add_argument('--limit', type=int, default=0,
+ help='Max items to process (0 = all)')
+ parser.add_argument('--thinking-budget', type=int, default=1024,
+ help='Max thinking tokens for phase-2 compress (shorter = faster, cheaper)')
+ parser.add_argument('--tokenizer', default='Qwen/Qwen3.5-4B',
+ help='HF/ModelScope tokenizer id for sparse token-ratio probe')
+ parser.add_argument('--tokenize-every', type=int, default=1000,
+ help='Tokenize one sample every N writes; others get tokens=0')
+ parser.add_argument('--fixed-query-ratio', type=float, default=FIXED_QUERY_RATIO,
+ help='Proportion of plain-data items using fixed queries instead of LLM-generated ones')
+ parser.add_argument('--source', choices=['think', 'plain', 'both'], default='think',
+ help='Data source: think=dataset_think.py (query+CoT), plain=dataset.py, both=chain both')
+ parser.add_argument('--failures', default=None,
+ help='Path to a failures.jsonl; when set, re-generate compressions for every record '
+ 'using its original (query, passage) pair and ignore --input/--source.')
+ args = parser.parse_args()
+
+ out_dir = os.path.dirname(args.output)
+ if out_dir:
+ os.makedirs(out_dir, exist_ok=True)
+
+ done_sample_ids = load_done_sample_ids(args.output)
+ # Group done sample ids by base item id so each worker only sees its slice.
+ done_per_item: Dict[str, Set[str]] = {}
+ for sid in done_sample_ids:
+ if '__' in sid:
+ base = sid.rsplit('__', 1)[0]
+ done_per_item.setdefault(base, set()).add(sid)
+ sys.stderr.write(
+ f'Resume: {len(done_sample_ids)} samples on disk across '
+ f'{len(done_per_item)} items.\n')
+
+ api = OpenAI(model=args.model, api_key=args.api_key, base_url=args.base_url)
+
+ from modelscope import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
+
+ def iter_pending() -> Iterator[Dict[str, Any]]:
+ if args.failures:
+ source_iter = iter_failures(args.failures, done_sample_ids)
+ elif args.input:
+ source_iter = iter_input(args.input)
+ else:
+ import itertools
+ sources = []
+ if args.source in ('plain', 'both'):
+ sources.append(iter_dataset_py(
+ total=args.total or None,
+ load_from_cache_file=not args.no_cache,
+ ))
+ if args.source in ('think', 'both'):
+ sources.append(iter_dataset_think_py(
+ total=args.total or None,
+ load_from_cache_file=not args.no_cache,
+ ))
+ source_iter = itertools.chain(*sources)
+ emitted = 0
+ for it in source_iter:
+ iid = it.get('id')
+ if not iid:
+ sys.stderr.write('[skip] row missing "id" field\n')
+ continue
+ if args.limit > 0 and emitted >= args.limit:
+ return
+ yield it
+ emitted += 1
+
+ write_lock = threading.Lock()
+ out_fh = open(args.output, 'a', encoding='utf-8')
+ items_done = 0
+ items_failed = 0
+ samples_emitted = 0
+ pbar = tqdm(desc='condense', unit='item', dynamic_ncols=True)
+
+ items_iter = iter_pending()
+ in_flight: Dict[Any, str] = {}
+ # Sliding window: keep ~2x concurrency tasks queued so the pool never starves.
+ window = max(args.concurrency * 2, args.concurrency + 4)
+
+ try:
+ with ThreadPoolExecutor(max_workers=args.concurrency) as ex:
+ exhausted = False
+ while True:
+ while not exhausted and len(in_flight) < window:
+ try:
+ it = next(items_iter)
+ except StopIteration:
+ exhausted = True
+ break
+ iid = it['id']
+ if args.failures:
+ fut = ex.submit(
+ process_failure, api, it, args.thinking_budget,
+ )
+ else:
+ fut = ex.submit(
+ process_item, api, it, done_per_item.get(iid),
+ args.thinking_budget, args.fixed_query_ratio,
+ )
+ in_flight[fut] = iid
+ if not in_flight:
+ break
+ done, _ = wait(list(in_flight.keys()), return_when=FIRST_COMPLETED)
+ for fut in done:
+ iid = in_flight.pop(fut)
+ try:
+ samples = fut.result()
+ except Exception as exc:
+ sys.stderr.write(f'[item {iid}] crashed: {exc}\n')
+ items_failed += 1
+ pbar.update(1)
+ continue
+ if not samples:
+ items_failed += 1
+ pbar.update(1)
+ continue
+ with write_lock:
+ for s in samples:
+ src = s.pop('__src', '')
+ cmp = s.pop('__cmp', '')
+ samples_emitted += 1
+ if (samples_emitted - 1) % args.tokenize_every == 0:
+ s['original_tokens'] = len(tokenizer(src).input_ids)
+ s['compressed_tokens'] = len(tokenizer(cmp).input_ids)
+ out_fh.write(json.dumps(s, ensure_ascii=False) + '\n')
+ out_fh.flush()
+ items_done += 1
+ pbar.set_postfix(
+ done=items_done, failed=items_failed,
+ samples=samples_emitted, refresh=False,
+ )
+ pbar.update(1)
+ finally:
+ out_fh.close()
+ pbar.close()
+
+ sys.stderr.write(
+ f'Done. items_done={items_done}, samples={samples_emitted}, '
+ f'failed={items_failed}\n')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/exp/condenser/train_condenser_ddp.py b/cookbook/exp/condenser/train_condenser_ddp.py
new file mode 100644
index 000000000..997235781
--- /dev/null
+++ b/cookbook/exp/condenser/train_condenser_ddp.py
@@ -0,0 +1,100 @@
+"""Ray LoRA SFT for the condenser model on condense_300K.
+
+Launch:
+ python cookbook/exp/train_condenser_ddp.py
+"""
+from pathlib import Path
+
+from peft import LoraConfig
+from tqdm import tqdm
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import Preprocessor
+
+logger = get_logger()
+
+MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
+DATASET_ID = 'ms://twinkle-kit/condense_300K'
+TEMPLATE_NAME = 'Qwen3_5Template'
+
+DP_SIZE = 8
+BATCH_SIZE = 8
+LEARNING_RATE = 1e-5
+GRADIENT_ACCUMULATION_STEPS = 8
+LOG_INTERVAL = 20
+EVAL_INTERVAL = 200
+EVAL_SAMPLES = 100
+NUM_EPOCHS = 1
+
+OUTPUT_DIR = './output/condenser_ddp'
+RESUME_FROM_CHECKPOINT = None
+RESUME_ONLY_MODEL = False
+IGNORE_DATA_SKIP = False
+ADAPTER_NAME = 'default'
+
+class LegacySectionRenameProcessor(Preprocessor):
+ """Rewrite legacy `## Read inline` / `## Call extract_compressed for` headers to `## Summary` / `## More`."""
+
+ _REPLACEMENTS = (
+ ('## Read inline', '## Summary'),
+ ('## Call extract_compressed for', '## More'),
+ )
+
+ def __call__(self, batch):
+ new_messages = []
+ for msgs in batch['messages']:
+ patched = []
+ for m in msgs:
+ content = m.get('content', '') or ''
+ for old, new in self._REPLACEMENTS:
+ content = content.replace(old, new)
+ patched.append({**m, 'content': content})
+ new_messages.append(patched)
+ return {'messages': new_messages}
+
+
+def build_dataset() -> Dataset:
+ dataset = Dataset(dataset_meta=DatasetMeta('/mnt/workspace/yzhao/tastelikefeet/condense_300K/train.jsonl'))
+ dataset.map(LegacySectionRenameProcessor(), remove_columns=[], num_proc=16)
+ dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID, max_length=40000, enable_thinking=False, truncation_strategy='delete')
+ dataset.encode(load_from_cache_file=True, num_proc=64)
+ return dataset
+
+
+def train():
+ device_groups = [DeviceGroup(name='model', ranks=DP_SIZE, device_type='GPU')]
+ model_mesh = DeviceMesh.from_sizes(world_size=DP_SIZE, dp_size=4, fsdp_size=2)
+ twinkle.initialize(mode='ray', nproc_per_node=DP_SIZE, groups=device_groups, global_device_mesh=model_mesh)
+
+ dataset = build_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
+
+ model = TransformersModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model')
+
+ model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
+ total_optim_steps = (len(dataloader) * NUM_EPOCHS) // GRADIENT_ACCUMULATION_STEPS
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler', num_warmup_steps=50, num_training_steps=total_optim_steps)
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+ logger.info(f'Total micro-steps: {len(dataloader) * NUM_EPOCHS}, optim steps: {total_optim_steps}')
+
+ for i in range(NUM_EPOCHS):
+ for cur_step, batch in enumerate(dataloader):
+ model.forward_backward(inputs=batch)
+ model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ if cur_step % LOG_INTERVAL == 0:
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Step {cur_step}/{len(dataloader) * NUM_EPOCHS}, metric: {metric}')
+ if cur_step % 4000 == 0:
+ model.save(f'step_{cur_step}', output_dir=OUTPUT_DIR)
+ model.save('last_checkpoint', output_dir=OUTPUT_DIR)
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/exp/condenser/untested/eval_condensed.py b/cookbook/exp/condenser/untested/eval_condensed.py
new file mode 100644
index 000000000..730aaf3a8
--- /dev/null
+++ b/cookbook/exp/condenser/untested/eval_condensed.py
@@ -0,0 +1,382 @@
+"""Evaluation: native (full ctx) vs condensed (chunk → condense → extract_condensed tool).
+
+Reuses the training-time data shape and prompt so the comparison is apples-to-apples.
+
+Launch:
+ # native baseline (full HotpotQA context, no compression, no tool)
+ python cookbook/exp/eval_condensed.py --mode native \\
+ --dataset /path/to/hotpot_dev_fullwiki.jsonl
+
+ # condensed (chunk → condense via Qwen3.5-4B-Condenser → extract_condensed tool)
+ python cookbook/exp/eval_condensed.py --mode condensed \\
+ --dataset /path/to/hotpot_dev_fullwiki.jsonl
+
+Outputs (under --out_dir / _/):
+ predictions.jsonl one row per sample with pred / gold / f1 / em / token-counts / tool-calls
+ summary.json aggregate metrics
+"""
+import argparse
+import json
+import os
+import re
+import time
+import uuid
+from collections import Counter
+from typing import Any, Dict, List, Optional
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_logger
+from twinkle.data_format import Message, SamplingParams, Trajectory
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Qwen3_5Template
+from twinkle_agentic.chunker.native import NativeChunker
+from twinkle_agentic.condenser import ModelCondenser
+from twinkle_agentic.reward.f1 import _f1_score
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+from twinkle_agentic.rollout.multi_turn_condense import MultiTurnCondenseRollout
+from twinkle_agentic.tools.tool_manager import ToolManager
+from twinkle.preprocessor.base import Preprocessor
+
+# Reuse training assets so eval and train share data shape + condensed prompt.
+from cookbook.exp.legacy.grpo_condensed import (
+ SYSTEM_PROMPT as CONDENSED_SYSTEM_PROMPT,
+ HotpotQAProcessor,
+ _BOXED_RE,
+ _last_assistant_text,
+)
+
+
+class MuSiQueProcessor(Preprocessor):
+ """MuSiQue-Ans → Trajectory adapter.
+
+ MuSiQue native schema (per row):
+ id, question, paragraphs=[{idx, title, paragraph_text, is_supporting}], answer,
+ answer_aliases=[...], answerable, question_decomposition=[...]
+
+ Maps to the same Trajectory(messages, user_data) shape that
+ :class:`HotpotQAProcessor` produces, so downstream rollout code is
+ schema-agnostic. ``ground_truth`` carries answer + answer_aliases.
+ """
+
+ def __init__(self, system: str):
+ self.system = system
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out = [self.preprocess(r) for r in rows]
+ out = [r for r in out if r is not None]
+ return self.map_row_to_col(out)
+
+ @staticmethod
+ def _format_context(paragraphs: List[Dict[str, Any]]) -> str:
+ lines = []
+ for p in paragraphs or []:
+ title = (p.get('title') or '').strip()
+ body = (p.get('paragraph_text') or '').strip()
+ if not body:
+ continue
+ lines.append(f'{title}: {body}' if title else body)
+ return '\n\n'.join(lines)
+
+ def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]:
+ if row.get('answerable') is False:
+ return None
+ question = (row.get('question') or '').strip()
+ if not question:
+ return None
+ gold_main = (row.get('answer') or '').strip()
+ aliases = row.get('answer_aliases') or []
+ gold = [g for g in dict.fromkeys([gold_main] + list(aliases)) if g]
+ if not gold:
+ return None
+ paragraphs = row.get('paragraphs') or []
+ context_block = self._format_context(paragraphs)
+ user_msg = f'Question: {question}\n\nContext:\n\n{context_block}'
+ messages = [
+ Message(role='system', content=self.system),
+ Message(role='user', content=user_msg),
+ ]
+ sf_titles = list(dict.fromkeys(
+ (p.get('title') or '').strip()
+ for p in paragraphs
+ if p.get('is_supporting') and (p.get('title') or '').strip()))
+ user_data = [('ground_truth', g) for g in gold] + [('sf_title', t) for t in sf_titles]
+ return Trajectory(messages=messages, user_data=user_data)
+
+logger = get_logger()
+
+NATIVE_SYSTEM_PROMPT = """You are a careful multi-hop QA assistant.
+
+The user message contains a Question and a Context. Read both, reason step by step,
+then commit to a final answer.
+
+## Output Format
+End your final response with \\boxed{answer}.
+Keep the boxed text short: a name, entity, date, or "yes"/"no".
+Answers not inside \\boxed{} will not be scored."""
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+ p.add_argument('--mode', choices=['native', 'condensed'], required=True)
+ p.add_argument('--dataset', required=True,
+ help='Eval set jsonl. HotpotQA or MuSiQue-Ans schema (see --dataset_format).')
+ p.add_argument('--dataset_format', choices=['hotpotqa', 'musique'], default='musique',
+ help='Schema of --dataset. MuSiQue-Ans (default) is harder multi-hop and OOD vs training.')
+ p.add_argument('--model_id', default='ms://Qwen/Qwen3.5-4B')
+ p.add_argument('--lora_path', default=None,
+ help='Optional LoRA adapter on top of model_id (e.g. trained QA LoRA).')
+ p.add_argument('--condenser_lora', default='ms://twinkle-kit/Qwen3.5-4B-Condenser')
+ p.add_argument('--limit', type=int, default=500)
+ p.add_argument('--num_gpus', type=int, default=4)
+ p.add_argument('--batch_size', type=int, default=8)
+ p.add_argument('--max_model_len', type=int, default=32768)
+ p.add_argument('--max_new_tokens', type=int, default=2048)
+ p.add_argument('--max_turns', type=int, default=4)
+ p.add_argument('--max_trajectory_tokens', type=int, default=8192)
+ p.add_argument('--chunk_size', type=int, default=1024)
+ p.add_argument('--temperature', type=float, default=0.0)
+ p.add_argument('--out_dir', default='eval_out')
+ p.add_argument('--seed', type=int, default=42)
+ return p.parse_args()
+
+
+def build_dataset(path: str, dataset_format: str, model_id: str,
+ max_length: int, limit: int, system: str) -> Dataset:
+ """Load eval JSONL and produce Trajectory rows tagged with ground_truth user_data."""
+ ds = Dataset()
+ ds.add_dataset(DatasetMeta(path))
+ if limit > 0 and len(ds) > limit:
+ ds = ds.select(range(limit))
+ ds.set_template(
+ 'Qwen3_5Template', model_id=model_id, max_length=max_length,
+ truncation_strategy='delete', enable_thinking=False)
+ if dataset_format == 'musique':
+ # MuSiQue-Ans cols (drop everything; we keep only the produced messages/user_data)
+ cols = ['id', 'question', 'paragraphs', 'answer', 'answer_aliases',
+ 'answerable', 'question_decomposition']
+ ds.map(MuSiQueProcessor(system=system), remove_columns=cols)
+ else:
+ cols = ['id', 'question', 'question_fixed', 'answers', 'original_answer',
+ 'type', 'level', 'verdict', 'reasoning', 'supporting_facts', 'context']
+ ds.map(HotpotQAProcessor(system=system), remove_columns=cols)
+ return ds
+
+
+def extract_boxed(text: str) -> Optional[str]:
+ """Pull the inner text of the LAST `\\boxed{...}` marker, brace-balanced enough for short answers."""
+ if not text:
+ return None
+ matches = _BOXED_RE.findall(text)
+ if not matches:
+ return None
+ last = matches[-1]
+ return last[len(r'\boxed{'):-1].strip()
+
+
+def best_f1_em(pred: str, golds: List[str]) -> Dict[str, float]:
+ """Max-over-references SQuAD-style F1 / EM, reusing the training reward's normalizer."""
+ if not golds:
+ return {'f1': 0.0, 'em': 0.0}
+ if not pred:
+ return {'f1': 0.0, 'em': 0.0}
+ best_f1, best_em = 0.0, 0.0
+ for g in golds:
+ f1, em = _f1_score(pred, g)
+ if f1 > best_f1:
+ best_f1 = f1
+ if em > best_em:
+ best_em = em
+ return {'f1': best_f1, 'em': best_em}
+
+
+def _user_text(traj_or_msg) -> str:
+ """Concat all text parts of the first user message — used to count original context tokens."""
+ msgs = traj_or_msg if isinstance(traj_or_msg, list) else (traj_or_msg.get('messages') or [])
+ for m in msgs:
+ role = m.get('role') if isinstance(m, dict) else getattr(m, 'role', None)
+ if role != 'user':
+ continue
+ content = m.get('content') if isinstance(m, dict) else getattr(m, 'content', None)
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ return ''.join(p.get('text') or '' for p in content if isinstance(p, dict) and p.get('type') == 'text')
+ return ''
+ return ''
+
+
+def _count_tool_calls(traj: Dict[str, Any]) -> int:
+ return sum(len(m.get('tool_calls') or [])
+ for m in (traj.get('messages') or []) if m.get('role') == 'assistant')
+
+
+def main():
+ args = parse_args()
+ run_id = time.strftime('%Y%m%d_%H%M%S') + '_' + uuid.uuid4().hex[:6]
+ out_dir = os.path.join(args.out_dir, f'{args.mode}_{run_id}')
+ os.makedirs(out_dir, exist_ok=True)
+
+ device_groups = [DeviceGroup(name='sampler', ranks=list(range(args.num_gpus)), device_type='GPU')]
+ sampler_mesh = DeviceMesh.from_sizes(world_size=args.num_gpus, dp_size=args.num_gpus)
+ twinkle.initialize(mode='ray', nproc_per_node=args.num_gpus,
+ groups=device_groups, lazy_collect=False)
+
+ system = CONDENSED_SYSTEM_PROMPT if args.mode == 'condensed' else NATIVE_SYSTEM_PROMPT
+ ds = build_dataset(args.dataset, args.dataset_format, args.model_id,
+ args.max_model_len, args.limit, system)
+ logger.info('Eval dataset: %d rows from %s (mode=%s, format=%s)',
+ len(ds), args.dataset, args.mode, args.dataset_format)
+
+ sampler = vLLMSampler(
+ model_id=args.model_id,
+ engine_args={
+ 'gpu_memory_utilization': 0.85, 'max_model_len': args.max_model_len,
+ 'max_lora_rank': 32, 'enable_lora': True,
+ 'enable_tower_connector_lora': True, 'max_loras': 5,
+ 'seed': args.seed,
+ },
+ device_mesh=sampler_mesh, remote_group='sampler')
+ sampler.set_template('Qwen3_5Template', model_id=args.model_id,
+ enable_thinking=False, max_length=args.max_model_len)
+ template = Qwen3_5Template(args.model_id, max_length=args.max_model_len, enable_thinking=False)
+
+ # stop=[''] only matters for condensed mode where the model issues tool calls
+ sampling_params = SamplingParams(
+ max_tokens=args.max_new_tokens, num_samples=1,
+ temperature=args.temperature, top_p=0.95,
+ stop=[''] if args.mode == 'condensed' else None,
+ )
+
+ if args.mode == 'condensed':
+ chunker = NativeChunker(chunk_size=args.chunk_size, passage_boundary_re=r'(?<=\n\n)')
+ # Chunk-level extraction of the question line; \A anchor avoids matching "Question:" inside passages.
+ _q_re = re.compile(r'\AQuestion:\s*(.+)')
+
+ def _q_from_chunk(chunk):
+ c = chunk.get('content')
+ if chunk.get('type') != 'text' or not isinstance(c, str):
+ return None
+ m = _q_re.search(c)
+ return m.group(1).strip() if m else None
+
+ condenser = ModelCondenser(
+ sampler=sampler, compression_ratio=2.0,
+ sampling_params=SamplingParams(max_tokens=1024, num_samples=1,
+ temperature=0.4, top_p=0.9),
+ min_chars=200, template=template,
+ lora_path=args.condenser_lora, skip_pattern=r'^Question:',
+ related_query=_q_from_chunk,
+ )
+ rollout = MultiTurnCondenseRollout(
+ sampler=sampler, template=template, tool_manager=ToolManager(),
+ chunker=chunker, condenser=condenser,
+ sampling_params=sampling_params,
+ max_turns=args.max_turns, max_trajectory_tokens=args.max_trajectory_tokens,
+ )
+ else:
+ # max_turns=1, no tools: reduces to single-turn QA over the full original context
+ rollout = MultiTurnRollout(
+ sampler=sampler, template=template, tool_manager=ToolManager(),
+ sampling_params=sampling_params,
+ max_turns=1, max_trajectory_tokens=args.max_trajectory_tokens,
+ )
+
+ dataloader = DataLoader(dataset=ds, batch_size=args.batch_size,
+ min_batch_size=1, shuffle=False)
+
+ pred_path = os.path.join(out_dir, 'predictions.jsonl')
+ pf = open(pred_path, 'w', encoding='utf-8')
+
+ agg = Counter()
+ sums = {'f1': 0.0, 'em': 0.0,
+ 'prompt_tok': 0, 'comp_tok': 0, 'orig_ctx_tok': 0,
+ 'turns': 0, 'tool_calls': 0}
+ t0 = time.time()
+
+ for batch in dataloader:
+ trajs = rollout(batch)
+
+ for src, traj in zip(batch, trajs):
+ text = _last_assistant_text(traj) or ''
+ pred = extract_boxed(text) or ''
+ golds = [v for k, v in (src.user_data or []) if k == 'ground_truth' and v]
+
+ scores = best_f1_em(pred, golds)
+ ids = traj.get('input_ids') or []
+ comp_tok = sum(1 for l in (traj.get('labels') or []) if l != -100)
+ prompt_tok = max(0, len(ids) - comp_tok)
+ tool_calls = _count_tool_calls(traj)
+
+ # Original (uncondensed) context size — feed only the user msg, not the system prompt,
+ # so the compression ratio stays comparable across modes.
+ orig_user = _user_text(src.messages)
+ orig_ctx_tok = len(template.tokenizer.encode(orig_user)) if orig_user else 0
+
+ agg['n'] += 1
+ agg['no_box'] += int(_BOXED_RE.search(text) is None)
+ agg['tool_use'] += int(tool_calls > 0)
+ sums['f1'] += scores['f1']
+ sums['em'] += scores['em']
+ sums['prompt_tok'] += prompt_tok
+ sums['comp_tok'] += comp_tok
+ sums['orig_ctx_tok'] += orig_ctx_tok
+ sums['turns'] += int(traj.get('turns') or 1)
+ sums['tool_calls'] += tool_calls
+
+ pf.write(json.dumps({
+ 'pred': pred,
+ 'gold': golds,
+ 'f1': scores['f1'],
+ 'em': scores['em'],
+ 'prompt_tok': prompt_tok,
+ 'comp_tok': comp_tok,
+ 'orig_ctx_tok': orig_ctx_tok,
+ 'tool_calls': tool_calls,
+ 'turns': int(traj.get('turns') or 1),
+ 'no_boxed': _BOXED_RE.search(text) is None,
+ 'response': text,
+ }, ensure_ascii=False) + '\n')
+
+ logger.info('[eval] %d / %d processed', agg['n'], len(ds))
+
+ pf.close()
+ wall = time.time() - t0
+ n = max(1, agg['n'])
+ summary = {
+ 'mode': args.mode,
+ 'dataset_format': args.dataset_format,
+ 'model_id': args.model_id,
+ 'lora_path': args.lora_path,
+ 'condenser_lora': args.condenser_lora if args.mode == 'condensed' else None,
+ 'dataset': args.dataset,
+ 'n_samples': agg['n'],
+ # quality
+ 'f1': sums['f1'] / n,
+ 'em': sums['em'] / n,
+ 'no_boxed_rate': agg['no_box'] / n,
+ # cost
+ 'avg_prompt_tokens': sums['prompt_tok'] / n,
+ 'avg_completion_tokens': sums['comp_tok'] / n,
+ 'avg_orig_context_tokens': sums['orig_ctx_tok'] / n,
+ 'compression_ratio': (sums['prompt_tok'] / sums['orig_ctx_tok']
+ if sums['orig_ctx_tok'] else None),
+ # tool / multi-turn behavior
+ 'avg_turns': sums['turns'] / n,
+ 'avg_tool_calls': sums['tool_calls'] / n,
+ 'tool_use_rate': agg['tool_use'] / n,
+ # wall
+ 'wall_time_sec': wall,
+ 'samples_per_sec': agg['n'] / wall if wall > 0 else 0.0,
+ }
+ with open(os.path.join(out_dir, 'summary.json'), 'w', encoding='utf-8') as f:
+ json.dump(summary, f, indent=2, ensure_ascii=False)
+
+ logger.info('Done. Output: %s', out_dir)
+ logger.info('Summary: %s', json.dumps(summary, indent=2, ensure_ascii=False))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/exp/condenser/untested/eval_condensed_compressed.sh b/cookbook/exp/condenser/untested/eval_condensed_compressed.sh
new file mode 100755
index 000000000..5567a1a3b
--- /dev/null
+++ b/cookbook/exp/condenser/untested/eval_condensed_compressed.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+# Compressed run: chunk → condense via Qwen3.5-4B-Condenser LoRA → extract_condensed tool loop.
+# Identical --dataset / --limit / --model_id as eval_condensed_native.sh for an A/B comparison.
+set -euo pipefail
+
+DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}"
+MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}"
+CONDENSER_LORA="${CONDENSER_LORA:-ms://twinkle-kit/Qwen3.5-4B-Condenser}"
+LIMIT="${LIMIT:-500}"
+NUM_GPUS="${NUM_GPUS:-4}"
+OUT_DIR="${OUT_DIR:-eval_out}"
+
+CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \
+python cookbook/exp/eval_condensed.py \
+ --mode condensed \
+ --dataset_format musique \
+ --dataset "${DATASET}" \
+ --model_id "${MODEL_ID}" \
+ --condenser_lora "${CONDENSER_LORA}" \
+ --limit "${LIMIT}" \
+ --num_gpus "${NUM_GPUS}" \
+ --batch_size 8 \
+ --max_model_len 32768 \
+ --max_new_tokens 2048 \
+ --max_turns 4 \
+ --max_trajectory_tokens 8192 \
+ --chunk_size 1024 \
+ --temperature 0.0 \
+ --out_dir "${OUT_DIR}"
diff --git a/cookbook/exp/condenser/untested/eval_condensed_native.sh b/cookbook/exp/condenser/untested/eval_condensed_native.sh
new file mode 100755
index 000000000..0849e9378
--- /dev/null
+++ b/cookbook/exp/condenser/untested/eval_condensed_native.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env bash
+# Native baseline: full original context, single-turn QA, no compression, no tools.
+# Compare against eval_condensed_compressed.sh on identical --dataset / --limit / --model_id.
+set -euo pipefail
+
+DATASET="${DATASET:-/mnt/data/yzhao/datasets/musique_ans_v1.0_dev.jsonl}"
+MODEL_ID="${MODEL_ID:-ms://Qwen/Qwen3.5-4B}"
+LIMIT="${LIMIT:-500}"
+NUM_GPUS="${NUM_GPUS:-4}"
+OUT_DIR="${OUT_DIR:-eval_out}"
+
+CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} \
+python cookbook/exp/eval_condensed.py \
+ --mode native \
+ --dataset_format musique \
+ --dataset "${DATASET}" \
+ --model_id "${MODEL_ID}" \
+ --limit "${LIMIT}" \
+ --num_gpus "${NUM_GPUS}" \
+ --batch_size 8 \
+ --max_model_len 32768 \
+ --max_new_tokens 2048 \
+ --max_trajectory_tokens 8192 \
+ --temperature 0.0 \
+ --out_dir "${OUT_DIR}"
diff --git a/cookbook/exp/embedding/dataset_think.py b/cookbook/exp/embedding/dataset_think.py
new file mode 100644
index 000000000..38618ced1
--- /dev/null
+++ b/cookbook/exp/embedding/dataset_think.py
@@ -0,0 +1,456 @@
+import hashlib
+import re
+from typing import Any, Dict, List, Optional
+
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import Preprocessor
+
+_THINK_RE = re.compile(r'(.*?)', re.DOTALL)
+
+
+def _hash_id(prefix: str, content: str) -> str:
+ return f'{prefix}__{hashlib.md5(content.encode("utf-8")).hexdigest()[:16]}'
+
+
+def _register(dataset, processor_cls, meta: DatasetMeta, init_args: Optional[Dict[str, Any]] = None,
+ load_from_cache_file: bool = True) -> None:
+ """Add dataset and run preprocessor; auto-strip every input column to enforce
+ the universal ``{id, source, query, cot, response}`` output schema."""
+ dataset.add_dataset(meta)
+ cols = list(dataset.datasets[meta.get_id()].column_names)
+ dataset.map(
+ processor_cls,
+ dataset_meta=meta,
+ init_args=init_args or {},
+ remove_columns=cols,
+ load_from_cache_file=load_from_cache_file,
+ )
+
+
+# ===== Modotte/CodeX-2M-Thinking =====
+CODEX_THINKING_REPO = 'ms://Modotte/CodeX-2M-Thinking'
+
+
+class CodeXThinkingProcessor(Preprocessor):
+ """CodeX-2M-Thinking row → ``{id, source, query, cot, response}``。
+
+ 输入 schema: ``input``(问题)、``output``(含 ``...`` + 答案)。
+ 拆分 output 为 cot(think 标签内容)和 response(标签之后的正文)。
+ 丢弃缺失 input/output 或无法解析 think 标签的行。
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('input') or '').strip()
+ output = (row.get('output') or '').strip()
+ if not query or not output:
+ continue
+ m = _THINK_RE.search(output)
+ if not m:
+ continue
+ cot = m.group(1).strip()
+ response = output[m.end():].strip()
+ if not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('codex_think', f'{query}\n{response}'),
+ 'source': 'CodeX-2M-Thinking',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===== open-thoughts/OpenThoughts3-1.2M =====
+OPEN_THOUGHTS_REPO = 'ms://open-thoughts/OpenThoughts3-1.2M'
+
+
+class OpenThoughtsProcessor(Preprocessor):
+ """OpenThoughts3 row → ``{id, source, query, cot, response}``。
+
+ 输入 schema: ``conversations`` (messages 格式 list[{from/value}])。
+ 取第一个 human 作 query,第一个 gpt 的 value 按 ``...`` 拆 cot/response。
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ convs = row.get('conversations')
+ if not isinstance(convs, list):
+ continue
+ query = ''
+ assistant_text = ''
+ for msg in convs:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('from') or msg.get('role') or ''
+ value = msg.get('value') or msg.get('content') or ''
+ if role in ('human', 'user') and not query:
+ query = value.strip()
+ elif role in ('gpt', 'assistant') and not assistant_text:
+ assistant_text = value.strip()
+ break
+ if not query or not assistant_text:
+ continue
+ m = _THINK_RE.search(assistant_text)
+ if not m:
+ continue
+ cot = m.group(1).strip()
+ response = assistant_text[m.end():].strip()
+ if not cot or not response:
+ continue
+ out.append({
+ 'id': _hash_id('openthoughts', f'{query}\n{response}'),
+ 'source': 'OpenThoughts3-1.2M',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===== GAIR/LIMO-v2 =====
+LIMO_REPO = 'ms://GAIR/LIMO-v2'
+
+
+class LIMOProcessor(Preprocessor):
+ """LIMO-v2 row → ``{id, source, query, cot, response}``。
+
+ 输入 schema: ``question``、``solution``(含 ``...`` + 答案)。
+ 拆分 solution 为 cot 和 response。
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('question') or '').strip()
+ solution = (row.get('solution') or '').strip()
+ if not query or not solution:
+ continue
+ m = _THINK_RE.search(solution)
+ if m:
+ cot = m.group(1).strip()
+ response = solution[m.end():].strip()
+ else:
+ # 无 think 标签时,solution 整体作为 response,cot 留空
+ cot = ''
+ response = solution
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('limo', f'{query}\n{response}'),
+ 'source': 'LIMO-v2',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===== AI-ModelScope/Chinese-DeepSeek-R1-Distill-data-110k =====
+CN_R1_DISTILL_REPO = 'ms://AI-ModelScope/Chinese-DeepSeek-R1-Distill-data-110k'
+
+
+class ChineseR1DistillProcessor(Preprocessor):
+ """Chinese-DeepSeek-R1-Distill row → ``{id, source, query, cot, response}``。
+
+ 输入已有三列: ``input`` → query, ``reasoning_content`` → cot, ``content`` → response。
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('input') or '').strip()
+ cot = (row.get('reasoning_content') or '').strip()
+ response = (row.get('content') or '').strip()
+ if not query or not response:
+ continue
+ if cot:
+ response = _THINK_RE.sub('', response).strip()
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('cn_r1_distill', f'{query}\n{response}'),
+ 'source': 'Chinese-DeepSeek-R1-Distill-data-110k',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===== nohurry/Opus-4.6-Reasoning-3000x-filtered =====
+OPUS_REASONING_REPO = 'ms://nohurry/Opus-4.6-Reasoning-3000x-filtered'
+
+
+class OpusReasoningProcessor(Preprocessor):
+ """Opus-4.6-Reasoning-3000x-filtered row → ``{id, source, query, cot, response}``。
+
+ 输入已有三列: ``problem`` → query, ``thinking`` → cot, ``solution`` → response。
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = (row.get('problem') or '').strip()
+ cot = (row.get('thinking') or '').strip()
+ response = (row.get('solution') or '').strip()
+ if not query or not response:
+ continue
+ if cot:
+ response = _THINK_RE.sub('', response).strip()
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('opus_reasoning', f'{query}\n{response}'),
+ 'source': 'Opus-4.6-Reasoning-3000x-filtered',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+# ===== Roman1111111/claude-opus-4.6-10000x =====
+CLAUDE_OPUS_REPO = 'ms://Roman1111111/claude-opus-4.6-10000x'
+
+
+class ClaudeOpusProcessor(Preprocessor):
+ """claude-opus-4.6-10000x row → ``{id, source, query, cot, response}``。
+
+ 输入 schema: ``messages`` (OpenAI 格式 list[{role, content}])。
+ 取首个 user 作 query,首个 assistant 按 ``...`` 拆 cot/response。
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ messages = row.get('messages')
+ if not isinstance(messages, list):
+ continue
+ query = ''
+ assistant_text = ''
+ for msg in messages:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('role') or ''
+ content = msg.get('content') or ''
+ if not isinstance(content, str):
+ continue
+ if role == 'user' and not query:
+ query = content.strip()
+ elif role == 'assistant' and not assistant_text:
+ assistant_text = content.strip()
+ break
+ if not query or not assistant_text:
+ continue
+ m = _THINK_RE.search(assistant_text)
+ if m:
+ cot = m.group(1).strip()
+ response = assistant_text[m.end():].strip()
+ else:
+ cot = ''
+ response = assistant_text
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('claude_opus', f'{query}\n{response}'),
+ 'source': 'claude-opus-4.6-10000x',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+ANGRYGIRAFFE_REPO = 'ms://hf/angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k'
+
+
+class AngrygiraffeOpusReasoningProcessor(Preprocessor):
+ """angrygiraffe/claude-opus-4.6-4.7-reasoning-8.7k row → ``{id, source, query, cot, response}``。
+
+ 输入 schema: ``messages`` (OpenAI 格式 list[{role, content}])。
+ 取首个 user 作 query,首个 assistant 按 ``...`` 拆 cot/response,仅用头一轮。
+ """
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ messages = row.get('messages')
+ if not isinstance(messages, list):
+ continue
+ query = ''
+ assistant_text = ''
+ for msg in messages:
+ if not isinstance(msg, dict):
+ continue
+ role = msg.get('role') or ''
+ content = msg.get('content') or ''
+ if not isinstance(content, str):
+ continue
+ if role == 'user' and not query:
+ query = content.strip()
+ elif role == 'assistant' and not assistant_text:
+ assistant_text = content.strip()
+ break
+ if not query or not assistant_text:
+ continue
+ m = _THINK_RE.search(assistant_text)
+ if m:
+ cot = m.group(1).strip()
+ response = assistant_text[m.end():].strip()
+ else:
+ cot = ''
+ response = assistant_text
+ if not response:
+ continue
+ out.append({
+ 'id': _hash_id('angrygiraffe_opus', f'{query}\n{response}'),
+ 'source': 'angrygiraffe-claude-opus-4.6-4.7-reasoning-8.7k',
+ 'query': query,
+ 'cot': cot,
+ 'response': response,
+ })
+ return self.map_row_to_col(out)
+
+
+_BASE_SIZES = {
+ 'codex_think': 100000,
+ 'open_thoughts': 400000,
+ 'cn_r1_distill': 100000,
+ 'opus_reasoning': 3000,
+ 'claude_opus': 10000,
+ 'angrygiraffe': 38000,
+}
+
+
+def _scaled_sizes(total: Optional[int]) -> Dict[str, int]:
+ if total is None:
+ return dict(_BASE_SIZES)
+ scale = total / sum(_BASE_SIZES.values())
+ return {k: max(1, int(round(v * scale))) for k, v in _BASE_SIZES.items()}
+
+
+def _build_dataset(total: Optional[int] = None, load_from_cache_file: bool = True) -> Dataset:
+ sizes = _scaled_sizes(total)
+ dataset = Dataset()
+
+ _register(dataset, CodeXThinkingProcessor,
+ DatasetMeta(dataset_id=CODEX_THINKING_REPO, split='train',
+ data_slice=range(sizes['codex_think'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, OpenThoughtsProcessor,
+ DatasetMeta(dataset_id=OPEN_THOUGHTS_REPO, split='train',
+ data_slice=range(sizes['open_thoughts'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, LIMOProcessor,
+ DatasetMeta(dataset_id=LIMO_REPO, split='train'),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, ChineseR1DistillProcessor,
+ DatasetMeta(dataset_id=CN_R1_DISTILL_REPO, split='train',
+ data_slice=range(sizes['cn_r1_distill'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, OpusReasoningProcessor,
+ DatasetMeta(dataset_id=OPUS_REASONING_REPO, split='train',
+ data_slice=range(sizes['opus_reasoning'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, ClaudeOpusProcessor,
+ DatasetMeta(dataset_id=CLAUDE_OPUS_REPO, split='train',
+ data_slice=range(sizes['claude_opus'])),
+ load_from_cache_file=load_from_cache_file)
+
+ _register(dataset, AngrygiraffeOpusReasoningProcessor,
+ DatasetMeta(dataset_id=ANGRYGIRAFFE_REPO, split='train',
+ data_slice=range(sizes['angrygiraffe'])),
+ load_from_cache_file=load_from_cache_file)
+
+ dataset.mix_dataset(False)
+ return dataset
+
+
+class ToMessagesProcessor(Preprocessor):
+ """Convert {query, cot, response} → {id, source, messages}."""
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ for row in rows:
+ query = row.get('query') or ''
+ cot = row.get('cot') or ''
+ response = row.get('response') or ''
+ if not cot:
+ continue
+ assistant_content = f'{cot}'
+ out.append({
+ 'id': row.get('id', ''),
+ 'source': row.get('source', ''),
+ 'messages': [
+ {'role': 'user', 'content': query},
+ {'role': 'assistant', 'content': assistant_content,
+ 'reasoning_content': cot},
+ ],
+ })
+ return self.map_row_to_col(out, keys=['id', 'source', 'messages'])
+
+
+def get_dataset(total: Optional[int] = None, dropped_log: Optional[str] = None,
+ load_from_cache_file: bool = True) -> Dataset:
+ """Build, convert to messages format, and quality-filter the CoT dataset.
+
+ If ``total`` is given, every per-source row count in ``_BASE_SIZES`` is
+ scaled proportionally so the input-row sum approximates ``total``.
+ """
+ from twinkle_agentic.preprocessor import (
+ DeadLoopFilter,
+ FixUnicodeFilter,
+ HardFilter,
+ IntentClassifier,
+ MessageSanityFilter,
+ QualityPreprocessor,
+ RefuseFilter,
+ RemoveRepeatSentencesFilter,
+ TokenNumFilter,
+ TokenSoupFilter,
+ )
+
+ dataset = _build_dataset(total=total, load_from_cache_file=load_from_cache_file)
+ dataset.map(ToMessagesProcessor(), remove_columns=['query', 'cot', 'response'],
+ load_from_cache_file=load_from_cache_file)
+ qp = QualityPreprocessor(
+ pipeline=[
+ HardFilter(),
+ RefuseFilter(),
+ DeadLoopFilter(),
+ TokenSoupFilter(),
+ MessageSanityFilter(min_turns=1, max_msg_chars=200000),
+ FixUnicodeFilter(),
+ RemoveRepeatSentencesFilter(),
+ TokenNumFilter(max_num=32768),
+ ],
+ dropped_log_path=dropped_log or '',
+ )
+ dataset.map(qp, num_proc=32, load_from_cache_file=load_from_cache_file)
+ return dataset
+
+
+if __name__ == '__main__':
+ import os
+ dropped_log = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dropped.jsonl')
+ if os.path.exists(dropped_log):
+ os.remove(dropped_log)
+ dataset = get_dataset(load_from_cache_file=False)
+ print(len(dataset))
diff --git a/cookbook/exp/embedding/train_embedding_full_ddp.py b/cookbook/exp/embedding/train_embedding_full_ddp.py
new file mode 100644
index 000000000..f5a7bc53a
--- /dev/null
+++ b/cookbook/exp/embedding/train_embedding_full_ddp.py
@@ -0,0 +1,743 @@
+"""LoRA embedding training with online condenser self-improvement.
+
+Architecture (8 GPUs total):
+ - Ranks 0-3 (``model``): Trainable embedding model with LoRA, InfoNCE loss.
+ - Ranks 4-5 (``condenser_sampler``): Frozen vLLM condenser for online compression.
+ - Ranks 6-7 (``condenser_model``): Trainable condenser with LoRA for self-improvement.
+
+When the condenser sampler truncates (stop_reason='length'), an external OpenAI-
+compatible API produces the correct compression. The failure is logged as SFT
+training data. A background thread retrains the condenser on accumulated failures
+mixed with condense_300K, then syncs weights back to the sampler.
+
+Launch:
+ python cookbook/exp/train_embedding_lora_ddp.py
+"""
+import hashlib
+import json
+import os
+import re
+import sys
+import threading
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import Any, Dict, List, Literal, Optional
+
+import swanlab
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.data_format import SamplingParams
+from twinkle.dataloader import DataLoader
+from twinkle.loss import InfonceLoss
+from twinkle.metric import EmbeddingMetric
+from twinkle.model import TransformersModel
+from twinkle.processor import InputProcessor
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Template
+from twinkle.utils.parallel import PosixFileLock
+from twinkle_agentic.protocol.openai import OpenAI as OpenAIClient
+
+sys.path.insert(0, str(Path(__file__).resolve().parent))
+from cookbook.exp.embedding.dataset_think import get_dataset # noqa: E402
+
+logger = get_logger()
+
+# -- Backend selection --------------------------------------------------------
+BACKEND: Literal['transformers', 'megatron'] = 'transformers'
+
+# Condenser (online compression + LoRA self-improvement); embedding model trains LoRA on top of MODEL_ID.
+CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2')
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+TEMPLATE_NAME = 'Qwen3_5Template'
+
+# -- GPU placement (8 total) --------------------------------------------------
+MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
+CONDENSER_SAMPLER_GPUS = int(os.environ.get('CONDENSER_SAMPLER_GPUS', 2))
+CONDENSER_MODEL_GPUS = int(os.environ.get('CONDENSER_MODEL_GPUS', 2))
+NUM_GPUS = MODEL_GPUS + CONDENSER_SAMPLER_GPUS + CONDENSER_MODEL_GPUS
+
+# -- Embedding training hyper-params ------------------------------------------
+EMB_MAX_LENGTH = 8192
+HARD_NEGATIVES = None
+TEMPERATURE = 0.03
+
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32))
+LEARNING_RATE = 1.5e-6
+GRADIENT_ACCUMULATION_STEPS = 1
+LOG_INTERVAL = 2
+SAVE_INTERVAL = 4000
+NUM_EPOCHS = 2
+
+TOTAL_SAMPLES: Optional[int] = None
+
+# -- Resume from checkpoint ---------------------------------------------------
+RESUME_CHECKPOINT = os.environ.get(
+ 'RESUME_CHECKPOINT',
+ './output/embedding_lora_transformers/step_16000')
+RESUME_STEP = int(os.environ.get('RESUME_STEP', 16000))
+
+# -- Online-compression knobs -------------------------------------------------
+# Below this length, condenser fabricates content for open-ended short prompts;
+# query passes through as qr verbatim and cot rows are dropped from training.
+MIN_TEXT_CHARS = 256
+DATASET_MAX_TOKENS = 32768
+COMPRESS_TEMPERATURE = 0.2
+COMPRESS_TOP_P = 0.5
+COMPRESS_MAX_MODEL_LEN = 32768
+
+# -- OpenAI API fallback for truncated compressions ---------------------------
+COMPRESS_API_KEY = os.environ.get('COMPRESS_API_KEY', '')
+COMPRESS_BASE_URL = os.environ.get('COMPRESS_BASE_URL', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
+COMPRESS_MODEL = os.environ.get('COMPRESS_MODEL', 'qwen3.7-max')
+
+# -- Condenser retraining knobs -----------------------------------------------
+CONDENSER_DATASET_ID = 'ms://twinkle-kit/condense_300K'
+CONDENSER_RETRAIN_SAMPLES = 128
+CONDENSER_RETRAIN_EPOCHS = 3
+CONDENSER_RETRAIN_LR = 1e-5
+
+# -- Output paths -------------------------------------------------------------
+OUTPUT_DIR = f'./output/embedding_lora_{BACKEND}'
+RESPONSE_LOG = os.environ.get('RESPONSE_LOG', f'./output/embedding_lora_{BACKEND}/responses.jsonl')
+FAILURE_LOG = os.environ.get('FAILURE_LOG', f'./output/embedding_lora_{BACKEND}/failures.jsonl')
+
+
+# =============================================================================
+# Prompts (from make_condenser_dataset.py — "## Summary" format)
+# =============================================================================
+
+COMPRESS_SYSTEM = """\
+You are a compression and summary assistant. For the (query, source) pair, emit a Markdown \
+answer with TWO sections, designed to pair with the `extract_compressed` tool: \
+the reader absorbs `## Summary` directly, then calls `extract_compressed` \
+on any topic-key listed under `## More` to recover its \
+fuller content.
+
+ `## Summary` — extreme-density text the reader reads directly.
+ `## More` — a topic index whose keys are valid arguments \
+to `extract_compressed` for recovering material not captured inline.
+
+Together the two sections must form a COMPLETE, NON-DISTORTING inventory of the \
+source for the query — nothing essential lost, nothing implied that the source \
+does not support. NO preamble, NO meta-commentary, NO code fences wrapping the \
+whole output.
+
+Output skeleton:
+
+## Summary
+Topic:
+
+
+## More
+- :
+- ...
+
+Format selection for the inline body (pick the MOST COMPACT form per query, mix \
+when helpful):
+- Interface / signature → code notation directly: `func(a:int)->str`
+- Factual / entity → telegraphic prose; drop function words; ":" for "is", "," \
+for "has"
+- Skill / how-to / usage → lead with `Use when: `; numbered telegraphic \
+steps `1.do X 2.then Y`; close with `Output: ` when relevant
+- Procedural → numbered short steps
+- Analytical / design → hierarchical bullets with abbreviations
+
+`## Summary` rules:
+1. TOPIC LINE — line 1 is ALWAYS `Topic: `, even when the \
+query is narrow. Anchors both the reader and the tool.
+2. DENSITY — every token in the body carries query-relevant signal; cut filler.
+3. PRIMARY-COMPLETE — never silently drop a fact essential to answering the \
+query. Anything cut for length MUST appear as a key under \
+`## More`.
+4. NON-MISLEADING — phrasing must not let the reader infer anything the source \
+does not support; partial truths that mislead are worse than honest omissions \
+flagged in the index.
+5. SELF-CONTAINED — the reader can act on the answer without re-opening the source.
+6. FAITHFUL — only content the source supports; no fabrication, no extrapolation.
+7. LANGUAGE — match the source language.
+8. NO outer code fences around the whole answer; no meta-commentary.
+
+`## More` rules (MANDATORY — this section is never omitted):
+1. FORMAT — each bullet is `- : `:
+ • topic-key — short, unambiguous, grounded in source vocabulary so the \
+`extract_compressed` tool can locate the aspect (e.g. `decorators`, \
+`error handling`, `pitfalls`).
+ • hint — tells WHAT the reader gains by expanding (concrete numbers, code \
+listings, secondary cases, edge details, related context, …); do NOT restate \
+the inline answer.
+2. CRITERION — each bullet names an aspect that EXISTS in the source but is \
+NOT fully captured inline. Material that genuinely fits inline without \
+distortion MUST NOT be duplicated here.
+3. FAITHFUL — hints must be grounded in the source; never speculate or invent.
+4. ORDER — by relevance to the query, then by importance.
+5. EMPTY CASE — if the source is so short / single-purpose that everything \
+fits inline, write a single line `- (none)`.
+
+Now begin.\
+"""
+
+COMPRESS_USER = (
+ 'Downstream model will read your compressed block to decide whether to '
+ 'expand it. Compress faithfully: preserve the passage topic + core facts. '
+ 'Do NOT invent facts. Do NOT drop major facts. Do NOT write meta-commentary '
+ 'about the Query (never write "Query info: absent", "no X mention", etc.); '
+ 'if the passage does not address the Query, still summarize the passage. '
+ 'CRITICAL LANGUAGE RULE: detect the dominant language of the Passage '
+ '(NOT the Query, NOT this instruction) and write the ENTIRE output in that '
+ 'same language; English passage → English output, Chinese passage → '
+ 'Chinese output, Japanese passage → Japanese output. NEVER translate, '
+ 'NEVER mix languages, NEVER copy these instructions into the output.\n\n'
+ '## Query (ordering hint only — still summarize the whole passage)\n{query}\n\n'
+ '## Passage\n{text}')
+
+
+# =============================================================================
+# Logging helpers
+# =============================================================================
+
+_response_lock: Optional[PosixFileLock] = None
+_failure_lock: Optional[PosixFileLock] = None
+
+# Monotonic global sample id; per-batch index would alias across batches.
+_sample_counter = 0
+_sample_counter_lock = threading.Lock()
+
+
+def _next_sample_id() -> int:
+ global _sample_counter
+ with _sample_counter_lock:
+ sid = _sample_counter
+ _sample_counter += 1
+ return sid
+
+
+def _log_responses(query_resp_text: str, cot_resp_text: str, idx: int,
+ query_raw: str = '', cot_raw: str = ''):
+ global _response_lock
+ if _response_lock is None:
+ os.makedirs(os.path.dirname(RESPONSE_LOG) or '.', exist_ok=True)
+ _response_lock = PosixFileLock(RESPONSE_LOG + '.lock')
+
+ record = {
+ 'idx': idx,
+ 'query_raw': query_raw,
+ 'cot_raw': cot_raw,
+ 'query_compressed': query_resp_text,
+ 'cot_compressed': cot_resp_text,
+ }
+ line = json.dumps(record, ensure_ascii=False, default=str) + '\n'
+ with _response_lock:
+ with open(RESPONSE_LOG, 'a', encoding='utf-8') as f:
+ f.write(line)
+
+
+def _log_failure(source_text: str, query: str, compressed: str, batch_idx: int):
+ global _failure_lock
+ if _failure_lock is None:
+ os.makedirs(os.path.dirname(FAILURE_LOG) or '.', exist_ok=True)
+ _failure_lock = PosixFileLock(FAILURE_LOG + '.lock')
+
+ qhash = hashlib.md5(query.strip().encode('utf-8')).hexdigest()[:8]
+ record = {
+ 'id': f'{batch_idx}__{qhash}',
+ 'source': 'online_failure',
+ 'query': query,
+ 'original_len': len(source_text),
+ 'compressed_len': len(compressed),
+ 'messages': [
+ {'role': 'system', 'content': COMPRESS_SYSTEM},
+ {'role': 'user', 'content': COMPRESS_USER.format(query=query, text=source_text)},
+ {'role': 'assistant', 'content': compressed},
+ ],
+ }
+ line = json.dumps(record, ensure_ascii=False, default=str) + '\n'
+ with _failure_lock:
+ with open(FAILURE_LOG, 'a', encoding='utf-8') as f:
+ f.write(line)
+
+
+# =============================================================================
+# Model builders
+# =============================================================================
+
+def build_model(device_mesh: DeviceMesh):
+ model_id = RESUME_CHECKPOINT if RESUME_CHECKPOINT else MODEL_ID
+ if BACKEND == 'transformers':
+ model = TransformersModel(
+ model_id=model_id,
+ device_mesh=device_mesh,
+ remote_group='model',
+ ddp_config={'find_unused_parameters': True},
+ )
+ from twinkle.patch.no_split_modules import NoSplitModulesPatch
+ model.apply_patch(NoSplitModulesPatch({'Qwen3_5DecoderLayer'}))
+ return model
+ if BACKEND == 'megatron':
+ from twinkle.model import MegatronModel
+ return MegatronModel(
+ model_id=MODEL_ID,
+ device_mesh=device_mesh,
+ remote_group='model',
+ mixed_precision='bf16',
+ variable_seq_lengths=True,
+ )
+ raise ValueError(f'Unknown BACKEND={BACKEND!r}')
+
+
+def setup_optimizer(model, total_steps: int):
+ if BACKEND == 'transformers':
+ model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=200,
+ num_training_steps=total_steps,
+ )
+ return
+ if BACKEND == 'megatron':
+ model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE)
+ model.set_lr_scheduler(
+ scheduler_cls='default',
+ lr_warmup_steps=50,
+ lr_decay_steps=total_steps,
+ )
+ return
+ raise ValueError(f'Unknown BACKEND={BACKEND!r}')
+
+
+def save_checkpoint(model, name: str):
+ model.save(name, output_dir=OUTPUT_DIR)
+
+
+# =============================================================================
+# Compression prompt building
+# =============================================================================
+
+EMBED_QUERY_Q = (
+ 'What problem does this passage address, and what skill or method is needed? '
+ 'Topic must name the specific pattern, never generic labels. '
+ 'Compress into a retrieval-friendly need description.')
+EMBED_QUERY_COT = (
+ 'Extract the reusable skill: trigger conditions, key steps, and expected output. '
+ 'Topic names the method/pattern; format as "Use when: ...", numbered steps, '
+ '"Output: ...". Compress into a standardized procedure for retrieval.')
+
+
+def _extract_query_cot(row: Dict[str, Any]):
+ messages = row.get('messages') or []
+ query, cot = '', ''
+ for m in messages:
+ if not isinstance(m, dict):
+ continue
+ role = m.get('role') or ''
+ if role == 'user' and not query:
+ query = (m.get('content') or '').strip()
+ elif role == 'assistant':
+ cot = (m.get('reasoning_content') or '').strip()
+ break
+ return query, cot
+
+
+def _build_compress_prompts(rows: List[Dict[str, Any]]) -> tuple:
+ """Build prompts for compressing both query and cot per row.
+
+ Returns (prompts, valid_indices, raw_pairs, prompt_queries, passthrough) where:
+ - prompts: flat-interleaved [query_0, cot_0, query_1, cot_1, ...]; ``None`` means
+ passthrough (use raw text directly, do not call sampler)
+ - valid_indices: which rows passed the min-length filter
+ - raw_pairs: [(query, cot), ...]
+ - prompt_queries: the query string used for each prompt (for failure logging)
+ - passthrough: parallel to prompts; non-None text means "use this verbatim as qc"
+ """
+ prompts: List[Optional[Dict[str, Any]]] = []
+ valid_indices: List[int] = []
+ raw_pairs: List[tuple] = []
+ prompt_queries: List[str] = []
+ passthrough: List[Optional[str]] = []
+ for i, row in enumerate(rows):
+ query, cot = _extract_query_cot(row)
+ if not query or len(cot) < MIN_TEXT_CHARS:
+ continue
+ valid_indices.append(i)
+ raw_pairs.append((query, cot))
+ # Short query bypasses condenser to avoid skeleton-induced hallucination.
+ if len(query) < MIN_TEXT_CHARS:
+ prompts.append(None)
+ passthrough.append(query)
+ else:
+ user = COMPRESS_USER.format(query=EMBED_QUERY_Q, text=query)
+ prompts.append({'messages': [
+ {'role': 'system', 'content': COMPRESS_SYSTEM},
+ {'role': 'user', 'content': user},
+ ]})
+ passthrough.append(None)
+ prompt_queries.append(EMBED_QUERY_Q)
+ user = COMPRESS_USER.format(query=EMBED_QUERY_COT, text=cot)
+ prompts.append({'messages': [
+ {'role': 'system', 'content': COMPRESS_SYSTEM},
+ {'role': 'user', 'content': user},
+ ]})
+ prompt_queries.append(EMBED_QUERY_COT)
+ passthrough.append(None)
+ return prompts, valid_indices, raw_pairs, prompt_queries, passthrough
+
+
+def _get_first_feature(decoded_text: str, template: Template, role: str) -> Optional[Dict[str, Any]]:
+ if not decoded_text:
+ return None
+ if role == 'anchor':
+ feat = template.encode({'messages': [
+ {'role': 'user', 'content': decoded_text},
+ {'role': 'assistant', 'content': 'Match the correct response here.'},
+ ]})
+ feat['labels'] = [1]
+ else:
+ feat = template.encode({'messages': [
+ {'role': 'user', 'content': 'Match the correct query here.'},
+ {'role': 'assistant', 'content': decoded_text},
+ ]})
+ feat['labels'] = [0]
+ return feat
+
+
+# =============================================================================
+# OpenAI API fallback
+# =============================================================================
+
+def _is_truncated_compression(text: str) -> bool:
+ """Detect structurally incomplete output that vLLM may report as stop_reason='stop'.
+
+ The condenser sometimes emits a chat-template token mid-skeleton (which we then
+ strip), so the visible text ends mid-sentence even though stop_reason!='length'.
+ The COMPRESS_SYSTEM skeleton mandates a `## More` section ending in a bullet list;
+ its absence is an unambiguous truncation signal.
+ """
+ if not text or not text.strip():
+ return True
+ if '## More' not in text or '## Summary' not in text:
+ return True
+ after_more = text.split('## More', 1)[1].strip()
+ if not after_more:
+ return True
+ last_line = after_more.splitlines()[-1].strip()
+ if not (last_line.startswith('-') or last_line.endswith(')')):
+ return True
+ return False
+
+
+def _api_compress(api_client: OpenAIClient, prompt: Dict[str, Any]) -> Optional[str]:
+ """Call external API to compress when vLLM truncates."""
+ trajectory = {'messages': prompt['messages']}
+ # Cap max_tokens to leave ample prompt headroom inside the API model context.
+ sp = SamplingParams(temperature=0.2, max_tokens=8192)
+ try:
+ reply = api_client(trajectory, sp, extra_body={'enable_thinking': False})
+ except Exception as exc:
+ logger.warning(f'[api_fallback] error: {exc}')
+ return None
+ content = (reply.get('content') or '').strip()
+ if not content:
+ return None
+ # Strip outer code fence if present
+ m = re.match(r'^```[a-zA-Z]*\n(.*?)\n```\s*$', content, re.DOTALL)
+ if m:
+ content = m.group(1).strip()
+ return content
+
+
+# =============================================================================
+# Condenser Retrainer (background thread)
+# =============================================================================
+
+class CondenserRetrainer:
+ """Async condenser self-improvement: retrains from failures, syncs to sampler."""
+
+ def __init__(self, condenser_model, ckpt_manager: CheckpointEngineManager,
+ condenser_sampler):
+ self._model = condenser_model
+ self._ckpt_manager = ckpt_manager
+ self._sampler = condenser_sampler
+ self._signal = threading.Event()
+ self._stop = threading.Event()
+ self._thread = threading.Thread(target=self._loop, daemon=True)
+ self._condense_300k_cache = None
+ self._retrain_count = 0
+ # Prevents sample() and sync_weights() from running concurrently
+ self.sampler_lock = threading.Lock()
+
+ def start(self):
+ self._thread.start()
+
+ def stop(self):
+ self._stop.set()
+ self._signal.set()
+ self._thread.join(timeout=10)
+
+ def notify_failure(self):
+ self._signal.set()
+
+ def _loop(self):
+ while not self._stop.is_set():
+ self._signal.wait(timeout=60)
+ if self._stop.is_set():
+ break
+ if not self._signal.is_set():
+ continue
+ self._signal.clear()
+ try:
+ self._retrain_and_sync()
+ except Exception as exc:
+ logger.error(f'[condenser_retrain] crashed: {exc}')
+
+ def _retrain_and_sync(self):
+ # Retrain + sync temporarily disabled; failures.jsonl is written directly by _log_failure.
+ pass
+
+
+# =============================================================================
+# Main training
+# =============================================================================
+
+def train():
+ # -------- Device groups (3 groups) ----------------------------------------
+ device_groups = [
+ DeviceGroup(name='model',
+ ranks=list(range(MODEL_GPUS)),
+ device_type='GPU'),
+ DeviceGroup(name='condenser_sampler',
+ ranks=list(range(MODEL_GPUS, MODEL_GPUS + CONDENSER_SAMPLER_GPUS)),
+ device_type='GPU'),
+ DeviceGroup(name='condenser_model',
+ ranks=list(range(MODEL_GPUS + CONDENSER_SAMPLER_GPUS, NUM_GPUS)),
+ device_type='GPU'),
+ ]
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+ condenser_sampler_mesh = DeviceMesh.from_sizes(
+ world_size=CONDENSER_SAMPLER_GPUS, dp_size=CONDENSER_SAMPLER_GPUS)
+ condenser_model_mesh = DeviceMesh.from_sizes(
+ world_size=CONDENSER_MODEL_GPUS, dp_size=1, fsdp_size=CONDENSER_MODEL_GPUS)
+
+ twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups)
+
+ # -------- Data -----------------------------------------------------------
+ dataset = get_dataset(total=TOTAL_SAMPLES, load_from_cache_file=True)
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
+ total_forward_steps = len(dataloader) * NUM_EPOCHS
+ optimizer_steps = total_forward_steps // GRADIENT_ACCUMULATION_STEPS
+
+ # -------- Embedding model (4 GPU) ----------------------------------------
+ model = build_model(model_mesh)
+ model.set_processor(InputProcessor)
+ model.set_loss(InfonceLoss, temperature=TEMPERATURE, use_batch=True,
+ hard_negatives=HARD_NEGATIVES)
+ setup_optimizer(model, optimizer_steps)
+ model.add_metric(EmbeddingMetric, is_training=True)
+
+ # -------- Condenser sampler (2 GPU, vLLM) --------------------------------
+ emb_template = Template(model_id=MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False)
+ # Special tokens come from the condenser tokenizer because the leak we strip is in its decoded output.
+ condenser_template = Template(model_id=CONDENSE_MODEL_ID, max_length=DATASET_MAX_TOKENS,
+ enable_thinking=False)
+ _special_tokens = set(condenser_template.processor.all_special_tokens)
+ condenser_sampler = vLLMSampler(
+ model_id=CONDENSE_MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': 0.8,
+ 'max_model_len': COMPRESS_MAX_MODEL_LEN,
+ },
+ device_mesh=condenser_sampler_mesh,
+ remote_group='condenser_sampler',
+ )
+ condenser_sampler.set_template(
+ TEMPLATE_NAME, model_id=CONDENSE_MODEL_ID, enable_thinking=False,
+ truncation_strategy='delete', max_length=DATASET_MAX_TOKENS)
+ compress_params = SamplingParams(
+ max_tokens=8192,
+ temperature=COMPRESS_TEMPERATURE,
+ top_p=COMPRESS_TOP_P,
+ num_samples=1,
+ )
+
+ # -------- Condenser model (2 GPU, trainable full-param) -------------------
+ condenser_model = TransformersModel(
+ model_id=CONDENSE_MODEL_ID,
+ device_mesh=condenser_model_mesh,
+ remote_group='condenser_model',
+ )
+ condenser_model.set_optimizer(optimizer_cls='AdamW', lr=CONDENSER_RETRAIN_LR)
+
+ # -------- CheckpointEngineManager: condenser_model → condenser_sampler ---
+ condenser_ckpt_manager = CheckpointEngineManager(
+ model=condenser_model, sampler=condenser_sampler)
+ condenser_ckpt_manager.sync_weights()
+
+ # -------- Background retrainer -------------------------------------------
+ retrainer = CondenserRetrainer(condenser_model, condenser_ckpt_manager,
+ condenser_sampler)
+ retrainer.start()
+
+ # -------- OpenAI API client for fallback ---------------------------------
+ api_client = OpenAIClient(
+ model=COMPRESS_MODEL,
+ api_key=COMPRESS_API_KEY,
+ base_url=COMPRESS_BASE_URL,
+ )
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+ logger.info(f'Total forward steps: {total_forward_steps}, optimizer steps: {optimizer_steps}')
+ if RESUME_STEP > 0:
+ logger.info(f'Resuming from step {RESUME_STEP}, checkpoint: {RESUME_CHECKPOINT}')
+ logger.info(f'Starting at epoch {RESUME_STEP // (total_forward_steps // NUM_EPOCHS)}, '
+ f'skipping {RESUME_STEP - (RESUME_STEP // (total_forward_steps // NUM_EPOCHS)) * (total_forward_steps // NUM_EPOCHS)} batches')
+
+ swanlab.init(project='twinkle', config={
+ 'backend': BACKEND,
+ 'model_id': MODEL_ID,
+ 'condense_model_id': CONDENSE_MODEL_ID,
+ 'batch_size': BATCH_SIZE,
+ 'lr': LEARNING_RATE,
+ 'temperature': TEMPERATURE,
+ 'emb_max_length': EMB_MAX_LENGTH,
+ 'DATASET_MAX_TOKENS': DATASET_MAX_TOKENS,
+ })
+
+ # -------- Train loop -----------------------------------------------------
+ def _sample_batch(raw_batch):
+ """Compress via vLLM sampler; fall back to API on truncation."""
+ compress_prompts, valid_indices, raw_pairs, prompt_queries, passthrough = \
+ _build_compress_prompts(raw_batch)
+ if not compress_prompts:
+ return None
+
+ # Only submit non-passthrough prompts to the sampler.
+ sampler_input = [p for p in compress_prompts if p is not None]
+ sampler_pos = [ri for ri, p in enumerate(compress_prompts) if p is not None]
+ if sampler_input:
+ with retrainer.sampler_lock:
+ sampler_responses = condenser_sampler.sample(sampler_input, compress_params)
+ else:
+ sampler_responses = []
+ responses = [None] * len(compress_prompts)
+ for resp, pos in zip(sampler_responses, sampler_pos):
+ responses[pos] = resp
+
+ # Extract decoded texts; detect truncations and fall back to API
+ decoded_texts: List[str] = []
+ for ri in range(len(compress_prompts)):
+ if passthrough[ri] is not None:
+ decoded_texts.append(passthrough[ri])
+ continue
+ resp = responses[ri]
+ seq = resp.sequences[0] if resp and resp.sequences else None
+ text = ''
+ if seq and seq.stop_reason != 'length' and seq.decoded:
+ text = seq.decoded
+ for tok in _special_tokens:
+ text = text.replace(tok, '')
+ text = text.rstrip()
+
+ # Premature-EOS: model emits chat-template token mid-skeleton, vLLM reports
+ # stop_reason='stop' but the stripped text is structurally incomplete.
+ needs_fallback = (not seq or seq.stop_reason == 'length'
+ or _is_truncated_compression(text))
+ if not needs_fallback:
+ decoded_texts.append(text)
+ continue
+
+ api_result = _api_compress(api_client, compress_prompts[ri])
+ # Skip logging when the API itself produced truncated output: an incomplete
+ # gold answer would teach the condenser to imitate broken outputs.
+ if api_result and not _is_truncated_compression(api_result):
+ decoded_texts.append(api_result)
+ pair_idx = ri // 2
+ q_raw, c_raw = raw_pairs[pair_idx]
+ source_text = q_raw if ri % 2 == 0 else c_raw
+ _log_failure(source_text, prompt_queries[ri], api_result,
+ valid_indices[pair_idx])
+ retrainer.notify_failure()
+ else:
+ decoded_texts.append('')
+
+ # Build embedding features from decoded texts
+ emb_features: List[Dict[str, Any]] = []
+ for i in range(0, len(decoded_texts), 2):
+ q_text = decoded_texts[i]
+ c_text = decoded_texts[i + 1]
+ q_raw, c_raw = raw_pairs[i // 2]
+ _log_responses(q_text, c_text, _next_sample_id(),
+ query_raw=q_raw, cot_raw=c_raw)
+ feat_q = _get_first_feature(q_text, emb_template, role='anchor')
+ feat_c = _get_first_feature(c_text, emb_template, role='positive')
+ if feat_q and feat_c:
+ emb_features.append(feat_q)
+ emb_features.append(feat_c)
+
+ if len(emb_features) < 4:
+ return None
+ return emb_features
+
+ cur_step = RESUME_STEP
+ # Compute which epoch and how many batches to skip within that epoch
+ _batches_per_epoch = len(dataloader)
+ _start_epoch = cur_step // _batches_per_epoch if cur_step > 0 else 0
+ _skip_batches_in_epoch = cur_step - _start_epoch * _batches_per_epoch if cur_step > 0 else 0
+
+ prefetch_executor = ThreadPoolExecutor(max_workers=1)
+ for epoch in range(_start_epoch, NUM_EPOCHS):
+ # Skip consumed samples for the resume epoch (shuffle order won't match
+ # exactly, but the correct number of samples is skipped).
+ if _skip_batches_in_epoch > 0:
+ dataloader.skip_consumed_samples(_skip_batches_in_epoch * BATCH_SIZE)
+ batch_iter = iter(dataloader)
+ # Reset skip after first resumed epoch
+ _skip_batches_in_epoch = 0
+ prefetch_future = None
+ first_batch = next(batch_iter, None)
+ if first_batch is not None:
+ prefetch_future = prefetch_executor.submit(_sample_batch, first_batch)
+
+ for raw_batch in batch_iter:
+ emb_features = prefetch_future.result() if prefetch_future else None
+ prefetch_future = prefetch_executor.submit(_sample_batch, raw_batch)
+
+ if emb_features is None:
+ continue
+
+ model.forward_backward(inputs=emb_features, task='embedding')
+ model.clip_grad_and_step(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ cur_step += 1
+
+ if cur_step % LOG_INTERVAL == 0:
+ metric = model.calculate_metric(is_training=True)
+ logger.info(
+ f'Epoch {epoch} Step {cur_step}/{total_forward_steps}, metric: {metric}')
+ log_dict = {}
+ for k, v in metric.items():
+ if not v:
+ continue
+ try:
+ log_dict[k] = float(v)
+ except (ValueError, TypeError):
+ pass
+ log_dict['epoch'] = epoch
+ swanlab.log(log_dict, step=cur_step)
+ if cur_step % SAVE_INTERVAL == 0:
+ save_checkpoint(model, f'step_{cur_step}')
+
+ # # Drain last prefetched batch
+ # if prefetch_future is not None:
+ # emb_features = prefetch_future.result()
+ # if emb_features is not None:
+ # model.forward_backward(inputs=emb_features, task='embedding')
+ # model.clip_grad_and_step()
+ # cur_step += 1
+
+ prefetch_executor.shutdown(wait=False)
+ retrainer.stop()
+ save_checkpoint(model, 'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/exp/legacy/grpo_baseline.py b/cookbook/exp/legacy/grpo_baseline.py
new file mode 100644
index 000000000..237f9b065
--- /dev/null
+++ b/cookbook/exp/legacy/grpo_baseline.py
@@ -0,0 +1,593 @@
+"""HotpotQA GRPO baseline — full context, no chunking, no compression, no tools.
+
+This is the **control group** for ``grpo_condensed.py``. Both scripts share:
+ * dataset (HotpotQA fullwiki, hard split)
+ * preprocessing (``HotpotQAProcessor`` with ``[K] Title: ...`` passages)
+ * GRPO infra (model / sampler / device mesh / hyperparams)
+ * rollout class (``MultiTurnRollout`` from ``multi_turn.py``)
+
+The only differences are intentional:
+ * no ``NativeChunker`` / ``ModelCondenser`` (full passages go in verbatim)
+ * no tools registered (``ToolManager()`` is empty)
+ * ``max_turns=1`` so the rollout is effectively single-turn
+ * simplified system prompt (no ```` / ``extract_condensed`` syntax)
+ * ``F1Reward + CoTReward`` only (no ``ToolExploreReward``)
+ * traces → ``rollout_trace_baseline.jsonl``
+ * checkpoints prefixed ``hotpotqa-grpo-baseline-*``
+
+Keeping the same ``MultiTurnRollout`` code path on both sides means any
+training-loop-level discrepancy between the two runs is attributable to
+the chunk+condense pipeline, not to differences in rollout plumbing.
+"""
+
+import math
+import os
+import re
+from typing import Any, Dict, List, Optional
+
+import swanlab
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup, get_logger
+from twinkle.advantage import GRPOAdvantage
+from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.data_format import Message, SamplingParams, Trajectory
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.metric import CompletionRewardMetric
+from twinkle.model import TransformersModel
+from twinkle.preprocessor.base import Preprocessor
+from twinkle.processor import InputProcessor
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Qwen3_5Template
+from twinkle_agentic.reward import F1Reward, CoTReward
+from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
+
+MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
+SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
+NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
+
+NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
+MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
+LEARNING_RATE = float(os.environ.get('LR', 1e-5))
+NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', 1))
+MAX_STEPS = int(os.environ.get('MAX_STEPS', 0))
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8))
+MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8))
+MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
+GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
+ADAPTER_NAME = 'default'
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
+LORA_RANK = int(os.environ.get('LORA_RANK', 16))
+
+# Single-turn baseline; tools are not registered, but we keep MultiTurnRollout
+# to share the rollout code path with the condensed variant. ``max_turns=1``
+# guarantees the loop runs exactly one sampling pass per trajectory.
+MAX_TURNS = int(os.environ.get('MAX_TURNS', 1))
+
+HOTPOTQA_NUM_PROC = int(os.environ.get('HOTPOTQA_NUM_PROC', 16))
+HOTPOTQA_MAX_LENGTH = int(os.environ.get('HOTPOTQA_MAX_LENGTH', 64000))
+
+F1_REWARD_WEIGHT = float(os.environ.get('F1_REWARD_WEIGHT', 1.0))
+COT_REWARD_WEIGHT = float(os.environ.get('COT_REWARD_WEIGHT', 0.2))
+
+# KL penalty coefficient; 0 disables KL (and skips the ref forward pass entirely).
+KL_BETA = float(os.environ.get('KL_BETA', 0.02))
+
+# Entropy bonus coefficient; 0 disables entropy compute path.
+ENTROPY_COEF = float(os.environ.get('ENTROPY_COEF', 0.0))
+
+# CISPO token-level IS clamp thresholds (asymmetric: 0.2 / 0.28).
+CISPO_EPS_LOW = float(os.environ.get('CISPO_EPS_LOW', 0.2))
+CISPO_EPS_HIGH = float(os.environ.get('CISPO_EPS_HIGH', 0.2))
+
+# High-KL token capture: top-K per microbatch dumped into log_dict['_high_kl_records']. 0 = disabled.
+HIGH_KL_TOPK = int(os.environ.get('HIGH_KL_TOPK', 0))
+
+DATASET_PATH = os.environ.get(
+ 'DATASET_PATH',
+ os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
+ 'hotpotqa_fullwiki_reannotated_12k.jsonl'))
+F1_BINARY_THRESHOLD = float(os.environ.get('F1_BINARY_THRESHOLD', 0.5))
+
+_ROLLOUT_TRACE_DIR = os.environ.get(
+ 'ROLLOUT_TRACE_BASELINE_DIR', 'rollout_trace_baseline')
+
+SYSTEM_PROMPT = """You are a careful multi-hop QA assistant.
+
+You will receive a question and a set of supporting passages. Each passage \
+is shown inline as plain text in the form `[K] Title: ...`, where `K` is the \
+passage index. All passages are already complete — there is no extraction \
+or expansion step.
+
+## Workflow
+
+Step 1: Read every passage and identify which ones are relevant to the question.
+Step 2: Reason step by step, citing the passage indices you used.
+ Step N: From passage [K], I learn that [fact A].
+ Step N+1: From passage [M], I learn that [fact B].
+ Step N+2: Combining these, the answer is ...
+Step 3: Emit the final answer in `\\boxed{...}`.
+
+Only answer when you are confident in the supporting facts.
+
+## Output Format
+End your final response with \\boxed{answer}, e.g. \\boxed{Delhi}.
+Keep the boxed text short: a name, entity, date, or "yes"/"no".
+Answers not inside \\boxed{} will not be scored."""
+
+
+_F1_REWARD: Optional[F1Reward] = F1Reward()
+_COT_REWARD: Optional[CoTReward] = CoTReward()
+
+
+def compute_rewards(trajectories: List[Dict[str, Any]]):
+ f1_raw = _F1_REWARD(trajectories)
+ f1 = [1.0 if v >= F1_BINARY_THRESHOLD else 0.0 for v in f1_raw] if F1_BINARY_THRESHOLD > 0 else f1_raw
+ cot = _COT_REWARD(trajectories)
+ total = [
+ F1_REWARD_WEIGHT * a + COT_REWARD_WEIGHT * c
+ for a, c in zip(f1, cot)
+ ]
+ return total, f1, cot
+
+
+class HotpotQAProcessor(Preprocessor):
+ """Preprocessor for the reannotated HotpotQA JSONL. Passages are emitted
+ as ``[K] Title: ...`` lines. Rows with ``verdict='drop'`` are excluded;
+ ``question_fixed`` is used in place of ``question`` when present."""
+
+ def __init__(self, system: str = SYSTEM_PROMPT):
+ self.system = system
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ rows = [self.preprocess(row) for row in rows]
+ rows = [r for r in rows if r is not None]
+ rows = self.map_row_to_col(rows)
+ return rows
+
+ @staticmethod
+ def _format_context(context: Dict[str, Any]) -> str:
+ titles = context.get('title', []) or []
+ sentences = context.get('sentences', []) or []
+ lines = []
+ for i, (title, sents) in enumerate(zip(titles, sentences), start=1):
+ if isinstance(sents, list):
+ body = ' '.join(s.strip() for s in sents if s and s.strip())
+ else:
+ body = str(sents).strip()
+ lines.append(f'[{i}] {title}: {body}')
+ return '\n\n'.join(lines)
+
+ def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]:
+ if (row.get('verdict') or '').strip().lower() == 'drop':
+ return None
+ question = row.get('question_fixed') or row['question']
+ answers = row.get('answers')
+ if isinstance(answers, list) and answers:
+ golds = [str(a).strip() for a in answers if str(a).strip()]
+ else:
+ golds = [s for s in [(row.get('answer', '') or '').strip()] if s]
+ context_block = self._format_context(row.get('context', {}) or {})
+ user_msg = f'Question: {question}\n\nContext:\n\n{context_block}'
+ messages = [
+ Message(role='system', content=self.system),
+ Message(role='user', content=user_msg),
+ ]
+ return Trajectory(messages=messages, user_data=[('ground_truth', g) for g in golds])
+
+
+def create_hotpotqa_dataset() -> Dataset:
+ dataset = Dataset()
+ dataset.add_dataset(DatasetMeta(DATASET_PATH))
+ logger.info('[dataset] loaded %s: %d rows', DATASET_PATH, len(dataset))
+
+ dataset.set_template(
+ 'Qwen3_5Template', model_id=MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH,
+ truncation_strategy='delete', enable_thinking=False)
+ _HOTPOTQA_COLS = ['id', 'question', 'question_fixed', 'answers',
+ 'original_answer', 'type', 'level', 'verdict',
+ 'reasoning', 'supporting_facts', 'context']
+ dataset.map(HotpotQAProcessor(system=SYSTEM_PROMPT),
+ remove_columns=_HOTPOTQA_COLS)
+ return dataset
+
+
+# Matches a LaTeX ``\boxed{...}`` final-answer marker — used to flag
+# rollouts that never committed an answer. Brace-balanced is overkill for
+# a logging heuristic; a non-greedy ``[^}]*`` is good enough.
+_BOXED_RE = re.compile(r'\\boxed\{[^}]*\}')
+
+# Pulls the leading number out of pre-formatted metric strings such as
+# ``'0.03 iters/s'`` / ``'1.000000e-05'`` / ``'30 seconds'`` emitted by
+# ``TrainMetric`` and ``GRPOMetric``. We use this in ``_coerce_for_swanlab``
+# so swanlab can build line charts instead of dropping those keys with a
+# ``failed to create chart for key '...': invalid value type`` warning.
+_LEADING_NUMBER_RE = re.compile(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?')
+
+
+def _coerce_for_swanlab(log_dict: Dict[str, Any]) -> Dict[str, Any]:
+ """Cast string-valued metrics to float for swanlab line charts.
+
+ ``TrainMetric.calculate()`` and ``GRPOMetric.calculate()`` return
+ pre-formatted strings (``'0.03 iters/s'``, ``'1.000000e-05'``,
+ ``'30 seconds'``, ``'0.8321'``). swanlab cannot build a line chart
+ from a string value and emits one warning per key per step. We extract
+ the leading number where possible; keys whose value can't be parsed
+ as a scalar are left as-is so they still show up in the text log.
+ """
+ coerced: Dict[str, Any] = {}
+ for k, v in log_dict.items():
+ if isinstance(v, bool) or isinstance(v, (int, float)):
+ coerced[k] = v
+ continue
+ if isinstance(v, str):
+ m = _LEADING_NUMBER_RE.search(v)
+ if m:
+ try:
+ coerced[k] = float(m.group())
+ continue
+ except ValueError:
+ pass
+ coerced[k] = v
+ return coerced
+
+
+def _last_assistant_text(trajectory: Dict[str, Any]) -> Optional[str]:
+ """Return the text of the last ``assistant`` message, or ``None``.
+
+ ``content`` can be ``str`` | ``None`` | ``dict`` (single multimodal
+ part) | ``list[dict]`` (multiple parts). The downstream caller feeds
+ this into ``_BOXED_RE.search(...)``, so we collapse the visible text
+ into a single string and ignore non-text parts (images etc.).
+ """
+ for m in reversed(trajectory.get('messages', [])):
+ if m.get('role') != 'assistant':
+ continue
+ c = m.get('content')
+ if c is None:
+ return None
+ if isinstance(c, str):
+ return c
+ if isinstance(c, dict):
+ return c.get('text') if c.get('type') == 'text' else None
+ if isinstance(c, list):
+ parts = [p.get('text') or '' for p in c
+ if isinstance(p, dict) and p.get('type') == 'text']
+ return '\n'.join(parts) if parts else None
+ return str(c)
+ return None
+
+
+def _compute_rollout_diagnostics(
+ trajectories: List[Dict[str, Any]],
+ n_turns_per_rollout: List[int],
+ per_rollout_completion_length: List[int],
+ f1_rewards: Optional[List[float]] = None,
+ old_logps: Optional[List[List[float]]] = None,
+) -> Dict[str, float]:
+ """Aggregate rollout diagnostics for swanlab logging.
+
+ Stripped-down version of the condensed variant's diagnostics — without
+ chunking we only care about (a) the longest non-trainable prefix
+ (system prompt + full passages), and (b) whether the rollout produced
+ a `\\boxed{}` final answer at all. ``avg_turns`` is logged for symmetry
+ even though it should be exactly 1.0 with ``MAX_TURNS=1``.
+ """
+ out: Dict[str, float] = {}
+ if n_turns_per_rollout:
+ out['avg_turns'] = sum(n_turns_per_rollout) / len(n_turns_per_rollout)
+
+ _max_non_trainable = 0
+ for t, comp_len in zip(trajectories, per_rollout_completion_length):
+ ids = t.get('input_ids') or []
+ non_trainable = max(0, len(ids) - int(comp_len or 0))
+ if non_trainable > _max_non_trainable:
+ _max_non_trainable = non_trainable
+ out['non_trainable_tokens'] = _max_non_trainable
+
+ if trajectories:
+ n_no_boxed = sum(
+ 0 if _BOXED_RE.search(_last_assistant_text(t) or '') else 1
+ for t in trajectories)
+ out['no_boxed_rate'] = n_no_boxed / len(trajectories)
+
+ def _content_chars(c: Any) -> int:
+ if not c:
+ return 0
+ if isinstance(c, str):
+ return len(c)
+ if isinstance(c, dict):
+ if c.get('type') == 'text':
+ return len(c.get('text') or '')
+ return 0
+ if isinstance(c, list):
+ total = 0
+ for part in c:
+ if isinstance(part, dict) and part.get('type') == 'text':
+ total += len(part.get('text') or '')
+ elif isinstance(part, str):
+ total += len(part)
+ return total
+ # Unknown shape -- fall back to ``str()`` length rather than
+ # crashing, so a template quirk never breaks metric logging.
+ return len(str(c))
+
+ msg_chars_total, prompt_chars, asst_chars = [], [], []
+ for t in trajectories:
+ total_i = prompt_i = asst_i = 0
+ for m in (t.get('messages') or []):
+ role = m.get('role')
+ if role == 'system':
+ continue
+ n = _content_chars(m.get('content'))
+ total_i += n
+ if role in ('user', 'tool'):
+ prompt_i += n
+ elif role == 'assistant':
+ asst_i += n
+ msg_chars_total.append(total_i)
+ prompt_chars.append(prompt_i)
+ asst_chars.append(asst_i)
+ out['avg_chars_total_no_sys'] = sum(msg_chars_total) / len(msg_chars_total)
+ out['avg_chars_prompt_no_sys'] = sum(prompt_chars) / len(prompt_chars)
+ out['avg_chars_assistant'] = sum(asst_chars) / len(asst_chars)
+
+ if f1_rewards is not None and old_logps is not None and f1_rewards:
+ per_traj_mean = [(sum(lp) / len(lp)) if lp else 0.0 for lp in old_logps]
+ pos_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 > 0]
+ zero_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 <= 0]
+ out['f1_correct_rate'] = len(pos_logp) / len(f1_rewards)
+ out['f1_zero_rate'] = len(zero_logp) / len(f1_rewards)
+ out['mean_old_logp_f1_pos'] = (sum(pos_logp) / len(pos_logp)) if pos_logp else 0.0
+ out['mean_old_logp_f1_zero'] = (sum(zero_logp) / len(zero_logp)) if zero_logp else 0.0
+ out['policy_confidence_f1_pos'] = math.exp(out['mean_old_logp_f1_pos'])
+ out['policy_confidence_f1_zero'] = math.exp(out['mean_old_logp_f1_zero'])
+ return out
+
+
+def main():
+ swanlab.init(project='twinkle')
+
+ device_groups = [
+ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
+ DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'),
+ ]
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+ sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
+ twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS,
+ groups=device_groups, lazy_collect=False)
+
+ logger.info('Building HotpotQA dataset (baseline, full context)')
+ _prebuilt_dataset = create_hotpotqa_dataset()
+ logger.info('Dataset ready: %d rows', len(_prebuilt_dataset))
+
+ GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
+ batches_per_epoch = max(1, len(_prebuilt_dataset) // GLOBAL_BATCH_SIZE)
+ # Single-turn baseline: every rollout produces exactly one assistant
+ # turn, so the per-batch optim-step count equals
+ # ceil(GLOBAL_BATCH_SIZE * NUM_GENERATIONS / MINI_BATCH_SIZE).
+ optim_steps_per_batch = max(1, (GLOBAL_BATCH_SIZE * NUM_GENERATIONS
+ + MINI_BATCH_SIZE - 1) // MINI_BATCH_SIZE)
+ steps_per_epoch = batches_per_epoch * optim_steps_per_batch
+ derived_total_steps = NUM_EPOCHS * steps_per_epoch
+ total_steps = min(MAX_STEPS, derived_total_steps) if MAX_STEPS > 0 else derived_total_steps
+ logger.info('Training horizon: %d steps (%d epochs × %d batches × %d steps/batch)',
+ total_steps, NUM_EPOCHS, batches_per_epoch, optim_steps_per_batch)
+
+ lora_config = LoraConfig(
+ target_modules='all-linear', r=LORA_RANK,
+ lora_alpha=LORA_RANK * 2, lora_dropout=0.05)
+
+ if USE_MEGATRON:
+ from twinkle.model.megatron import MegatronModel
+ model = MegatronModel(
+ model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model',
+ mixed_precision='bf16', variable_seq_lengths=True)
+ else:
+ model = TransformersModel(
+ model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model')
+
+ model.add_adapter_to_model(ADAPTER_NAME, lora_config,
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ if USE_MEGATRON:
+ model.set_optimizer('default', lr=LEARNING_RATE)
+ model.set_lr_scheduler('default', lr_decay_steps=total_steps, max_lr=LEARNING_RATE)
+ else:
+ model.set_optimizer('AdamW', lr=LEARNING_RATE)
+ model.set_lr_scheduler('CosineAnnealingLR', T_max=total_steps, eta_min=0)
+
+ model.set_loss('GRPOLoss', epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH,
+ beta=KL_BETA, entropy_coef=ENTROPY_COEF)
+ model.set_processor(InputProcessor, padding_free=True)
+ model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH)
+
+ model.add_metric('GRPOMetric', is_training=True,
+ epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH,
+ top_k_kl=HIGH_KL_TOPK)
+
+ sampler = vLLMSampler(
+ model_id=MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': 0.8, 'max_model_len': 32768,
+ 'max_lora_rank': 32, 'enable_lora': True,
+ 'enable_tower_connector_lora': True,
+ },
+ device_mesh=sampler_mesh, remote_group='sampler')
+ sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH)
+ rollout_template = Qwen3_5Template(
+ MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, enable_thinking=False)
+
+ ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
+
+ dataloader = DataLoader(
+ dataset=lambda: _prebuilt_dataset,
+ batch_size=GLOBAL_BATCH_SIZE, min_batch_size=GLOBAL_BATCH_SIZE)
+
+ advantage_fn = GRPOAdvantage()
+ metrics = CompletionRewardMetric()
+ sampling_params = SamplingParams(
+ max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1,
+ temperature=1.0, top_p=0.95)
+
+ def _trace_should_store(traj):
+ return True
+
+ def _trace_is_success(traj):
+ return _F1_REWARD([traj])[0] > 0.0
+
+ rollout = MultiTurnRollout(
+ sampler=sampler,
+ template=rollout_template,
+ tool_manager=ToolManager(),
+ sampling_params=sampling_params,
+ max_turns=MAX_TURNS,
+ trace_dir=_ROLLOUT_TRACE_DIR or None,
+ trace_callback=_trace_should_store,
+ success_callback=_trace_is_success,
+ )
+
+ optim_step = 0
+ logger.info('Starting HotpotQA GRPO baseline (no chunk / no condense / no tools)')
+
+ def _epoch_cycle(dl, n_epochs):
+ for ep in range(1, n_epochs + 1):
+ logger.info(f'=== Epoch {ep}/{n_epochs} (step={optim_step}/{total_steps}) ===')
+ for batch in dl:
+ yield batch
+
+ for batch in _epoch_cycle(dataloader, NUM_EPOCHS):
+ if optim_step >= total_steps:
+ break
+
+ # Single source of truth for the step shown in swanlab / logger / rollout-trace filename.
+ batch_step = optim_step
+
+ metrics.reset()
+ expand_prompts = [p for prompt in batch for p in [prompt] * NUM_GENERATIONS]
+
+ ckpt_manager.sync_weights(merge_and_sync=False)
+ sampler.reset_prefix_cache()
+
+ # Single batched rollout: each trajectory produces exactly one
+ # assistant turn (tools are unregistered, ``max_turns=1``).
+ all_trajectories: List[Dict[str, Any]] = rollout(expand_prompts)
+ n_turns_per_rollout = [int(t.get('turns') or 0) for t in all_trajectories]
+ per_rollout_completion_length = [
+ sum(1 for l in (t.get('labels') or []) if l != -100)
+ for t in all_trajectories]
+
+ total_rewards, f1_rewards, cot_rewards = compute_rewards(all_trajectories)
+
+ rollout_advantages = advantage_fn(
+ total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()
+
+ all_f1_labels: List[bool] = [f > 0 for f in f1_rewards]
+ n_pos = sum(1 for p in all_f1_labels if p)
+ n_neg = sum(1 for p in all_f1_labels if not p)
+ pos_with_neg_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if p and a < 0)
+ neg_with_pos_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if not p and a > 0)
+
+ all_old_logps: List[List[float]] = [
+ [lp[0][1] for lp in (t.get('logprobs') or [])] for t in all_trajectories]
+
+ # Skip homogeneous groups where gradient signal is meaningless
+ f1_pos_rate = n_pos / len(f1_rewards) if f1_rewards else 0.5
+ if f1_pos_rate > 0.9 or f1_pos_rate < 0.1:
+ logger.info('[skip-homogeneous] f1_pos_rate=%.3f, skipping training update', f1_pos_rate)
+ metrics.accumulate(
+ completion_lengths=per_rollout_completion_length,
+ rewards={'total': total_rewards, 'f1': f1_rewards, 'cot': cot_rewards})
+ log_dict = metrics.calculate()
+ log_dict.update(_compute_rollout_diagnostics(
+ all_trajectories, n_turns_per_rollout, per_rollout_completion_length,
+ f1_rewards=f1_rewards, old_logps=all_old_logps))
+ log_dict['skipped'] = True
+ log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0
+ log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0
+ log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0
+ log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0
+ swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step)
+ metrics.reset()
+ logger.info(f'[Step {batch_step}/{total_steps}] [SKIPPED] {log_dict}')
+ optim_step += optim_steps_per_batch
+ continue
+
+ metrics.accumulate(
+ completion_lengths=per_rollout_completion_length,
+ rewards={'total': total_rewards, 'f1': f1_rewards, 'cot': cot_rewards})
+
+ all_input_data: List[Any] = list(all_trajectories)
+ advantages: List[float] = list(rollout_advantages)
+
+ total_completions = len(all_input_data)
+ aligned_completions = (total_completions // MODEL_GPUS) * MODEL_GPUS
+ if aligned_completions < total_completions:
+ logger.info(
+ '[dp-align] dropping %d tail sample(s): total=%d -> aligned=%d (dp=%d)',
+ total_completions - aligned_completions,
+ total_completions, aligned_completions, MODEL_GPUS)
+ for mb_start in range(0, aligned_completions, MINI_BATCH_SIZE):
+ mb_end = min(mb_start + MINI_BATCH_SIZE, aligned_completions)
+ mb_inputs = all_input_data[mb_start:mb_end]
+ # Reference log-probs for KL: same policy with LoRA disabled (= base model).
+ ref_logps = None
+ if KL_BETA > 0.0:
+ ref_outputs = model.forward_only(inputs=mb_inputs, disable_lora=True)
+ ref_logps = ref_outputs.get('logps') if isinstance(ref_outputs, dict) else getattr(ref_outputs, 'logps', None)
+ model.forward_backward(
+ inputs=mb_inputs,
+ old_logps=all_old_logps[mb_start:mb_end],
+ advantages=advantages[mb_start:mb_end],
+ ref_logps=ref_logps,
+ positive_mask=all_f1_labels[mb_start:mb_end],
+ micro_batch_size=MICRO_BATCH_SIZE)
+ model.clip_grad_and_step()
+ optim_step += 1
+ if optim_step >= total_steps:
+ break
+ if optim_step % SAVE_STEPS == 0:
+ model.save(f'hotpotqa-grpo-baseline-checkpoint-{optim_step}')
+
+ log_dict = metrics.calculate()
+ log_dict.update(model.calculate_metric(is_training=True))
+ log_dict.update(_compute_rollout_diagnostics(
+ all_trajectories, n_turns_per_rollout, per_rollout_completion_length,
+ f1_rewards=f1_rewards, old_logps=all_old_logps))
+ log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0
+ log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0
+ log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0
+ log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0
+ # Pop high-KL token records before swanlab.log: list-of-dict won't render as a chart.
+ _hk = log_dict.pop('_high_kl_records', None)
+ if _hk:
+ _tok = rollout_template.tokenizer
+ for r in _hk:
+ gsi = r.get('gsi')
+ tid = all_trajectories[gsi].get('id') if gsi is not None and 0 <= gsi < len(all_trajectories) else None
+ try:
+ tok_text = _tok.decode([r['token_id']])
+ except Exception:
+ tok_text = None
+ logger.info(
+ '[high-kl] step=%d gsi=%s tid=%s pos=%s tok=%r kl=%.4f r=%.4f lp_new=%.4f lp_old=%.4f',
+ batch_step, gsi, tid, r.get('pos'), tok_text,
+ r.get('kl'), r.get('ratio'), r.get('logp_new'), r.get('logp_old'))
+ swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step)
+ metrics.reset()
+ logger.info(f'[Step {batch_step}/{total_steps}] {log_dict}')
+
+ logger.info(f'Training completed. optim_steps={optim_step}')
+ model.save('hotpotqa-grpo-baseline-final')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/exp/legacy/grpo_condensed.py b/cookbook/exp/legacy/grpo_condensed.py
new file mode 100644
index 000000000..83eb49ac7
--- /dev/null
+++ b/cookbook/exp/legacy/grpo_condensed.py
@@ -0,0 +1,955 @@
+import copy
+import math
+import os
+import re
+from typing import Any, Dict, List, Optional
+
+import torch
+import swanlab
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup, get_logger
+from twinkle.advantage import GRPOAdvantage
+from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.data_format import Message, SamplingParams, Trajectory
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.metric import CompletionRewardMetric
+from twinkle.model import TransformersModel
+from twinkle.preprocessor.base import Preprocessor
+from twinkle.processor import InputProcessor
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Qwen3_5Template
+from twinkle_agentic.chunker.native import NativeChunker
+from twinkle_agentic.condenser import ModelCondenser
+from twinkle_agentic.reward import F1Reward, CoTReward, ToolExploreReward
+from twinkle_agentic.rollout.multi_turn_condense import MultiTurnCondenseRollout
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0')))
+
+MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
+SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
+NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
+
+NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
+MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
+LEARNING_RATE = float(os.environ.get('LR', 1e-5))
+NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', 1))
+MAX_STEPS = int(os.environ.get('MAX_STEPS', 0))
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8))
+MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8))
+MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
+GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
+ADAPTER_NAME = 'default'
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
+LORA_RANK = int(os.environ.get('LORA_RANK', 16))
+
+MAX_TURNS = int(os.environ.get('MAX_TURNS', 4))
+MAX_TRAJECTORY_TOKENS = int(os.environ.get('MAX_TRAJECTORY_TOKENS', 8192))
+CHUNK_SIZE = int(os.environ.get('CHUNK_SIZE', 1024))
+
+HOTPOTQA_NUM_PROC = int(os.environ.get('HOTPOTQA_NUM_PROC', 16))
+HOTPOTQA_MAX_LENGTH = int(os.environ.get('HOTPOTQA_MAX_LENGTH', 64000))
+
+F1_REWARD_WEIGHT = float(os.environ.get('F1_REWARD_WEIGHT', 1.0))
+COT_REWARD_WEIGHT = float(os.environ.get('COT_REWARD_WEIGHT', 0))
+TOOL_BONUS_WEIGHT = float(os.environ.get('TOOL_BONUS_WEIGHT', 0.0))
+TOOL_BONUS_F1_THRESHOLD = float(
+ os.environ.get('TOOL_BONUS_F1_THRESHOLD', 0.5))
+
+# KL penalty coefficient; 0 disables KL (and skips the ref forward pass entirely).
+# CISPO is token-level and DOES support per-token KL — small positive value (e.g. 0.005) recommended as anchor.
+KL_BETA = float(os.environ.get('KL_BETA', 0.01))
+
+# Entropy bonus coefficient; 0 disables the entropy compute path entirely.
+# Typical GRPO values: 0.001–0.01. Loss is: L = L_PPO + beta*KL - entropy_coef*H.
+ENTROPY_COEF = float(os.environ.get('ENTROPY_COEF', 0.0))
+
+# Per-token oracle bonus coefficient; 0 disables. Typical: 0.05–0.2.
+# Loss becomes: L = L_PPO + beta*KL - entropy_coef*H - token_bonus_coef*(oracle_logps - rollout_logps)
+ORACLE_BONUS_COEF = float(os.environ.get('ORACLE_BONUS_COEF', 0.0))
+
+# CISPO token-level IS clamp thresholds (MiniMax CISPO defaults: 0.2 / 0.28 asymmetric).
+CISPO_EPS_LOW = float(os.environ.get('CISPO_EPS_LOW', 0.2))
+CISPO_EPS_HIGH = float(os.environ.get('CISPO_EPS_HIGH', 0.2))
+
+# High-KL token capture: top-K per microbatch dumped into log_dict['_high_kl_records']. 0 = disabled.
+HIGH_KL_TOPK = int(os.environ.get('HIGH_KL_TOPK', 0))
+
+INIT_LORA_PATH = os.environ.get('INIT_LORA_PATH', 'output/condensed_sft_ddp/last-checkpoint')
+DATASET_PATH = os.environ.get(
+ 'DATASET_PATH',
+ os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
+ 'hotpotqa_fullwiki_reannotated_12k.jsonl'))
+F1_BINARY_THRESHOLD = float(os.environ.get('F1_BINARY_THRESHOLD', 0.5))
+
+_ROLLOUT_TRACE_DIR = os.environ.get('ROLLOUT_TRACE_DIR', 'rollout_trace')
+ORACLE_HINT = bool(int(os.environ.get('ORACLE_HINT', '0')))
+
+
+# [EXP-ORACLE] staged hint injection — appended to the Question line so skip_pattern keeps it uncompressed.
+def _oracle_hint_stage(step: int, total_steps: int) -> int:
+ """0 = explicit titles, 1 = vague count, 2 = no hint."""
+ return 0
+ # if total_steps <= 0:
+ # return 0
+ # third = max(1, total_steps // 3)
+ # if step < third:
+ # return 0
+ # if step < 2 * third:
+ # return 1
+ # return 2
+
+
+
+def _make_oracle_hint_callback(total_steps: int):
+ """Return a post_compress_callback that injects oracle hints with actual block IDs.
+
+ Called by MultiTurnCondenseRollout after compression + metadata merge, so
+ ``compressed['user_data']`` carries sf_titles and ``chunks`` carries the
+ condensed/raw status of each passage.
+
+ Stages (determined by global_step / total_steps):
+ 0 — explicit block IDs for supporting-fact passages
+ 1 — block count only (no IDs)
+ 2 — no hint
+ """
+ _q_split = re.compile(r'(Question:\s*.+?)(\n\nContext:)', re.DOTALL)
+
+ def _callback(compressed, chunks, **kwargs):
+ step = kwargs.get('global_step', 0)
+ stage = _oracle_hint_stage(step, total_steps)
+ if stage == 2:
+ return compressed
+
+ user_data = compressed.get('user_data') or []
+ sf_titles = [v for k, v in user_data if k == 'sf_title' and v]
+ if not sf_titles:
+ return compressed
+ sf_set = set(sf_titles)
+
+ # Map sf_titles → block IDs by walking condensed chunks
+ block_id = 0
+ sf_block_ids = []
+ for c in chunks.chunks:
+ if c.get('type') != 'text':
+ continue
+ content = c.get('content')
+ if not isinstance(content, str) or not content:
+ continue
+ if c.get('role') == 'tool':
+ continue
+ raw = c.get('raw')
+ if not (isinstance(raw, dict) and raw.get('condensed')):
+ continue
+ block_id += 1
+ original = raw.get('original', '')
+ if isinstance(original, str):
+ for title in sf_set:
+ if original.startswith(f'{title}: ') or original.startswith(f'{title}:'):
+ sf_block_ids.append(block_id)
+ break
+
+ if stage == 0:
+ if sf_block_ids:
+ ids_str = ', '.join(str(b) for b in sf_block_ids)
+ hint = (f'\n[Oracle Hint] Block {ids_str} contain(s) the supporting facts. '
+ 'Call `extract_condensed` to expand them if you need more detail information.')
+ else:
+ n = len(sf_set)
+ word = {1: 'One', 2: 'Two', 3: 'Three'}.get(n, str(n))
+ hint = (f'\n[Oracle Hint] {word} short passage(s) contain the supporting facts; '
+ 'they are uncompressed — read them directly.')
+ else:
+ hint = (f'\n[Oracle Hint] Some compressed block(s) contain the supporting facts; '
+ 'call `extract_condensed` to expand them if you need more detail information.')
+
+ for m in (compressed.get('messages') or []):
+ if m.get('role') != 'user':
+ continue
+ c = m.get('content')
+ if isinstance(c, str):
+ m['content'] = _q_split.sub(
+ lambda g: g.group(1) + hint + g.group(2), c, count=1)
+ elif isinstance(c, list):
+ for part in c:
+ if isinstance(part, dict) and part.get('type') == 'text':
+ part['text'] = _q_split.sub(
+ lambda g: g.group(1) + hint + g.group(2),
+ part.get('text') or '', count=1)
+ break
+ break
+ return compressed
+
+ return _callback
+
+SYSTEM_PROMPT = """You are a careful multi-hop QA assistant.
+
+## Context Format (Mixed)
+The context you receive is a **mix of two forms**:
+
+1. **Compressed blocks** — long passages wrapped in `...`, \
+ displayed as a Markdown digest in **telegraphic style** (no \
+ articles / "is" / "are"; colons and commas mean "is" / "has") \
+ with two sections:
+ - **Summary**: overview plus facts strongly related to the question, stated explicitly.
+ - **More**: a collapsed INDEX of category keywords hinting at extra details hidden in the full text (call `extract_condensed` to see them).
+ Reading example: `India: 7th largest by area. Borders: Pakistan, \
+ China.` means "India is the 7th largest country by area and \
+ shares borders with Pakistan and China."
+2. **Raw passages** — short passages shown inline as plain text (`Title: \
+ body`) **without** any `` wrapping. These are already the full \
+ text; nothing is hidden.
+
+Only the ``-wrapped blocks are compressed and can be expanded. \
+Block ids `N` are 1-based and assigned in the order compressed blocks \
+appear in the context, so they are always contiguous (``, \
+``, ``, ...). Raw passages have no block id and cannot \
+be extracted — they are already complete.
+
+## Workflow
+
+### Phase 1 — Scan and Decide
+Step 1: Read each compressed block's Summary, and read raw \
+passages directly, to get an overview.
+Step 2: For compressed blocks, check the More keywords to judge whether \
+hidden details are needed.
+Step 3: Decide which compressed blocks to expand, then call \
+`extract_condensed` with their block ids. Raw passages need no extraction.
+
+### Phase 2 — Reason and Answer
+After the tool returns the full text, continue stepping through the evidence:
+Step N: From block X (or the raw passage titled "..."), I learn that [fact A].
+Step N+1: From block Y, I need to call `extract_condensed` to get more information, because this block is related to...
+Step N+2: Combining these, the answer is ...
+\\boxed{answer}
+
+You may call `extract_condensed` several times to expand more blocks if the information is not enough, only answer the question if you are sure about the facts.
+The `blocks` parameter accepts **exactly one integer** per call (e.g. `3`); lists are rejected. Expand additional blocks by issuing separate `extract_condensed` calls, one per block. Only pass ids that actually appear as `` in the context, and do **not** request the same block twice — its text is already in the conversation after the first expansion.
+
+## Tool Call Format
+
+
+
+3
+
+
+
+
+## Output Format
+End your final response with \\boxed{answer}, e.g. \\boxed{Delhi}.
+Keep the boxed text short: a name, entity, date, or "yes"/"no".
+Answers not inside \\boxed{} will not be scored."""
+
+
+_F1_REWARD: Optional[F1Reward] = F1Reward()
+_COT_REWARD: Optional[CoTReward] = CoTReward()
+_TOOL_EXPLORE_REWARD: Optional[ToolExploreReward] = ToolExploreReward(
+ f1_threshold=TOOL_BONUS_F1_THRESHOLD)
+
+
+def compute_rewards(trajectories: List[Dict[str, Any]]):
+ f1_raw = _F1_REWARD(trajectories)
+ f1 = [1.0 if v >= F1_BINARY_THRESHOLD else 0.0 for v in f1_raw] if F1_BINARY_THRESHOLD > 0 else f1_raw
+ cot = _COT_REWARD(trajectories)
+ tool_explore = _TOOL_EXPLORE_REWARD(trajectories)
+ total = [
+ F1_REWARD_WEIGHT * a + COT_REWARD_WEIGHT * c + TOOL_BONUS_WEIGHT * te
+ for a, c, te in zip(f1, cot, tool_explore)
+ ]
+ return total, f1, cot, tool_explore
+
+
+class HotpotQAProcessor(Preprocessor):
+ def __init__(self, system: str = SYSTEM_PROMPT):
+ self.system = system
+
+ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ rows = self.map_col_to_row(rows)
+ rows = [self.preprocess(row) for row in rows]
+ rows = [r for r in rows if r is not None]
+ rows = self.map_row_to_col(rows)
+ return rows
+
+ @staticmethod
+ def _format_context(context: Dict[str, Any]) -> str:
+ titles = context.get('title', []) or []
+ sentences = context.get('sentences', []) or []
+ lines = []
+ for title, sents in zip(titles, sentences):
+ if isinstance(sents, list):
+ body = ' '.join(s.strip() for s in sents if s and s.strip())
+ else:
+ body = str(sents).strip()
+ lines.append(f'{title}: {body}')
+ return '\n\n'.join(lines)
+
+ def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]:
+ if (row.get('verdict') or '').strip().lower() == 'drop':
+ return None
+ question = row.get('question_fixed') or row['question']
+ answers = row.get('answers')
+ if isinstance(answers, list) and answers:
+ gold = [str(a).strip() for a in answers if str(a).strip()]
+ else:
+ gold = [s for s in [(row.get('answer', '') or '').strip()] if s]
+ context_block = self._format_context(row.get('context', {}) or {})
+ user_msg = f'Question: {question}\n\nContext:\n\n{context_block}'
+ messages = [
+ Message(role='system', content=self.system),
+ Message(role='user', content=user_msg),
+ ]
+ # [EXP-ORACLE] carry supporting_facts titles via user_data; rollout injects post-compression block hint
+ sf = row.get('supporting_facts') or {}
+ sf_titles = sf.get('title') or []
+ sf_unique = list(dict.fromkeys(t for t in sf_titles if t))
+ user_data = [('ground_truth', g) for g in gold] + [('sf_title', t) for t in sf_unique]
+ return Trajectory(messages=messages, user_data=user_data)
+
+
+def create_hotpotqa_dataset() -> Dataset:
+ dataset = Dataset()
+ dataset.add_dataset(DatasetMeta(DATASET_PATH))
+ logger.info('[dataset] loaded %s: %d rows', DATASET_PATH, len(dataset))
+
+ dataset.set_template(
+ 'Qwen3_5Template', model_id=MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH,
+ truncation_strategy='delete', enable_thinking=False)
+ _HOTPOTQA_COLS = ['id', 'question', 'question_fixed', 'answers',
+ 'original_answer', 'type', 'level', 'verdict',
+ 'reasoning', 'supporting_facts', 'context']
+ dataset.map(HotpotQAProcessor(system=SYSTEM_PROMPT), remove_columns=_HOTPOTQA_COLS)
+ return dataset
+
+
+# Matches a LaTeX ``\boxed{...}`` final-answer marker — used to flag
+# rollouts that never committed an answer. Brace-balanced is overkill for
+# a logging heuristic; a non-greedy ``[^}]*`` is good enough.
+_BOXED_RE = re.compile(r'\\boxed\{[^}]*\}')
+
+# Pulls the leading number out of pre-formatted metric strings such as
+# ``'0.03 iters/s'`` / ``'1.000000e-05'`` / ``'30 seconds'`` emitted by
+# ``TrainMetric`` and ``GRPOMetric``. We use this in ``_coerce_for_swanlab``
+# so swanlab can build line charts instead of dropping those keys with a
+# ``failed to create chart for key '...': invalid value type`` warning.
+_LEADING_NUMBER_RE = re.compile(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?')
+
+
+def _coerce_for_swanlab(log_dict: Dict[str, Any]) -> Dict[str, Any]:
+ """Cast string-valued metrics to float for swanlab line charts.
+
+ ``TrainMetric.calculate()`` and ``GRPOMetric.calculate()`` return
+ pre-formatted strings (``'0.03 iters/s'``, ``'1.000000e-05'``,
+ ``'30 seconds'``, ``'0.8321'``). swanlab cannot build a line chart
+ from a string value and emits one warning per key per step. We extract
+ the leading number where possible; keys whose value can't be parsed
+ as a scalar are left as-is so they still show up in the text log.
+ """
+ coerced: Dict[str, Any] = {}
+ for k, v in log_dict.items():
+ if isinstance(v, bool) or isinstance(v, (int, float)):
+ coerced[k] = v
+ continue
+ if isinstance(v, str):
+ m = _LEADING_NUMBER_RE.search(v)
+ if m:
+ try:
+ coerced[k] = float(m.group())
+ continue
+ except ValueError:
+ pass
+ coerced[k] = v
+ return coerced
+
+
+def _last_assistant_text(trajectory: Dict[str, Any]) -> Optional[str]:
+ """Return the text of the last ``assistant`` message, or ``None``.
+
+ ``content`` can be ``str`` | ``None`` | ``dict`` (single multimodal
+ part) | ``list[dict]`` (multiple parts). The downstream caller feeds
+ this into ``_BOXED_RE.search(...)``, so we collapse the visible text
+ into a single string and ignore non-text parts (images etc.).
+ """
+ for m in reversed(trajectory.get('messages', [])):
+ if m.get('role') != 'assistant':
+ continue
+ c = m.get('content')
+ if c is None:
+ return None
+ if isinstance(c, str):
+ return c
+ if isinstance(c, dict):
+ return c.get('text') if c.get('type') == 'text' else None
+ if isinstance(c, list):
+ parts = [p.get('text') or '' for p in c
+ if isinstance(p, dict) and p.get('type') == 'text']
+ return '\n'.join(parts) if parts else None
+ return str(c)
+ return None
+
+
+def _compute_rollout_diagnostics(
+ trajectories: List[Dict[str, Any]],
+ n_turns_per_rollout: List[int],
+ per_rollout_completion_length: List[int],
+ f1_rewards: Optional[List[float]] = None,
+ old_logps: Optional[List[List[float]]] = None,
+) -> Dict[str, float]:
+ """Aggregate rollout diagnostics for swanlab logging.
+
+ All inputs are already flat:
+ * ``trajectories[i]`` is the merged trajectory dict returned by
+ :class:`MultiTurnCondenseRollout` (contains ``messages``,
+ ``input_ids``, ``labels``, ``turns`` at top level).
+ * ``n_turns_per_rollout[i] == trajectories[i]['turns']``.
+ * ``per_rollout_completion_length[i]`` == number of trainable
+ tokens in the trajectory (labels != -100).
+ """
+ out: Dict[str, float] = {}
+ if n_turns_per_rollout:
+ out['avg_turns'] = sum(n_turns_per_rollout) / len(n_turns_per_rollout)
+
+ # ``non_trainable_tokens`` is the longest non-trainable prefix across
+ # the batch: ``len(input_ids) - sum(1 for l in labels if l != -100)``.
+ # Tracks how much the condensed context + system prompt is eating the
+ # context budget (it does NOT equal the first-turn prompt length
+ # because multi-turn runs also contribute non-trainable tokens from
+ # the ``tool`` observations between assistant turns).
+ _max_non_trainable = 0
+ for t, comp_len in zip(trajectories, per_rollout_completion_length):
+ ids = t.get('input_ids') or []
+ non_trainable = max(0, len(ids) - int(comp_len or 0))
+ if non_trainable > _max_non_trainable:
+ _max_non_trainable = non_trainable
+ out['non_trainable_tokens'] = _max_non_trainable
+
+ if trajectories:
+ tool_counts = [
+ sum(len(m.get('tool_calls') or [])
+ for m in t.get('messages', []) if m.get('role') == 'assistant')
+ for t in trajectories]
+ out['avg_tool_calls'] = sum(tool_counts) / len(tool_counts)
+ out['tool_use_rate'] = sum(1 for c in tool_counts if c > 0) / len(tool_counts)
+ n_no_boxed = sum(
+ 0 if _BOXED_RE.search(_last_assistant_text(t) or '') else 1
+ for t in trajectories)
+ out['no_boxed_rate'] = n_no_boxed / len(trajectories)
+ def _content_chars(c: Any) -> int:
+ if not c:
+ return 0
+ if isinstance(c, str):
+ return len(c)
+ if isinstance(c, dict):
+ if c.get('type') == 'text':
+ return len(c.get('text') or '')
+ return 0
+ if isinstance(c, list):
+ total = 0
+ for part in c:
+ if isinstance(part, dict) and part.get('type') == 'text':
+ total += len(part.get('text') or '')
+ elif isinstance(part, str):
+ total += len(part)
+ return total
+ # Unknown shape -- fall back to ``str()`` length rather than
+ # crashing, so a template quirk never breaks metric logging.
+ return len(str(c))
+
+ msg_chars_total, prompt_chars, asst_chars = [], [], []
+ for t in trajectories:
+ total_i = prompt_i = asst_i = 0
+ for m in (t.get('messages') or []):
+ role = m.get('role')
+ if role == 'system':
+ continue
+ n = _content_chars(m.get('content'))
+ total_i += n
+ if role in ('user', 'tool'):
+ prompt_i += n
+ elif role == 'assistant':
+ asst_i += n
+ msg_chars_total.append(total_i)
+ prompt_chars.append(prompt_i)
+ asst_chars.append(asst_i)
+ out['avg_chars_total_no_sys'] = sum(msg_chars_total) / len(msg_chars_total)
+ out['avg_chars_prompt_no_sys'] = sum(prompt_chars) / len(prompt_chars)
+ out['avg_chars_assistant'] = sum(asst_chars) / len(asst_chars)
+
+ if f1_rewards is not None and old_logps is not None and f1_rewards:
+ per_traj_mean = [
+ (sum(lp) / len(lp)) if lp else 0.0 for lp in old_logps]
+ pos_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 > 0]
+ zero_logp = [m for m, f1 in zip(per_traj_mean, f1_rewards) if f1 <= 0]
+ out['f1_correct_rate'] = len(pos_logp) / len(f1_rewards)
+ out['f1_zero_rate'] = len(zero_logp) / len(f1_rewards)
+ out['mean_old_logp_f1_pos'] = (sum(pos_logp) / len(pos_logp)) if pos_logp else 0.0
+ out['mean_old_logp_f1_zero'] = (sum(zero_logp) / len(zero_logp)) if zero_logp else 0.0
+ out['policy_confidence_f1_pos'] = math.exp(out['mean_old_logp_f1_pos'])
+ out['policy_confidence_f1_zero'] = math.exp(out['mean_old_logp_f1_zero'])
+ return out
+
+
+def _build_oracle_inputs(
+ mb_inputs: List[Dict[str, Any]],
+ f1_labels: List[bool],
+ template,
+) -> Optional[List[Dict[str, Any]]]:
+ """Build oracle-context inputs at the TOKEN level for per-token bonus computation.
+
+ The approach:
+ 1. Find ``first_trainable`` from labels (first position != -100).
+ Due to NTP shift, input_ids[first_trainable] is the last prefix token (e.g. \\n
+ after ``assistant``) and labels[first_trainable] is the first response token target.
+ 2. Construct oracle messages: [system, user_with_oracle_suffix].
+ 3. Encode with template (add_generation_prompt=True) → oracle_prefix_ids ending with
+ the same assistant header token.
+ 4. Concatenate: oracle_prefix_ids + input_ids[first_trainable+1:] (response tokens).
+ 5. Labels: [-100]*(len(oracle_prefix)-1) + labels[first_trainable:] so the last prefix
+ position predicts the first response token.
+
+ For F1=0 samples: copied unchanged (bonus zeroed by _compute_token_bonus).
+ """
+ _q_line_re = re.compile(r'Question:\s*(.+?)(?:\n|$)', re.DOTALL)
+ oracle_inputs = []
+ any_modified = False
+
+ for inp, is_pos in zip(mb_inputs, f1_labels):
+ if not is_pos:
+ oracle_inputs.append(inp)
+ continue
+
+ user_data = inp.get('user_data') or []
+ sf_titles = [v for k, v in user_data if k == 'sf_title' and v]
+ gts = [v for k, v in user_data if k == 'ground_truth' and v]
+ if not sf_titles and not gts:
+ oracle_inputs.append(inp)
+ continue
+
+ labels = inp.get('labels') or []
+ input_ids = inp.get('input_ids') or []
+ if not labels or not input_ids:
+ oracle_inputs.append(inp)
+ continue
+
+ # 1. Find first trainable position
+ first_trainable = None
+ for i, l in enumerate(labels):
+ if l != -100:
+ first_trainable = i
+ break
+
+ assert first_trainable is not None
+
+ # 2. Extract question from first user message
+ question = None
+ msgs = inp.get('messages') or []
+ for m in msgs:
+ if m.get('role') != 'user':
+ continue
+ c = m.get('content')
+ text = c if isinstance(c, str) else (
+ next((p.get('text') for p in c if isinstance(p, dict) and p.get('type') == 'text'), '')
+ if isinstance(c, list) else '')
+ q_match = _q_line_re.match(text or '')
+ if q_match:
+ question = q_match.group(1).strip()
+ break
+
+ if not question:
+ oracle_inputs.append(inp)
+ continue
+
+ # 3. Build oracle user message (concise: question + oracle hints only)
+ hint_parts = []
+ if sf_titles:
+ hint_parts.append('Supporting passages: ' + ', '.join(f'"{t}"' for t in sf_titles))
+ if gts:
+ hint_parts.append('Answer: ' + '; '.join(gts))
+ hint_parts.append('You must call `extract_condensed` to read the right original passage from the condensed block with thinking steps, and give the final correct answer')
+ oracle_suffix = '\n[Oracle Context] ' + '. '.join(hint_parts) + '.'
+ oracle_user_content = f'Question: {question}{oracle_suffix}'
+
+ oracle_msgs = [
+ Message(role='system', content=SYSTEM_PROMPT),
+ Message(role='user', content=oracle_user_content),
+ ]
+
+ # 4. Encode oracle prefix (ends with <|im_start|>assistant\n)
+ oracle_feature = template.encode(
+ Trajectory(messages=oracle_msgs), add_generation_prompt=True)
+ oracle_prefix_ids = list(oracle_feature['input_ids'])
+
+ # 5. Splice: oracle_prefix + response_tokens
+ response_tokens = list(input_ids[first_trainable + 1:])
+ response_labels = list(labels[first_trainable:])
+
+ oracle_input_ids = oracle_prefix_ids + response_tokens
+ # Last position of oracle prefix predicts first response token
+ oracle_labels = [-100] * (len(oracle_prefix_ids) - 1) + response_labels
+
+ assert len(oracle_input_ids) == len(oracle_labels)
+ seq_len = len(oracle_input_ids)
+ # Start from original keys to keep collator-compatible shape
+ oi = dict(inp)
+ oi['input_ids'] = oracle_input_ids
+ oi['labels'] = oracle_labels
+ oi['attention_mask'] = [1] * seq_len
+ oi['messages'] = None
+ oi['length'] = seq_len
+ # Replicate mrope position_ids shape from original input
+ orig_pos = inp.get('position_ids')
+ if isinstance(orig_pos, torch.Tensor) and orig_pos.dim() == 3:
+ n_dims = orig_pos.shape[0]
+ pos_range = torch.arange(seq_len).unsqueeze(0).unsqueeze(0)
+ oi['position_ids'] = pos_range.expand(n_dims, 1, seq_len)
+ else:
+ oi['position_ids'] = list(range(seq_len))
+ if 'mm_token_type_ids' in inp:
+ oi['mm_token_type_ids'] = torch.zeros(1, seq_len)
+ oracle_inputs.append(oi)
+ any_modified = True
+
+ return oracle_inputs if any_modified else None
+
+
+def _compute_token_bonus(
+ oracle_logps: Any,
+ old_logps: List[List[float]],
+ f1_labels: List[bool],
+ oracle_inputs: List[Dict[str, Any]],
+) -> List[List[float]]:
+ """Compute per-token bonus = oracle_logps - rollout_logps, zeroed for F1=0 samples.
+
+ oracle_logps is full-sequence form [batch, padded_seq] from forward_only + collector.
+ We extract valid positions using oracle_inputs[i]['labels'] mask to get response-only
+ logps aligned 1:1 with old_logps.
+ """
+ import torch
+
+ if isinstance(oracle_logps, torch.Tensor):
+ oracle_logps = oracle_logps.float().cpu()
+
+ bonus = []
+ for i, (is_pos, old_lp) in enumerate(zip(f1_labels, old_logps)):
+ if not is_pos or not old_lp:
+ bonus.append([0.0] * len(old_lp) if old_lp else [])
+ continue
+
+ n = len(old_lp)
+ oracle_labels = oracle_inputs[i].get('labels') or []
+
+ # Build mask from oracle labels to extract valid (trainable) positions
+ if isinstance(oracle_logps, torch.Tensor):
+ orc_row = oracle_logps[i]
+ mask = torch.tensor([l != -100 for l in oracle_labels], dtype=torch.bool)
+ seq_len = min(len(mask), orc_row.numel())
+ orc_valid = orc_row[:seq_len][mask[:seq_len]].tolist()
+ else:
+ orc_row = oracle_logps[i] if i < len(oracle_logps) else []
+ if isinstance(orc_row, torch.Tensor):
+ orc_row = orc_row.float().cpu().tolist()
+ elif not isinstance(orc_row, (list, tuple)):
+ orc_row = []
+ orc_valid = [v for v, l in zip(orc_row, oracle_labels) if l != -100]
+
+ assert len(orc_valid) == n
+ bonus.append([o - r for o, r in zip(orc_valid, old_lp)])
+ return bonus
+
+
+def main():
+ swanlab.init(project='twinkle')
+
+ device_groups = [
+ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
+ DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'),
+ ]
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+ sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
+ twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS,
+ groups=device_groups, lazy_collect=False)
+
+ logger.info('Building HotpotQA dataset')
+ _prebuilt_dataset = create_hotpotqa_dataset()
+ logger.info('Dataset ready: %d rows', len(_prebuilt_dataset))
+
+ GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
+ batches_per_epoch = max(1, len(_prebuilt_dataset) // GLOBAL_BATCH_SIZE)
+ optim_steps_per_batch = max(1, (GLOBAL_BATCH_SIZE * NUM_GENERATIONS
+ + MINI_BATCH_SIZE - 1) // MINI_BATCH_SIZE)
+ steps_per_epoch = batches_per_epoch * optim_steps_per_batch
+ derived_total_steps = NUM_EPOCHS * steps_per_epoch
+ total_steps = min(MAX_STEPS, derived_total_steps) if MAX_STEPS > 0 else derived_total_steps
+ logger.info('Training horizon: %d steps (%d epochs × %d batches × %d steps/batch)',
+ total_steps, NUM_EPOCHS, batches_per_epoch, optim_steps_per_batch)
+
+ lora_config = LoraConfig(
+ target_modules='all-linear', r=LORA_RANK,
+ lora_alpha=LORA_RANK * 2, lora_dropout=0.05)
+
+ if USE_MEGATRON:
+ from twinkle.model.megatron import MegatronModel
+ model = MegatronModel(
+ model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model',
+ mixed_precision='bf16', variable_seq_lengths=True)
+ else:
+ model = TransformersModel(
+ model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model')
+
+ model.add_adapter_to_model(ADAPTER_NAME, lora_config,
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ if INIT_LORA_PATH:
+ model.load(INIT_LORA_PATH, adapter_name=ADAPTER_NAME)
+ logger.info('Loaded cold-start LoRA from %s', INIT_LORA_PATH)
+ if USE_MEGATRON:
+ model.set_optimizer('default', lr=LEARNING_RATE)
+ model.set_lr_scheduler('default', lr_decay_steps=total_steps, max_lr=LEARNING_RATE)
+ else:
+ model.set_optimizer('AdamW', lr=LEARNING_RATE)
+ model.set_lr_scheduler('CosineAnnealingLR', T_max=total_steps, eta_min=0)
+
+ model.set_loss('GRPOLoss', epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH,
+ beta=KL_BETA, entropy_coef=ENTROPY_COEF, token_bonus_coef=ORACLE_BONUS_COEF)
+ model.set_processor(InputProcessor, padding_free=True)
+ model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH)
+
+ model.add_metric('GRPOMetric', is_training=True,
+ epsilon=CISPO_EPS_LOW, epsilon_high=CISPO_EPS_HIGH,
+ top_k_kl=HIGH_KL_TOPK)
+
+ sampler = vLLMSampler(
+ model_id=MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': 0.8, 'max_model_len': 32768,
+ 'max_lora_rank': 32, 'enable_lora': True,
+ 'enable_tower_connector_lora': True,
+ 'max_loras': 5
+ },
+ device_mesh=sampler_mesh, remote_group='sampler')
+ sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False, max_length=HOTPOTQA_MAX_LENGTH)
+ rollout_template = Qwen3_5Template(
+ MODEL_ID, max_length=HOTPOTQA_MAX_LENGTH, enable_thinking=False)
+
+ ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
+ chunker = NativeChunker(
+ chunk_size=CHUNK_SIZE,
+ passage_boundary_re=r'(?<=\n\n)',
+ )
+ # ``\A`` anchor: prevents a ``Question:`` line inside a passage from being misread as the query.
+ _question_re = re.compile(r'\AQuestion:\s*(.+)')
+
+ def _extract_question(chunk):
+ content = chunk.get('content')
+ if chunk.get('type') != 'text' or not isinstance(content, str):
+ return None
+ m = _question_re.search(content)
+ return m.group(1).strip() if m else None
+
+ condenser = ModelCondenser(
+ sampler=sampler,
+ compression_ratio=2.0,
+ sampling_params=SamplingParams(
+ max_tokens=1024, num_samples=1, temperature=0.4, top_p=0.9),
+ min_chars=200,
+ template=rollout_template,
+ lora_path='ms://twinkle-kit/Qwen3.5-4B-Condenser',
+ skip_pattern=r'^Question:',
+ related_query=_extract_question,
+ )
+
+ dataloader = DataLoader(
+ dataset=lambda: _prebuilt_dataset,
+ batch_size=GLOBAL_BATCH_SIZE, min_batch_size=GLOBAL_BATCH_SIZE)
+
+ advantage_fn = GRPOAdvantage()
+ metrics = CompletionRewardMetric()
+ sampling_params = SamplingParams(
+ max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1,
+ temperature=1.0, top_p=0.95,
+ stop=[''])
+
+ def _trace_should_store(traj):
+ return _F1_REWARD([traj])[0] == 0.0
+
+ def _trace_is_success(traj):
+ return _F1_REWARD([traj])[0] > 0.0
+
+ rollout = MultiTurnCondenseRollout(
+ sampler=sampler,
+ template=rollout_template,
+ tool_manager=ToolManager(),
+ chunker=chunker,
+ condenser=condenser,
+ sampling_params=sampling_params,
+ max_turns=MAX_TURNS,
+ max_trajectory_tokens=MAX_TRAJECTORY_TOKENS,
+ trace_dir=_ROLLOUT_TRACE_DIR or None,
+ trace_callback=_trace_should_store,
+ success_callback=_trace_is_success,
+ post_compress_callback=(
+ _make_oracle_hint_callback(total_steps) if ORACLE_HINT else None),
+ )
+
+ optim_step = 0
+ logger.info('Starting HotpotQA GRPO training (LLM condenser variant)')
+
+ def _epoch_cycle(dl, n_epochs):
+ for ep in range(1, n_epochs + 1):
+ logger.info(f'=== Epoch {ep}/{n_epochs} (step={optim_step}/{total_steps}) ===')
+ for batch in dl:
+ yield batch
+
+ for batch in _epoch_cycle(dataloader, NUM_EPOCHS):
+ if optim_step >= total_steps:
+ break
+
+ # Single source of truth for the step shown in swanlab / logger / rollout-trace filename.
+ # Equals the number of optimizer updates already completed when this rollout was sampled.
+ batch_step = optim_step
+
+ metrics.reset()
+ expand_prompts = [p for prompt in batch for p in [prompt] * NUM_GENERATIONS]
+
+ ckpt_manager.sync_weights(merge_and_sync=False)
+ sampler.reset_prefix_cache()
+
+ # Batched multi-turn rollout with chunk+condense pre-processing.
+ # Each returned trajectory is a flat dict containing ``messages``,
+ # ``input_ids``, ``labels``, ``attention_mask``, ``position_ids``,
+ # ``turns``, ``logprobs``, ``stop_reason``, ``truncated``.
+ all_trajectories: List[Dict[str, Any]] = rollout(expand_prompts, global_step=batch_step)
+ n_turns_per_rollout = [int(t.get('turns') or 0) for t in all_trajectories]
+ per_rollout_completion_length = [
+ sum(1 for l in (t.get('labels') or []) if l != -100)
+ for t in all_trajectories]
+
+ total_rewards, f1_rewards, cot_rewards, tool_explore_rewards = \
+ compute_rewards(all_trajectories)
+
+ rollout_advantages = advantage_fn(
+ total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()
+
+ all_f1_labels: List[bool] = [f > 0 for f in f1_rewards]
+ n_pos = sum(1 for p in all_f1_labels if p)
+ n_neg = sum(1 for p in all_f1_labels if not p)
+ pos_with_neg_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if p and a < 0)
+ neg_with_pos_adv = sum(1 for p, a in zip(all_f1_labels, rollout_advantages) if not p and a > 0)
+
+ # Skip homogeneous groups where gradient signal is meaningless
+ f1_pos_rate = n_pos / len(f1_rewards) if f1_rewards else 0.5
+ if f1_pos_rate > 0.9 or f1_pos_rate < 0.1:
+ logger.info('[skip-homogeneous] f1_pos_rate=%.3f, skipping training update', f1_pos_rate)
+ metrics.accumulate(
+ completion_lengths=per_rollout_completion_length,
+ rewards={'total': total_rewards, 'f1': f1_rewards,
+ 'cot': cot_rewards, 'tool_explore': tool_explore_rewards})
+ log_dict = metrics.calculate()
+ log_dict.update(_compute_rollout_diagnostics(
+ all_trajectories, n_turns_per_rollout, per_rollout_completion_length,
+ f1_rewards=f1_rewards, old_logps=[[lp[0][1] for lp in (t.get('logprobs') or [])] for t in all_trajectories]))
+ log_dict['skipped'] = True
+ log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0
+ log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0
+ log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0
+ log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0
+ swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step)
+ metrics.reset()
+ logger.info(f'[Step {batch_step}/{total_steps}] [SKIPPED] {log_dict}')
+ optim_step += optim_steps_per_batch
+ continue
+
+ metrics.accumulate(
+ completion_lengths=per_rollout_completion_length,
+ rewards={'total': total_rewards, 'f1': f1_rewards,
+ 'cot': cot_rewards, 'tool_explore': tool_explore_rewards})
+
+ all_input_data: List[Any] = []
+ all_old_logps: List[List[float]] = []
+ advantages: List[float] = []
+ for t, adv in zip(all_trajectories, rollout_advantages):
+ all_input_data.append(t)
+ all_old_logps.append([lp[0][1] for lp in (t.get('logprobs') or [])])
+ advantages.append(adv)
+
+ total_completions = len(all_input_data)
+ aligned_completions = (total_completions // MODEL_GPUS) * MODEL_GPUS
+ if aligned_completions < total_completions:
+ logger.info(
+ '[dp-align] dropping %d tail sample(s): total=%d -> aligned=%d (dp=%d)',
+ total_completions - aligned_completions,
+ total_completions, aligned_completions, MODEL_GPUS)
+ for mb_start in range(0, aligned_completions, MINI_BATCH_SIZE):
+ mb_end = min(mb_start + MINI_BATCH_SIZE, aligned_completions)
+ mb_inputs = all_input_data[mb_start:mb_end]
+ # Reference log-probs for KL: same policy model with LoRA adapter disabled (= base model).
+ # Skipped when KL_BETA == 0 to save one extra forward per mini-batch.
+ ref_logps = None
+ if KL_BETA > 0.0:
+ ref_outputs = model.forward_only(inputs=mb_inputs, disable_lora=True)
+ ref_logps = ref_outputs.get('logps') if isinstance(ref_outputs, dict) else getattr(ref_outputs, 'logps', None)
+ # [EXP-ORACLE] per-token bonus: forward with oracle context, diff against rollout logps
+ mb_token_bonus = None
+ if ORACLE_BONUS_COEF > 0.0:
+ mb_oracle_inputs = _build_oracle_inputs(
+ mb_inputs, all_f1_labels[mb_start:mb_end], rollout_template)
+ if mb_oracle_inputs is not None:
+ oracle_outputs = model.forward_only(inputs=mb_oracle_inputs)
+ oracle_logps = oracle_outputs.get('logps') if isinstance(oracle_outputs, dict) else getattr(oracle_outputs, 'logps', None)
+ if oracle_logps is not None:
+ mb_token_bonus = _compute_token_bonus(
+ oracle_logps, all_old_logps[mb_start:mb_end],
+ all_f1_labels[mb_start:mb_end], mb_oracle_inputs)
+ model.forward_backward(
+ inputs=mb_inputs,
+ old_logps=all_old_logps[mb_start:mb_end],
+ advantages=advantages[mb_start:mb_end],
+ ref_logps=ref_logps,
+ token_bonus=mb_token_bonus,
+ positive_mask=all_f1_labels[mb_start:mb_end],
+ micro_batch_size=MICRO_BATCH_SIZE)
+ model.clip_grad_and_step()
+ optim_step += 1
+ if optim_step >= total_steps:
+ break
+ if optim_step % SAVE_STEPS == 0:
+ model.save(f'hotpotqa-grpo-tools-llmcondense-checkpoint-{optim_step}')
+
+ log_dict = metrics.calculate()
+ log_dict.update(model.calculate_metric(is_training=True))
+ log_dict.update(_compute_rollout_diagnostics(
+ all_trajectories, n_turns_per_rollout, per_rollout_completion_length,
+ f1_rewards=f1_rewards, old_logps=all_old_logps))
+ log_dict['pos_neg_adv_rate'] = pos_with_neg_adv / n_pos if n_pos else 0.0
+ log_dict['neg_pos_adv_rate'] = neg_with_pos_adv / n_neg if n_neg else 0.0
+ log_dict['adv_max'] = max(rollout_advantages) if rollout_advantages else 0.0
+ log_dict['adv_min'] = min(rollout_advantages) if rollout_advantages else 0.0
+ # Pop high-KL token records before swanlab.log: list-of-dict won't render as a chart.
+ _hk = log_dict.pop('_high_kl_records', None)
+ if _hk:
+ _tok = rollout_template.tokenizer
+ for r in _hk:
+ gsi = r.get('gsi')
+ tid = all_trajectories[gsi].get('id') if gsi is not None and 0 <= gsi < len(all_trajectories) else None
+ try:
+ tok_text = _tok.decode([r['token_id']])
+ except Exception:
+ tok_text = None
+ logger.info(
+ '[high-kl] step=%d gsi=%s tid=%s pos=%s tok=%r kl=%.4f r=%.4f lp_new=%.4f lp_old=%.4f',
+ batch_step, gsi, tid, r.get('pos'), tok_text,
+ r.get('kl'), r.get('ratio'), r.get('logp_new'), r.get('logp_old'))
+ swanlab.log(_coerce_for_swanlab(log_dict), step=batch_step)
+ metrics.reset()
+ logger.info(f'[Step {batch_step}/{total_steps}] {log_dict}')
+
+ logger.info(f'Training completed. optim_steps={optim_step}')
+ model.save('hotpotqa-grpo-tools-llmcondense-final')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/exp/legacy/make_condensed_sft.py b/cookbook/exp/legacy/make_condensed_sft.py
new file mode 100644
index 000000000..3b9855ac2
--- /dev/null
+++ b/cookbook/exp/legacy/make_condensed_sft.py
@@ -0,0 +1,945 @@
+"""Cold-start SFT dataset builder for the condensed multi-hop QA task.
+
+Pipeline per HotpotQA distractor row:
+ 1. Build the standard system + user-with-context trajectory using the
+ production ``SYSTEM_PROMPT`` and ``_format_context`` from
+ ``cookbook/rl/grpo_condensed.py`` so the offline data matches what
+ the policy sees at training/inference time.
+ 2. Run the production ``NativeChunker`` + ``ModelCondenser`` on the
+ row to produce ``...`` compressed text.
+ 3. **Validation pass** (super-LLM, ``enable_thinking=True``, no oracle,
+ no tools): judge whether the question / supporting_facts / GT are
+ well-formed against the raw passages; return strict JSON
+ ``{"verdict": "ok"|"fix"|"drop", ...}`` with fixed SF + GT when
+ applicable. ``drop`` skips the row.
+ 4. **Oracle rollout pass** via :class:`APIMultiTurnRollout` with a
+ trajectory-bound :class:`ExtractCondensed` tool. The oracle hint
+ (SF titles + GT) is injected into the system prompt **only for
+ the API call**; it is stripped before saving. The model emits
+ OpenAI-shape ``tool_calls`` for ``extract_condensed``, the rollout
+ dispatches them through :class:`ToolManager` and feeds back the
+ pre-compression passage text as a ``tool`` message, looping until
+ the model finalises with ``\\boxed{...}`` or hits ``MAX_TURNS``.
+ 5. Accept iff F1(boxed, used_gt) >= ``F1_ACCEPT_THRESHOLD``. On miss,
+ retry once with a higher temperature.
+ 6. Convert OpenAI-shape ``tool_calls`` into the textual
+ ``N``
+ format consumed by the training chat template (mirrors
+ ``grpo_condensed.SYSTEM_PROMPT`` L232-239), restore the clean
+ system prompt, and emit one JSONL line.
+
+Run::
+
+ python cookbook/rl/make_condensed_sft.py \\
+ --output hotpotqa_sft_coldstart.jsonl \\
+ --model --api-key $KEY --base-url $URL \\
+ --total 9000 --easy 1500 --medium 3000 --hard 4500 \\
+ --concurrency 16 --seed 42 \\
+ --condenser-model-id ms://Qwen/Qwen3.5-4B \\
+ --condenser-lora ms://twinkle-kit/Qwen3.5-4B-Condenser
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import random
+import re
+import sys
+import threading
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Dict, List, Optional, Tuple
+
+from datasets import load_dataset
+
+from twinkle.data_format.sampling import SamplingParams
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Qwen3_5Template
+from twinkle_agentic.chunker.native import NativeChunker
+from twinkle_agentic.condenser import ModelCondenser
+from twinkle_agentic.data_format import Chunks
+from twinkle_agentic.protocol.openai import OpenAI
+from twinkle_agentic.reward.f1 import _extract_final_answer, _f1_score
+from twinkle_agentic.rollout import APIMultiTurnRollout
+from twinkle_agentic.tools.extract_condensed import ExtractCondensed
+from twinkle_agentic.tools.tool_manager import ToolManager
+
+
+# --------------------------------------------------------------------------
+# Constants mirrored from grpo_condensed.py so the SFT data matches the
+# runtime contract byte-for-byte. Re-import would pull the whole training
+# module; copying these few strings keeps the builder standalone.
+# --------------------------------------------------------------------------
+SYSTEM_PROMPT = """You are a careful multi-hop QA assistant.
+
+## Context Format (Mixed)
+The context you receive is a **mix of two forms**:
+
+1. **Compressed blocks** — long passages wrapped in `...`, \
+displayed as a Markdown digest in **telegraphic style** (no \
+articles / "is" / "are"; colons and commas mean "is" / "has") \
+with two sections:
+ - **Summary**: overview plus facts strongly related to the question, stated explicitly.
+ - **More**: a collapsed INDEX of category keywords hinting at extra details hidden in the full text (call `extract_condensed` to see them).
+ Reading example: `India: 7th largest by area. Borders: Pakistan, \
+China.` means "India is the 7th largest country by area and \
+shares borders with Pakistan and China."
+2. **Raw passages** — short passages shown inline as plain text (`Title: \
+body`) **without** any `` wrapping. These are already the full \
+text; nothing is hidden.
+
+Only the ``-wrapped blocks are compressed and can be expanded. \
+Block ids `N` are 1-based and assigned in the order compressed blocks \
+appear in the context, so they are always contiguous (``, \
+``, ``, ...). Raw passages have no block id and cannot \
+be extracted — they are already complete.
+
+## Workflow
+
+### Phase 1 — Scan and Decide
+Step 1: Read each compressed block's Summary, and read raw \
+passages directly, to get an overview.
+Step 2: For compressed blocks, check the More keywords to judge whether \
+hidden details are needed.
+Step 3: Decide which compressed blocks to expand, then call \
+`extract_condensed` with their block ids. Raw passages need no extraction.
+
+### Phase 2 — Reason and Answer
+After the tool returns the full text, continue stepping through the evidence:
+Step N: From block X (or the raw passage titled "..."), I learn that [fact A].
+Step N+1: From block Y, I need to call `extract_condensed` to get more information, because this block is related to...
+Step N+2: Combining these, the answer is ...
+\\boxed{answer}
+
+You may call `extract_condensed` several times to expand more blocks if the information is not enough, only answer the question if you are sure about the facts.
+The `blocks` parameter accepts **exactly one integer** per call (e.g. `3`); lists are rejected. Expand additional blocks by issuing separate `extract_condensed` calls, one per block. Only pass ids that actually appear as `` in the context, and do **not** request the same block twice — its text is already in the conversation after the first expansion.
+
+## Tool Call Format
+
+
+
+3
+
+
+
+
+## Output Format
+End your final response with \\boxed{answer}, e.g. \\boxed{Delhi}.
+Keep the boxed text short: a name, entity, date, or "yes"/"no".
+Answers not inside \\boxed{} will not be scored."""
+
+
+# Oracle suffix appended ONLY for API generation; stripped before save.
+_ORACLE_HINT_TEMPLATE = (
+ '\n\n## Oracle hint (PRIVATE — do NOT quote verbatim)\n'
+ 'The following supporting-fact titles and ground-truth answer are '
+ 'provided to make your final answer reliable. Use them as a signpost '
+ 'while you reason from the context; your final `\\boxed{{...}}` MUST '
+ 'paraphrase the ground truth using evidence from the blocks (after '
+ 'expanding compressed blocks when needed), not just echo it.\n'
+ 'Supporting facts (titles): {sf}\n'
+ 'Ground truth: {gt}\n'
+ 'You MUST still call `extract_condensed` on EVERY compressed block '
+ 'whose Summary or More keywords touch any supporting-fact title, even '
+ 'if the Summary already seems to state the answer — the compressed '
+ 'Summary occasionally loses pronoun referents or attribution and the '
+ 'raw passage is the authoritative source.'
+)
+
+
+VALIDATION_SYSTEM = (
+ 'You are a HotpotQA annotation auditor. Read the raw passages, the '
+ 'question, the supplied supporting-fact titles and the supplied '
+ 'ground-truth answer. Decide whether this row is usable for training '
+ 'a multi-hop QA model.\n\n'
+ 'Pathologies to catch (drop or fix):\n'
+ ' - question template leakage: the question literally contains the '
+ 'answer, references a passage id, or is malformed;\n'
+ ' - subject/answer mismatch: the GT does not actually answer the '
+ 'question given the passages (e.g. the question asks about an event '
+ 'X but GT is from a sibling event Y);\n'
+ ' - GT entity not present in any passage AND not directly inferable '
+ 'by a 2-hop bridge from the passages;\n'
+ ' - supporting-fact titles obviously incomplete for a 2-hop question.\n'
+ '\n'
+ 'Return STRICT JSON ONLY (no markdown fence, no preamble) with this '
+ 'exact shape:\n'
+ ' {"verdict": "ok"|"fix"|"drop", "reason": "", '
+ '"fixed_supporting_facts": ["", ...], '
+ '"fixed_ground_truth": ""}\n'
+ 'Use verdict "ok" when the supplied SF + GT are correct (then '
+ '"fixed_supporting_facts" and "fixed_ground_truth" MAY be empty). '
+ 'Use verdict "fix" when the question is answerable but SF or GT are '
+ 'wrong/incomplete -- fill the fixed fields with the corrected values, '
+ 'titles drawn verbatim from the passage titles below. Use verdict '
+ '"drop" when the question itself is invalid or unanswerable from the '
+ 'given passages.'
+)
+
+
+VALIDATION_USER_TEMPLATE = (
+ 'Question: {question}\n'
+ '\n'
+ 'Supplied supporting-fact titles: {sf}\n'
+ 'Supplied ground truth: {gt}\n'
+ '\n'
+ 'Passage titles (verbatim):\n{titles}\n'
+ '\n'
+ 'Passages (raw, uncompressed):\n\n{passages}'
+)
+
+
+# JSON Schema for the OpenAI API; the in-process ExtractCondensed tool's
+# tool_info() emits a free-form description that the OpenAI SDK rejects.
+EXTRACT_CONDENSED_TOOL: Dict[str, Any] = {
+ 'type': 'function',
+ 'function': {
+ 'name': 'extract_condensed',
+ 'description': (
+ 'Recover the full, uncompressed text of ONE previously '
+ 'condensed passage, identified by its tag. Use '
+ 'this tool whenever you need to re-read the original detail '
+ 'of a compressed block. Each call expands exactly one block; '
+ 'issue separate calls for additional blocks, and do not '
+ 'request the same block twice.'),
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'blocks': {
+ 'type': 'integer',
+ 'description': (
+ 'The 1-indexed block number N appearing inside '
+ '.... Exactly one block per '
+ 'call (e.g. 3); lists are rejected.'),
+ },
+ },
+ 'required': ['blocks'],
+ },
+ },
+}
+
+
+F1_ACCEPT_THRESHOLD: float = 0.5
+ROLLOUT_MAX_TURNS: int = 8
+ROLLOUT_MAX_TOKENS: int = 2048
+VALIDATION_MAX_TOKENS: int = 1024
+ROLLOUT_TEMPERATURE_LADDER: Tuple[float, ...] = (0.4, 0.7)
+
+
+# --------------------------------------------------------------------------
+# Trajectory + chunk helpers (mirror HotpotQAProcessor + production prompt).
+# --------------------------------------------------------------------------
+def _format_passage(title: str, sentences: Any) -> str:
+ if isinstance(sentences, list):
+ body = ' '.join(s.strip() for s in sentences if s and s.strip())
+ else:
+ body = str(sentences).strip()
+ return f'{title}: {body}'
+
+
+def _format_context(titles: List[str], sentences_list: List[Any]) -> str:
+ return '\n\n'.join(
+ _format_passage(t, s) for t, s in zip(titles, sentences_list))
+
+
+def _build_initial_trajectory(row: Dict[str, Any]) -> Dict[str, Any]:
+ """Build the pre-compression trajectory dict the chunker expects."""
+ ctx = row.get('context') or {}
+ titles = list(ctx.get('title') or [])
+ sentences_list = list(ctx.get('sentences') or [])
+ user_msg = (
+ f"Question: {row['question']}\n\n"
+ f'Context:\n\n{_format_context(titles, sentences_list)}')
+ return {
+ 'messages': [
+ {'role': 'system', 'content': SYSTEM_PROMPT},
+ {'role': 'user', 'content': user_msg},
+ ],
+ }
+
+
+def _extract_question_from_chunk(chunk):
+ content = chunk.get('content')
+ if chunk.get('type') != 'text' or not isinstance(content, str):
+ return None
+ m = re.search(r'\AQuestion:\s*(.+)', content)
+ return m.group(1).strip() if m else None
+
+
+# --------------------------------------------------------------------------
+# Per-batch compression (re-use MultiTurnCondenseRollout's batching trick:
+# merge all per-row chunks into ONE Chunks so the sampler sees a packed batch).
+# --------------------------------------------------------------------------
+def compress_rows(
+ rows: List[Dict[str, Any]],
+ chunker: NativeChunker,
+ condenser: ModelCondenser,
+) -> List[Tuple[Dict[str, Any], Chunks]]:
+ """Return ``[(compressed_trajectory_dict, per_row_Chunks), ...]``.
+
+ ``compressed_trajectory_dict`` already has ``...``
+ wrapping in its user message (see :meth:`Chunks.to_trajectory`).
+ ``per_row_Chunks`` carries ``raw.original`` snapshots so
+ :class:`ExtractCondensed` can return the pre-compression text.
+ """
+ if not rows:
+ return []
+ initial = [_build_initial_trajectory(r) for r in rows]
+ per_row_chunks = [chunker(t) for t in initial]
+ merged_list: List[Any] = []
+ boundaries: List[int] = []
+ for ck in per_row_chunks:
+ merged_list.extend(ck.chunks)
+ boundaries.append(len(merged_list))
+ merged = condenser(Chunks(chunks=merged_list))
+ out: List[Tuple[Dict[str, Any], Chunks]] = []
+ start = 0
+ for end in boundaries:
+ slc = Chunks(chunks=list(merged.chunks[start:end]))
+ out.append((slc.to_trajectory(), slc))
+ start = end
+ return out
+
+
+# --------------------------------------------------------------------------
+# Stage 1: validation pass.
+# --------------------------------------------------------------------------
+_JSON_FENCE_RE = re.compile(r'```(?:json)?\s*\n(.*?)\n```', re.DOTALL)
+
+
+def _extract_json_object(text: str) -> Optional[Dict[str, Any]]:
+ """Best-effort JSON parse: strip fence, then locate first ``{...}`` block."""
+ if not text:
+ return None
+ candidate = text.strip()
+ m = _JSON_FENCE_RE.search(candidate)
+ if m:
+ candidate = m.group(1).strip()
+ depth = 0
+ start = -1
+ for i, ch in enumerate(candidate):
+ if ch == '{':
+ if depth == 0:
+ start = i
+ depth += 1
+ elif ch == '}':
+ depth -= 1
+ if depth == 0 and start != -1:
+ blob = candidate[start:i + 1]
+ try:
+ return json.loads(blob)
+ except json.JSONDecodeError:
+ start = -1
+ continue
+ return None
+
+
+def validate_row(
+ api: OpenAI, row: Dict[str, Any], original_gt: List[str], sf_titles: List[str],
+) -> Optional[Dict[str, Any]]:
+ """Return parsed JSON verdict, or ``None`` on unrecoverable parse failure."""
+ ctx = row.get('context') or {}
+ titles = list(ctx.get('title') or [])
+ sentences_list = list(ctx.get('sentences') or [])
+ passages = _format_context(titles, sentences_list)
+ user = VALIDATION_USER_TEMPLATE.format(
+ question=row['question'],
+ sf=json.dumps(sf_titles, ensure_ascii=False),
+ gt=json.dumps(original_gt, ensure_ascii=False),
+ titles='\n'.join(f'- {t}' for t in titles),
+ passages=passages,
+ )
+ trajectory = {
+ 'messages': [
+ {'role': 'system', 'content': VALIDATION_SYSTEM},
+ {'role': 'user', 'content': user},
+ ],
+ }
+ sp = SamplingParams(
+ temperature=0.0, max_tokens=VALIDATION_MAX_TOKENS, num_samples=1)
+ for attempt in range(2):
+ try:
+ reply = api(
+ trajectory, sp, extra_body={'enable_thinking': True})
+ except Exception as exc:
+ sys.stderr.write(f'[validate] row={row.get("id")} attempt={attempt} api error: {exc}\n')
+ return None
+ content = reply.get('content') or ''
+ parsed = _extract_json_object(content)
+ if parsed and parsed.get('verdict') in ('ok', 'fix', 'drop'):
+ return parsed
+ return None
+
+
+def resolve_validation(
+ verdict: Dict[str, Any], original_gt: List[str], sf_titles: List[str],
+) -> Tuple[List[str], List[str]]:
+ """Pick the SF + GT list to use downstream based on verdict."""
+ v = verdict.get('verdict')
+ if v == 'fix':
+ fixed_gt = verdict.get('fixed_ground_truth') or ''
+ fixed_sf = verdict.get('fixed_supporting_facts') or []
+ gt_list: List[str] = []
+ if isinstance(fixed_gt, list):
+ gt_list = [str(x).strip() for x in fixed_gt if str(x).strip()]
+ elif isinstance(fixed_gt, str) and fixed_gt.strip():
+ gt_list = [fixed_gt.strip()]
+ if not gt_list:
+ gt_list = original_gt
+ sf_list = (
+ [str(x).strip() for x in fixed_sf if str(x).strip()]
+ if isinstance(fixed_sf, list) else sf_titles)
+ if not sf_list:
+ sf_list = sf_titles
+ return gt_list, sf_list
+ return original_gt, sf_titles
+
+
+# --------------------------------------------------------------------------
+# Stage 2 prep: build oracle trajectory + per-trajectory ToolManager.
+# --------------------------------------------------------------------------
+def _oracle_system_prompt(sf_titles: List[str], gt_list: List[str]) -> str:
+ sf_render = ', '.join(repr(t) for t in sf_titles) if sf_titles else '(none)'
+ gt_render = ' | '.join(gt_list) if gt_list else '(unknown)'
+ return SYSTEM_PROMPT + _ORACLE_HINT_TEMPLATE.format(
+ sf=sf_render, gt=gt_render)
+
+
+def _build_oracle_trajectory(
+ compressed_traj: Dict[str, Any],
+ sf_titles: List[str],
+ gt_list: List[str],
+) -> Dict[str, Any]:
+ """Replace the system message with the oracle-suffixed variant and
+ attach the JSON-schema tools field consumed by the OpenAI API."""
+ oracle_sp = _oracle_system_prompt(sf_titles, gt_list)
+ out_messages: List[Dict[str, Any]] = []
+ sys_inserted = False
+ for m in compressed_traj.get('messages') or []:
+ if m.get('role') == 'system' and not sys_inserted:
+ out_messages.append({'role': 'system', 'content': oracle_sp})
+ sys_inserted = True
+ else:
+ out_messages.append(dict(m))
+ if not sys_inserted:
+ out_messages.insert(0, {'role': 'system', 'content': oracle_sp})
+ return {
+ 'messages': out_messages,
+ 'tools': [EXTRACT_CONDENSED_TOOL],
+ }
+
+
+def _make_tool_manager(chunks: Chunks) -> ToolManager:
+ """One ToolManager + ExtractCondensed per trajectory; the tool keeps
+ a ``_already_expanded`` set, so reusing across trials would lie to
+ the model on retry."""
+ tm = ToolManager()
+ tm.register(ExtractCondensed(chunks))
+ return tm
+
+
+# --------------------------------------------------------------------------
+# Stage 3 + 4: F1 acceptance + conversion to training-runtime format.
+# --------------------------------------------------------------------------
+def boxed_f1(boxed: str, gt_list: List[str]) -> float:
+ if not boxed or not gt_list:
+ return 0.0
+ return max(_f1_score(boxed, g)[0] for g in gt_list)
+
+
+def _last_assistant_text(messages: List[Dict[str, Any]]) -> str:
+ for m in reversed(messages):
+ if m.get('role') == 'assistant' and isinstance(m.get('content'), str):
+ return m['content']
+ return ''
+
+
+def _format_tool_call_text(blocks: int) -> str:
+ return (
+ '\n'
+ '\n'
+ '\n'
+ f'{blocks}\n'
+ '\n'
+ '\n'
+ ''
+ )
+
+
+def convert_to_runtime_messages(
+ api_messages: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """OpenAI tool_calls -> textual format consumed by the
+ training chat template. The first system message has its oracle
+ suffix stripped (we just replace it with the clean SYSTEM_PROMPT).
+ """
+ out: List[Dict[str, Any]] = []
+ sys_done = False
+ for m in api_messages:
+ role = m.get('role')
+ if role == 'system' and not sys_done:
+ out.append({'role': 'system', 'content': SYSTEM_PROMPT})
+ sys_done = True
+ continue
+ if role == 'assistant':
+ content = m.get('content') or ''
+ tool_calls = m.get('tool_calls') or []
+ if tool_calls:
+ pieces = [content.rstrip()] if content else []
+ for tc in tool_calls:
+ fn = tc.get('function') or {}
+ args_raw = fn.get('arguments')
+ try:
+ args = (
+ json.loads(args_raw) if isinstance(args_raw, str)
+ else (args_raw or {}))
+ except json.JSONDecodeError:
+ args = {}
+ blocks_val = args.get('blocks', args.get('block'))
+ try:
+ n = int(blocks_val)
+ except (TypeError, ValueError):
+ continue
+ pieces.append(_format_tool_call_text(n))
+ text = '\n\n'.join(p for p in pieces if p)
+ out.append({'role': 'assistant', 'content': text})
+ else:
+ out.append({'role': 'assistant', 'content': content})
+ continue
+ if role == 'tool':
+ out.append({'role': 'tool', 'content': m.get('content') or ''})
+ continue
+ out.append({k: v for k, v in m.items() if k in ('role', 'content')})
+ return out
+
+
+def trajectory_achieved_ratio(chunks: Chunks) -> float:
+ total_src = 0
+ total_cmp = 0
+ for c in chunks.chunks:
+ if c.get('type') != 'text':
+ continue
+ raw = c.get('raw')
+ if not (isinstance(raw, dict) and raw.get('condensed')):
+ continue
+ original = raw.get('original')
+ compressed = c.get('content')
+ if isinstance(original, str) and isinstance(compressed, str):
+ total_src += len(original)
+ total_cmp += len(compressed)
+ return round(total_cmp / total_src, 4) if total_src else 0.0
+
+
+def build_record(
+ row: Dict[str, Any],
+ runtime_messages: List[Dict[str, Any]],
+ chunks: Chunks,
+ verdict: Dict[str, Any],
+ original_gt: List[str],
+ used_gt: List[str],
+ used_sf: List[str],
+ boxed: str,
+ f1: float,
+ num_tool_calls: int,
+) -> Dict[str, Any]:
+ ctx = row.get('context') or {}
+ titles = list(ctx.get('title') or [])
+ sentences_list = list(ctx.get('sentences') or [])
+ raw_passages = [
+ {
+ 'title': t,
+ 'sentences': list(s) if isinstance(s, list) else [str(s)],
+ }
+ for t, s in zip(titles, sentences_list)
+ ]
+ sf_full = row.get('supporting_facts') or {}
+ return {
+ 'id': row['id'],
+ 'level': row.get('level'),
+ 'type': row.get('type'),
+ 'messages': runtime_messages,
+ 'tools': [EXTRACT_CONDENSED_TOOL],
+ 'meta': {
+ 'num_tool_calls': num_tool_calls,
+ 'achieved_ratio': trajectory_achieved_ratio(chunks),
+ 'validation_verdict': verdict.get('verdict'),
+ 'validation_reason': verdict.get('reason'),
+ 'original_question': row.get('question'),
+ 'original_answer': row.get('answer'),
+ 'original_gt': original_gt,
+ 'used_gt': used_gt,
+ 'used_supporting_facts': used_sf,
+ 'original_supporting_facts': {
+ 'title': list(sf_full.get('title') or []),
+ 'sent_id': list(sf_full.get('sent_id') or []),
+ },
+ 'original_passages': raw_passages,
+ 'f1': round(f1, 4),
+ 'boxed': boxed,
+ },
+ }
+
+
+# --------------------------------------------------------------------------
+# Per-batch pipeline orchestration.
+# --------------------------------------------------------------------------
+def _extract_original_gt_sf(row: Dict[str, Any]) -> Tuple[List[str], List[str]]:
+ answers = row.get('answers')
+ if isinstance(answers, list) and answers:
+ original_gt = [str(a).strip() for a in answers if str(a).strip()]
+ else:
+ original_gt = [(row.get('answer', '') or '').strip()]
+ original_gt = [g for g in original_gt if g]
+ sf = row.get('supporting_facts') or {}
+ sf_titles = list(dict.fromkeys(t for t in (sf.get('title') or []) if t))
+ return original_gt, sf_titles
+
+
+def _validate_in_parallel(
+ api: OpenAI, batch: List[Dict[str, Any]], pool: ThreadPoolExecutor,
+) -> Tuple[List[Optional[Dict[str, Any]]], List[Tuple[List[str], List[str]]]]:
+ """Run ``validate_row`` for every row in parallel (one OpenAI call each)."""
+ futures = []
+ payloads: List[Tuple[List[str], List[str]]] = []
+ for row in batch:
+ original_gt, sf_titles = _extract_original_gt_sf(row)
+ payloads.append((original_gt, sf_titles))
+ futures.append(pool.submit(
+ validate_row, api, row, original_gt, sf_titles))
+ verdicts: List[Optional[Dict[str, Any]]] = [f.result() for f in futures]
+ return verdicts, payloads
+
+
+def _num_tool_calls(messages: List[Dict[str, Any]]) -> int:
+ return sum(
+ len(m.get('tool_calls') or [])
+ for m in messages if m.get('role') == 'assistant')
+
+
+def process_batch(
+ api: OpenAI,
+ rollout: APIMultiTurnRollout,
+ batch: List[Dict[str, Any]],
+ chunker: NativeChunker,
+ condenser: ModelCondenser,
+ validation_pool: ThreadPoolExecutor,
+) -> List[Dict[str, Any]]:
+ """Validate -> compress -> rollout (T-ladder) -> accept. Returns the
+ list of accepted JSONL records for the batch."""
+ if not batch:
+ return []
+ # 1. Validation in parallel.
+ verdicts, payloads = _validate_in_parallel(api, batch, validation_pool)
+
+ survivors_meta: List[Dict[str, Any]] = []
+ for row, verdict, (original_gt, sf_titles) in zip(batch, verdicts, payloads):
+ if verdict is None or verdict.get('verdict') == 'drop':
+ continue
+ if not original_gt:
+ continue
+ used_gt, used_sf = resolve_validation(verdict, original_gt, sf_titles)
+ if not used_gt:
+ continue
+ survivors_meta.append({
+ 'row': row, 'verdict': verdict,
+ 'original_gt': original_gt,
+ 'used_gt': used_gt, 'used_sf': used_sf,
+ })
+ if not survivors_meta:
+ return []
+
+ # 2. Compress survivors (one packed batch through ModelCondenser).
+ survivor_rows = [m['row'] for m in survivors_meta]
+ try:
+ compressed = compress_rows(survivor_rows, chunker, condenser)
+ except Exception as exc:
+ sys.stderr.write(f'[compress] batch crashed: {exc}\n')
+ return []
+
+ # 3. Build oracle trajectories + per-trajectory ToolManagers.
+ trajs: List[Dict[str, Any]] = []
+ chunks_list: List[Chunks] = []
+ for meta, (compressed_traj, chunks) in zip(survivors_meta, compressed):
+ trajs.append(_build_oracle_trajectory(
+ compressed_traj, meta['used_sf'], meta['used_gt']))
+ chunks_list.append(chunks)
+
+ # 4. Temperature ladder. Each rung gets fresh ExtractCondensed tools so
+ # a retry does not see the previous attempt's already-expanded set.
+ accepted: List[Dict[str, Any]] = []
+ pending_idx = list(range(len(trajs)))
+ for temperature in ROLLOUT_TEMPERATURE_LADDER:
+ if not pending_idx:
+ break
+ sp = SamplingParams(
+ temperature=temperature, max_tokens=ROLLOUT_MAX_TOKENS, num_samples=1)
+ run_trajs = [trajs[i] for i in pending_idx]
+ run_tms = [_make_tool_manager(chunks_list[i]) for i in pending_idx]
+ try:
+ outs = rollout(
+ run_trajs, tool_manager=run_tms, sampling_params=sp)
+ except Exception as exc:
+ sys.stderr.write(f'[rollout] batch crashed at T={temperature}: {exc}\n')
+ return accepted
+ next_pending: List[int] = []
+ for local_pos, traj_idx in enumerate(pending_idx):
+ out_traj = outs[local_pos]
+ if out_traj.get('stop_reason') == 'api_error':
+ continue # hard-drop API failures, do not retry
+ messages = out_traj.get('messages') or []
+ boxed = _extract_final_answer(_last_assistant_text(messages))
+ meta = survivors_meta[traj_idx]
+ f1 = boxed_f1(boxed, meta['used_gt'])
+ if f1 >= F1_ACCEPT_THRESHOLD:
+ runtime_messages = convert_to_runtime_messages(messages)
+ accepted.append(build_record(
+ row=meta['row'],
+ runtime_messages=runtime_messages,
+ chunks=chunks_list[traj_idx],
+ verdict=meta['verdict'],
+ original_gt=meta['original_gt'],
+ used_gt=meta['used_gt'],
+ used_sf=meta['used_sf'],
+ boxed=boxed, f1=f1,
+ num_tool_calls=_num_tool_calls(messages)))
+ else:
+ next_pending.append(traj_idx)
+ pending_idx = next_pending
+ return accepted
+
+
+# --------------------------------------------------------------------------
+# Stratified sampling + resume.
+# --------------------------------------------------------------------------
+LEVELS: Tuple[str, str, str] = ('easy', 'medium', 'hard')
+
+
+def stratified_sample(
+ ds, per_level: Dict[str, int], seed: int,
+) -> List[Dict[str, Any]]:
+ rng = random.Random(seed)
+ buckets: Dict[str, List[int]] = {lv: [] for lv in LEVELS}
+ for i, lv in enumerate(ds['level']):
+ if lv in buckets:
+ buckets[lv].append(i)
+ picked: List[int] = []
+ for lv in LEVELS:
+ need = per_level[lv]
+ pool = buckets[lv]
+ if len(pool) < need:
+ raise RuntimeError(
+ f'level={lv} has only {len(pool)} rows, need {need}')
+ picked.extend(rng.sample(pool, need))
+ rng.shuffle(picked)
+ return [ds[int(i)] for i in picked]
+
+
+def load_done_ids(path: str) -> set:
+ if not os.path.exists(path):
+ return set()
+ done = set()
+ with open(path, 'r', encoding='utf-8') as fh:
+ for line in fh:
+ try:
+ obj = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+ rid = obj.get('id')
+ if rid:
+ done.add(rid)
+ return done
+
+
+def apply_reannotation_overlay(
+ rows: List[Dict[str, Any]], path: str,
+) -> List[Dict[str, Any]]:
+ """Drop verdict=drop ids; overlay ``question_fixed`` and multi-form ``answers``.
+
+ The validation stage in ``process_batch`` still runs on every survivor
+ because the audit ran on a different HF subset (fullwiki) than this
+ builder's default (distractor) and passage contexts differ.
+ """
+ overrides: Dict[str, Dict[str, Any]] = {}
+ drop_ids: set = set()
+ with open(path, 'r', encoding='utf-8') as fh:
+ for line in fh:
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ obj = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+ rid = obj.get('id')
+ if not rid:
+ continue
+ if obj.get('verdict') == 'drop':
+ drop_ids.add(rid)
+ else:
+ overrides[rid] = obj
+ out: List[Dict[str, Any]] = []
+ overridden = 0
+ for row in rows:
+ rid = row.get('id')
+ if rid in drop_ids:
+ continue
+ ov = overrides.get(rid)
+ if ov is not None:
+ row = dict(row)
+ qfix = (ov.get('question_fixed') or '').strip()
+ if qfix:
+ row['question'] = qfix
+ ans = [str(a).strip() for a in (ov.get('answers') or []) if str(a).strip()]
+ if ans:
+ row['answers'] = ans
+ overridden += 1
+ out.append(row)
+ sys.stderr.write(
+ f'[REANNOTATED] {path}: {len(rows)} -> {len(out)} rows '
+ f'(dropped={len(drop_ids)}, overridden={overridden})\n')
+ return out
+
+
+# --------------------------------------------------------------------------
+# CLI + main loop.
+# --------------------------------------------------------------------------
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--output', required=True)
+ parser.add_argument('--model', required=True,
+ help='Super-LLM model name (OpenAI-protocol).')
+ parser.add_argument('--api-key', default=os.environ.get('OPENAI_API_KEY'))
+ parser.add_argument('--base-url', default=os.environ.get('OPENAI_BASE_URL'))
+ parser.add_argument('--total', type=int, default=12000)
+ parser.add_argument('--easy', type=int, default=2000)
+ parser.add_argument('--medium', type=int, default=4000)
+ parser.add_argument('--hard', type=int, default=6000)
+ parser.add_argument('--concurrency', type=int, default=16)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--reannotated', default=os.environ.get('REANNOTATED_FILE', ''),
+ help='Path to wrong_ids_reannotated.jsonl. Drops verdict=drop ids and overlays question_fixed + multi-form answers. Validation stage still runs because the audit was on a different HF subset.')
+ parser.add_argument('--hf-subset', default='distractor')
+ parser.add_argument('--hf-split', default='train')
+ parser.add_argument('--condenser-model-id',
+ default=os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B'))
+ parser.add_argument('--condenser-lora',
+ default='ms://twinkle-kit/Qwen3.5-4B-Condenser')
+ parser.add_argument('--chunk-size', type=int, default=1024)
+ parser.add_argument('--hotpotqa-max-length', type=int, default=64000)
+ parser.add_argument('--compress-batch-size', type=int, default=32,
+ help='How many rows to feed to ModelCondenser at once.')
+ parser.add_argument('--gpu-memory-utilization', type=float, default=0.8)
+ return parser.parse_args()
+
+
+def build_condenser(args: argparse.Namespace) -> Tuple[NativeChunker, ModelCondenser]:
+ sampler = vLLMSampler(
+ model_id=args.condenser_model_id,
+ engine_args={
+ 'gpu_memory_utilization': args.gpu_memory_utilization,
+ 'max_model_len': max(8192, args.hotpotqa_max_length),
+ 'max_lora_rank': 32,
+ 'enable_lora': True,
+ 'max_loras': 2,
+ },
+ )
+ sampler.set_template(
+ 'Qwen3_5Template', model_id=args.condenser_model_id,
+ enable_thinking=False, max_length=args.hotpotqa_max_length)
+ rollout_template = Qwen3_5Template(
+ args.condenser_model_id, max_length=args.hotpotqa_max_length,
+ enable_thinking=False)
+ chunker = NativeChunker(
+ chunk_size=args.chunk_size,
+ passage_boundary_re=r'(?<=\n\n)',
+ )
+ condenser = ModelCondenser(
+ sampler=sampler,
+ compression_ratio=2.0,
+ sampling_params=SamplingParams(
+ max_tokens=1024, num_samples=1, temperature=0.4, top_p=0.9),
+ min_chars=200,
+ template=rollout_template,
+ lora_path=args.condenser_lora or None,
+ skip_pattern=r'^Question:',
+ related_query=_extract_question_from_chunk,
+ )
+ return chunker, condenser
+
+
+def main() -> None:
+ args = parse_args()
+ if args.easy + args.medium + args.hard != args.total:
+ raise ValueError(
+ f'--easy + --medium + --hard ({args.easy + args.medium + args.hard}) '
+ f'must equal --total ({args.total})')
+ per_level = {'easy': args.easy, 'medium': args.medium, 'hard': args.hard}
+
+ sys.stderr.write(
+ f'Loading hotpotqa/hotpot_qa:{args.hf_subset}:{args.hf_split}...\n')
+ ds = load_dataset(
+ 'hotpotqa/hotpot_qa', args.hf_subset, split=args.hf_split)
+
+ rows = stratified_sample(ds, per_level=per_level, seed=args.seed)
+ if args.reannotated.strip():
+ rows = apply_reannotation_overlay(rows, args.reannotated.strip())
+ done = load_done_ids(args.output)
+ sys.stderr.write(f'Resume: {len(done)} rows already emitted.\n')
+ pending = [r for r in rows if r['id'] not in done]
+ sys.stderr.write(f'Pending: {len(pending)} / {len(rows)}\n')
+
+ chunker, condenser = build_condenser(args)
+ api = OpenAI(
+ model=args.model, api_key=args.api_key, base_url=args.base_url)
+
+ # APIMultiTurnRollout itself owns the per-trajectory thread pool. The
+ # validation phase runs on a separate pool of equal size; both phases
+ # are network-bound so we never need more threads than ``concurrency``.
+ rollout = APIMultiTurnRollout(
+ api=api,
+ tool_manager=ToolManager(), # placeholder; per-call list overrides
+ sampling_params=SamplingParams(
+ temperature=ROLLOUT_TEMPERATURE_LADDER[0],
+ max_tokens=ROLLOUT_MAX_TOKENS, num_samples=1),
+ max_turns=ROLLOUT_MAX_TURNS,
+ concurrency=args.concurrency,
+ extra_body={'enable_thinking': False},
+ )
+
+ write_lock = threading.Lock()
+ out_fh = open(args.output, 'a', encoding='utf-8')
+ accepted_total = 0
+ seen_total = 0
+
+ with ThreadPoolExecutor(max_workers=args.concurrency) as validation_pool:
+ try:
+ for start in range(0, len(pending), args.compress_batch_size):
+ batch = pending[start:start + args.compress_batch_size]
+ seen_total += len(batch)
+ try:
+ records = process_batch(
+ api, rollout, batch, chunker, condenser,
+ validation_pool)
+ except Exception as exc:
+ sys.stderr.write(
+ f'[batch {start}-{start + len(batch)}] crashed: {exc}\n')
+ continue
+ with write_lock:
+ for record in records:
+ out_fh.write(
+ json.dumps(record, ensure_ascii=False) + '\n')
+ out_fh.flush()
+ accepted_total += len(records)
+ sys.stderr.write(
+ f'[progress] seen={seen_total}/{len(pending)} '
+ f'accepted={accepted_total} '
+ f'(+{len(records)} from this batch)\n')
+ finally:
+ out_fh.close()
+
+ sys.stderr.write(
+ f'Done. accepted={accepted_total} total_pending={len(pending)}\n')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/exp/legacy/reannotate_groundtruth.py b/cookbook/exp/legacy/reannotate_groundtruth.py
new file mode 100644
index 000000000..137ebb4b9
--- /dev/null
+++ b/cookbook/exp/legacy/reannotate_groundtruth.py
@@ -0,0 +1,389 @@
+"""Re-annotate HotpotQA ground truth using a super-LLM to ensure correctness.
+
+The original HotpotQA dataset has annotation issues:
+ - GT doesn't match the question type (asks "where", GT gives a name)
+ - Partial/incomplete answers for multi-hop questions
+ - Single form when multiple valid forms exist (e.g. "2" vs "two")
+ - Question itself malformed (wrong question word, truncation, presupposition
+ mismatch with the answer type)
+
+This script:
+ 1. Loads HotpotQA fullwiki train split.
+ 2. By default (--only-forced), re-annotates ONLY the IDs listed in
+ wrong_ids.txt (the 340 known-bad cases).
+ Pass --no-only-forced to fall back to stratified 3000-per-level sampling
+ with wrong_ids force-included.
+ 3. For each row, sends question + full context + original GT to a super-LLM.
+ 4. The LLM emits one of four verdicts and (when applicable) a multi-form
+ answer list and/or a repaired question:
+ - keep: original Q + A are both correct
+ - fix_answer: Q is fine; A is wrong/incomplete
+ - fix_question: Q is malformed but repairable into a well-formed Q
+ that the same passages answer with the same gold facts
+ - drop: Q cannot be repaired without changing the fact, OR
+ passages do not support any answer
+ 5. Outputs ONE JSONL file containing all rows (including drop). Each row has
+ verdict, question, question_fixed, answers, reasoning. Downstream filters
+ by verdict.
+
+Run (re-clean wrong_ids.txt only, default):
+ python reannotate_groundtruth.py \
+ --model qwen-max --api-key $OPENAI_API_KEY \
+ --base-url https://dashscope.aliyuncs.com/compatible-mode/v1 \
+ --output hotpotqa_reannotated_wrong.jsonl --concurrency 16
+"""
+import argparse
+import json
+import os
+import random
+import re
+import sys
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Any, Dict, List, Optional, Tuple
+
+from datasets import load_dataset
+
+from twinkle.data_format.sampling import SamplingParams
+from twinkle_agentic.protocol.openai import OpenAI
+
+
+VERIFY_SYSTEM = """You are a dataset quality auditor for a multi-hop QA benchmark (HotpotQA).
+
+Given a Question, supporting Context passages, and the dataset's Original Answer, output ONE of four verdicts and a multi-form answer list grounded in the passages.
+
+VERDICTS
+- "keep": original question + original answer are both correct.
+- "fix_answer": question is fine; original answer is wrong/incomplete.
+- "fix_question": question is malformed (wrong question word, broken grammar, truncated, or presupposition mismatch with the answer type) but can be REPAIRED into a well-formed question that the SAME passages answer with the SAME gold facts.
+- "drop": question cannot be repaired without changing the underlying fact, OR the passages do not support any answer.
+
+MULTI-FORM ANSWER RULES (apply to keep / fix_answer / fix_question)
+1. Output ALL acceptable surface forms whenever applicable:
+ - Number variants: arabic + english word + hyphen-prefix form (e.g. "3", "three", "three-door", "3-door")
+ - Range variants: start, end, and full range string (e.g. "1901", "1902", "1901-1902", "1901-2")
+ - Location variants: city / state-or-province / country (e.g. "Everett", "Washington", "WA", "United States")
+ - Person variants: legal name / nickname / full name (e.g. "Allan", "Heywood", "Allan Stewart Konigsberg")
+ - Entity-role pairs for role-of-X questions: BOTH the role AND the entity (e.g. "chauffeur", "Hitler's chauffeur")
+ - Show-vs-character pairs for best-known-for questions: BOTH the show AND the character (e.g. "M*A*S*H", "Major Frank Burns")
+ - Common abbreviations (e.g. "NYC", "New York City", "New York")
+ - With/without titles (e.g. "Dr. Smith", "Smith")
+ - Different date formats if applicable (e.g. "July 4, 1776", "4 July 1776")
+2. Each answer is SHORT (a name, entity, number, date, or yes/no).
+3. yes/no answers MUST be lowercase ["yes"] or ["no"].
+4. Do NOT hallucinate. Every answer must be grounded in the provided passages.
+
+QUESTION REWRITE RULES (verdict = fix_question)
+1. question_fixed MUST be answerable by the SAME passages and yield the SAME factual answer as the original gold facts.
+2. Allowed edits: swap question word (Where -> Did / Who / What), repair grammar, complete truncation, align question word with the answer type.
+3. FORBIDDEN: changing intent, injecting the answer into the question, adding facts not in the passages.
+4. If you cannot satisfy these constraints, downgrade to "drop".
+
+DROP RULES (verdict = drop)
+- answers MUST be [] and question_fixed MUST be null.
+
+OUTPUT FORMAT (JSON only, no markdown fence, no explanation)
+{"verdict": "keep|fix_answer|fix_question|drop", "question_fixed": "..." | null, "answers": ["..."], "reasoning": "one sentence"}"""
+
+VERIFY_USER = """## Question
+{question}
+
+## Original Answer (may be wrong)
+{original_answer}
+
+## Supporting Passages
+{context}
+
+## Task
+Audit the row per the system rules. Pick exactly one verdict (keep / fix_answer / fix_question / drop), produce the multi-form answers list (or [] for drop), and write a one-sentence reasoning. If verdict=fix_question, also produce question_fixed; otherwise set it to null.
+Return a single JSON object only."""
+
+
+LEVELS: Tuple[str, str, str] = ('easy', 'medium', 'hard')
+
+
+def _format_context(context: Dict[str, Any]) -> str:
+ titles = context.get('title', []) or []
+ sentences = context.get('sentences', []) or []
+ lines = []
+ for i, (title, sents) in enumerate(zip(titles, sentences), start=1):
+ if isinstance(sents, list):
+ body = ' '.join(s.strip() for s in sents if s and s.strip())
+ else:
+ body = str(sents).strip()
+ lines.append(f'[{i}] {title}: {body}')
+ return '\n\n'.join(lines)
+
+
+_JSON_RE = re.compile(r'\{[^{}]*"verdict"\s*:\s*"[^"]+"[^{}]*"answers"\s*:\s*\[.*?\][^{}]*\}', re.DOTALL)
+
+_VALID_VERDICTS = ('keep', 'fix_answer', 'fix_question', 'drop')
+
+
+def _parse_response(text: str) -> Optional[Dict[str, Any]]:
+ text = text.strip()
+ if text.startswith('```'):
+ first_nl = text.find('\n')
+ last_fence = text.rfind('```')
+ if first_nl != -1 and last_fence > first_nl:
+ text = text[first_nl + 1:last_fence].strip()
+ try:
+ obj = json.loads(text)
+ if isinstance(obj, dict) and 'answers' in obj:
+ return obj
+ except json.JSONDecodeError:
+ pass
+ m = _JSON_RE.search(text)
+ if m:
+ try:
+ return json.loads(m.group(0))
+ except json.JSONDecodeError:
+ pass
+ return None
+
+
+def _validate_verdict(
+ verdict: Optional[str], answers: List[str],
+ qfix: Optional[str], original_question: str,
+) -> bool:
+ if verdict not in _VALID_VERDICTS:
+ return False
+ if verdict == 'drop':
+ return not answers and qfix is None
+ if not answers:
+ return False
+ if verdict == 'fix_question':
+ return bool(qfix) and qfix.strip() != original_question.strip()
+ return qfix is None
+
+
+def verify_answer(
+ api: OpenAI, model: str, row: Dict[str, Any],
+) -> Optional[Dict[str, Any]]:
+ question = row['question']
+ original_answer = row.get('answer', '') or ''
+ context_str = _format_context(row.get('context', {}) or {})
+
+ user_content = VERIFY_USER.format(
+ question=question,
+ original_answer=original_answer,
+ context=context_str)
+
+ trajectory = {
+ 'messages': [
+ {'role': 'system', 'content': VERIFY_SYSTEM},
+ {'role': 'user', 'content': user_content},
+ ]
+ }
+ sp = SamplingParams(temperature=0.1, max_tokens=512)
+
+ for attempt in range(3):
+ try:
+ reply = api(trajectory, sp, extra_body={'enable_thinking': True})
+ except Exception as exc:
+ sys.stderr.write(f'[verify] {row["id"]}: API error: {exc}\n')
+ if attempt < 2:
+ continue
+ return None
+
+ content = reply.get('content') or ''
+ parsed = _parse_response(content)
+ if parsed:
+ verdict = parsed.get('verdict')
+ answers_raw = parsed.get('answers')
+ answers = (
+ [str(a).strip() for a in answers_raw if str(a).strip()]
+ if isinstance(answers_raw, list) else [])
+ qfix_raw = parsed.get('question_fixed')
+ qfix = (qfix_raw.strip() or None) if isinstance(qfix_raw, str) else None
+ if _validate_verdict(verdict, answers, qfix, question):
+ return {
+ 'id': row['id'],
+ 'verdict': verdict,
+ 'question': question,
+ 'question_fixed': qfix,
+ 'original_answer': original_answer,
+ 'answers': answers,
+ 'reasoning': parsed.get('reasoning', ''),
+ 'level': row.get('level', ''),
+ 'type': row.get('type', ''),
+ 'context': row.get('context', {}),
+ 'supporting_facts': row.get('supporting_facts', {}),
+ }
+ sys.stderr.write(
+ f'[verify retry {attempt+1}] {row["id"]}: '
+ f'parse failed, content={content[:200]!r}\n')
+
+ sys.stderr.write(f'[verify drop] {row["id"]}: all attempts failed\n')
+ return None
+
+
+def stratified_sample_with_forced(
+ ds, per_level: Dict[str, int], forced_ids: frozenset, seed: int,
+) -> List[Dict[str, Any]]:
+ rng = random.Random(seed)
+ buckets: Dict[str, List[int]] = {lv: [] for lv in LEVELS}
+ forced_indices: List[int] = []
+ forced_levels: Dict[str, int] = {lv: 0 for lv in LEVELS}
+
+ for i in range(len(ds)):
+ row_id = ds[i]['id']
+ level = (ds[i].get('level') or '').strip().lower()
+ if row_id in forced_ids:
+ forced_indices.append(i)
+ if level in forced_levels:
+ forced_levels[level] += 1
+ elif level in buckets:
+ buckets[level].append(i)
+
+ picked_set = set(forced_indices)
+ for lv in LEVELS:
+ need = max(0, per_level[lv] - forced_levels[lv])
+ pool = [idx for idx in buckets[lv] if idx not in picked_set]
+ if len(pool) < need:
+ sys.stderr.write(
+ f'Warning: level={lv} has {len(pool)} available, need {need}\n')
+ need = len(pool)
+ sampled = rng.sample(pool, need)
+ picked_set.update(sampled)
+
+ picked = sorted(picked_set)
+ rng.shuffle(picked)
+ return [ds[int(i)] for i in picked]
+
+
+def select_forced_only(ds, forced_ids: frozenset, seed: int) -> List[Dict[str, Any]]:
+ """Pick exactly the rows whose id is in forced_ids; warn on missing."""
+ indices: List[int] = []
+ found: set = set()
+ for i in range(len(ds)):
+ rid = ds[i]['id']
+ if rid in forced_ids:
+ indices.append(i)
+ found.add(rid)
+ missing = forced_ids - found
+ if missing:
+ sys.stderr.write(
+ f'Warning: {len(missing)} forced ids not found in dataset, '
+ f'e.g. {sorted(missing)[:5]}\n')
+ rng = random.Random(seed)
+ rng.shuffle(indices)
+ return [ds[int(i)] for i in indices]
+
+
+def load_done_ids(path: str) -> set:
+ if not os.path.exists(path):
+ return set()
+ done = set()
+ with open(path, 'r', encoding='utf-8') as fh:
+ for line in fh:
+ try:
+ obj = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+ rid = obj.get('id')
+ if rid:
+ done.add(rid)
+ return done
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--output', required=True)
+ parser.add_argument('--model', required=True)
+ parser.add_argument('--api-key', default=os.environ.get('OPENAI_API_KEY'))
+ parser.add_argument('--base-url', default=os.environ.get('OPENAI_BASE_URL'))
+ parser.add_argument('--total', type=int, default=12000)
+ parser.add_argument('--easy', type=int, default=2000)
+ parser.add_argument('--medium', type=int, default=4000)
+ parser.add_argument('--hard', type=int, default=6000)
+ parser.add_argument('--concurrency', type=int, default=16)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--wrong-ids', default='cookbook/rl/wrong_ids.txt')
+ parser.add_argument('--hf-subset', default='fullwiki')
+ parser.add_argument('--hf-split', default='train')
+ parser.add_argument(
+ '--only-forced', action=argparse.BooleanOptionalAction, default=False,
+ help='If set, re-annotate ONLY IDs in --wrong-ids; default is stratified sampling with wrong_ids force-included.')
+ args = parser.parse_args()
+
+ forced_ids: frozenset = frozenset()
+ if args.wrong_ids and os.path.exists(args.wrong_ids):
+ with open(args.wrong_ids, 'r', encoding='utf-8') as fh:
+ forced_ids = frozenset(ln.strip() for ln in fh if ln.strip())
+ sys.stderr.write(f'Forced IDs loaded: {len(forced_ids)}\n')
+
+ if args.only_forced and not forced_ids:
+ raise ValueError(
+ f'--only-forced is set but no IDs loaded from {args.wrong_ids!r}')
+
+ sys.stderr.write(
+ f'Loading hotpotqa/hotpot_qa:{args.hf_subset}:{args.hf_split}...\n')
+ ds = load_dataset(
+ 'hotpotqa/hotpot_qa', args.hf_subset, split=args.hf_split)
+
+ if args.only_forced:
+ rows = select_forced_only(ds, forced_ids=forced_ids, seed=args.seed)
+ sys.stderr.write(
+ f'Selected {len(rows)} rows (only-forced mode, '
+ f'requested={len(forced_ids)})\n')
+ else:
+ if args.easy + args.medium + args.hard != args.total:
+ raise ValueError(
+ f'--easy + --medium + --hard ({args.easy + args.medium + args.hard}) '
+ f'must equal --total ({args.total})')
+ per_level = {'easy': args.easy, 'medium': args.medium, 'hard': args.hard}
+ rows = stratified_sample_with_forced(
+ ds, per_level=per_level, forced_ids=forced_ids, seed=args.seed)
+ sys.stderr.write(
+ f'Selected {len(rows)} rows (stratified per_level={per_level}, '
+ f'forced={len(forced_ids)})\n')
+
+ done = load_done_ids(args.output)
+ sys.stderr.write(f'Resume: {len(done)} rows already done, skipping.\n')
+ pending = [row for row in rows if row['id'] not in done]
+ sys.stderr.write(f'Pending: {len(pending)} / {len(rows)}\n')
+
+ api = OpenAI(
+ model=args.model, api_key=args.api_key, base_url=args.base_url)
+
+ write_lock = threading.Lock()
+ out_fh = open(args.output, 'a', encoding='utf-8')
+ rows_done = 0
+ rows_failed = 0
+ try:
+ with ThreadPoolExecutor(max_workers=args.concurrency) as ex:
+ futures = {
+ ex.submit(verify_answer, api, args.model, row): row['id']
+ for row in pending
+ }
+ for fut in as_completed(futures):
+ rid = futures[fut]
+ try:
+ result = fut.result()
+ except Exception as exc:
+ sys.stderr.write(f'[row {rid}] crashed: {exc}\n')
+ rows_failed += 1
+ continue
+ if result is None:
+ rows_failed += 1
+ continue
+ with write_lock:
+ out_fh.write(
+ json.dumps(result, ensure_ascii=False) + '\n')
+ out_fh.flush()
+ rows_done += 1
+ if rows_done % 100 == 0:
+ sys.stderr.write(
+ f'[progress] done={rows_done} '
+ f'failed={rows_failed}\n')
+ finally:
+ out_fh.close()
+
+ sys.stderr.write(
+ f'Done. rows_done={rows_done}, failed={rows_failed}, '
+ f'total_pending={len(pending)}\n')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/exp/legacy/train_extract_ddp.py b/cookbook/exp/legacy/train_extract_ddp.py
new file mode 100644
index 000000000..38d3c1f5f
--- /dev/null
+++ b/cookbook/exp/legacy/train_extract_ddp.py
@@ -0,0 +1,119 @@
+"""DDP LoRA SFT for the policy on hotpotqa_distractor_reannotated_sft_12k.jsonl.
+
+The JSONL is the output of ``cookbook/rl/make_condensed_sft.py``: each row
+already carries ``messages`` (system / user / assistant with textual
+```` blocks / tool) plus an OpenAI-shape ``tools`` schema, ready
+for ``Qwen3_5Template`` to render. ``enable_thinking=False`` matches the
+RL runtime contract.
+
+Launch:
+ torchrun --nproc_per_node=8 cookbook/rl/train_condensed_sft_ddp.py
+"""
+from pathlib import Path
+
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceMesh, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+
+logger = get_logger()
+
+MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
+DATASET_PATH = str(
+ Path(__file__).resolve().parent.parent.parent
+ / 'hotpotqa_distractor_reannotated_sft_12k.jsonl')
+TEMPLATE_NAME = 'Qwen3_5Template'
+# Multi-hop with compressed context + multi-turn extract_condensed CoT;
+# raw audit: most samples land well under 16k after condensation.
+MAX_LENGTH = 32000
+
+DP_SIZE = 8
+BATCH_SIZE = 16
+LEARNING_RATE = 1e-4
+GRADIENT_ACCUMULATION_STEPS = 2
+LOG_INTERVAL = 20
+NUM_EPOCHS = 2
+
+OUTPUT_DIR = './output/condensed_sft_ddp'
+RESUME_FROM_CHECKPOINT = None
+RESUME_ONLY_MODEL = False
+IGNORE_DATA_SKIP = False
+ADAPTER_NAME = 'default'
+
+device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE)
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+
+def build_dataset(num_samples: int = None) -> Dataset:
+ meta_kwargs = {}
+ if num_samples is not None:
+ meta_kwargs['data_slice'] = range(num_samples)
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, **meta_kwargs))
+ # ``truncation_strategy='delete'`` drops overlong rows instead of slicing —
+ # a sliced multi-turn trajectory would lose `\boxed{}` and break SFT signal.
+ dataset.set_template(
+ TEMPLATE_NAME,
+ model_id=MODEL_ID,
+ max_length=MAX_LENGTH,
+ truncation_strategy='delete',
+ enable_thinking=False)
+ dataset.encode(load_from_cache_file=True, num_proc=16)
+ return dataset
+
+
+def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
+ model.save(
+ checkpoint_name,
+ output_dir=OUTPUT_DIR,
+ adapter_name=ADAPTER_NAME,
+ save_optimizer=True,
+ consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
+ )
+
+
+def train():
+ dataset = build_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+
+ model = TransformersModel(model_id=MODEL_ID, ddp_config={'find_unused_parameters': True})
+ model.model._no_split_modules = {'Qwen3_5DecoderLayer'}
+
+ lora_config = LoraConfig(r=16, lora_alpha=32, target_modules='all-linear')
+ model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+ model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=50,
+ num_training_steps=len(dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS)
+
+ if RESUME_FROM_CHECKPOINT:
+ checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()
+ kwargs = {'adapter_name': ADAPTER_NAME} if ADAPTER_NAME else {}
+ progress = model.resume_from_checkpoint(
+ str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs)
+ if not IGNORE_DATA_SKIP:
+ dataloader.resume_from_checkpoint(progress['consumed_train_samples'])
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+ logger.info(f'Total steps: {len(dataloader) * NUM_EPOCHS}')
+
+ optimizer_group = model.optimizer_group[ADAPTER_NAME]
+
+ for epoch in range(NUM_EPOCHS):
+ for batch in dataloader:
+ model.forward_backward(inputs=batch)
+ model.clip_grad_and_step()
+ cur_step = optimizer_group.cur_step
+ if cur_step % LOG_INTERVAL == 0:
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Epoch {epoch} Step {cur_step}/{len(dataloader) * NUM_EPOCHS}, metric: {metric}')
+ save_checkpoint(model, f'epoch-{epoch}', dataloader)
+ save_checkpoint(model, 'last-checkpoint', dataloader)
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/sample/emb_sample.py b/cookbook/sample/emb_sample.py
new file mode 100644
index 000000000..6d7e4c599
--- /dev/null
+++ b/cookbook/sample/emb_sample.py
@@ -0,0 +1,543 @@
+"""Embedding quality validation: compress (query, cot) pairs via vLLM condenser,
+extract embeddings via TransformersModel.forward_only(task='embedding'),
+report cosine similarity.
+
+Covers three domains: basic math, code logic, open-ended reasoning.
+
+Architecture (2 GPUs):
+ - GPU 0: vLLM condenser (compression)
+ - GPU 1: TransformersModel (embedding, same path as training)
+
+Launch:
+ python cookbook/sample/emb_sample.py
+ EMB_MODEL=./output/embedding_lora_transformers/step_16000 python cookbook/sample/emb_sample.py
+"""
+import os
+import re
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn.functional as F
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.data_format import SamplingParams
+from twinkle.loss import InfonceLoss
+from twinkle.model import TransformersModel
+from twinkle.processor import InputProcessor
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Template
+
+logger = get_logger()
+
+# -- Config -------------------------------------------------------------------
+CONDENSE_MODEL_ID = os.environ.get('CONDENSE_MODEL_ID', 'ms://twinkle-kit/Qwen3.5-4B-CM-v2')
+EMB_MODEL_ID = os.environ.get('EMB_MODEL', 'ms://twinkle-kit/Qwen3.5-4B-QA-emb')
+SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1))
+EMB_GPUS = int(os.environ.get('EMB_GPUS', 1))
+EMB_MAX_LENGTH = 8192
+
+# -- Prompts (aligned with train_embedding_full_ddp.py) -----------------------
+COMPRESS_SYSTEM = """\
+You are a compression and summary assistant. For the (query, source) pair, emit a Markdown \
+answer with TWO sections, designed to pair with the `extract_compressed` tool: \
+the reader absorbs `## Summary` directly, then calls `extract_compressed` \
+on any topic-key listed under `## More` to recover its \
+fuller content.
+
+ `## Summary` — extreme-density text the reader reads directly.
+ `## More` — a topic index whose keys are valid arguments \
+to `extract_compressed` for recovering material not captured inline.
+
+Together the two sections must form a COMPLETE, NON-DISTORTING inventory of the \
+source for the query — nothing essential lost, nothing implied that the source \
+does not support. NO preamble, NO meta-commentary, NO code fences wrapping the \
+whole output.
+
+Output skeleton:
+
+## Summary
+Topic:
+
+
+## More
+- :
+- ...
+
+Now begin.\
+"""
+
+COMPRESS_USER = (
+ 'Compress faithfully: preserve the passage topic + core facts. '
+ 'Do NOT invent facts. Do NOT drop major facts.\n\n'
+ '## Query (ordering hint only)\n{query}\n\n'
+ '## Passage\n{text}')
+
+EMBED_QUERY_Q = (
+ 'What problem does this passage address, and what skill or method is needed? '
+ 'Compress into a retrieval-friendly need description.')
+EMBED_QUERY_COT = (
+ 'Extract the reusable skill: trigger conditions, key steps, and expected output. '
+ 'Compress into a standardized procedure for retrieval.')
+
+# =============================================================================
+# Test pairs: 4 special categories to probe embedding quality
+# cat1 same-query-different-approach: same query paired with two CoTs that
+# solve it via different methods. Both q-c sims should be high; cross-CoT
+# sim reveals whether method-difference is reflected in embedding.
+# cat2 odd-domain: niche topics likely under-represented in pretraining,
+# tests generalization beyond mainstream STEM/web text.
+# cat3 reusable-basic: foundational facts/lemmas reusable across many queries,
+# check whether they get globally similar to lots of unrelated queries.
+# cat4 mutually-interfering: pairs of queries with overlapping vocabulary or
+# confusable concepts; intra-group cross-sim should remain discriminative.
+# Each entry may carry an optional 'group' key for intra-group cross analysis.
+# =============================================================================
+TEST_PAIRS: List[Dict[str, str]] = [
+ # --- cat1: same query, different approach -------------------------------
+ {
+ 'domain': 'cat1-sum-gauss',
+ 'group': 'g1-sum',
+ 'query': '计算 1+2+3+...+100 的和,要求使用高斯求和公式',
+ 'cot': (
+ '使用高斯求和公式 S = n(n+1)/2,n=100 → S = 100×101/2 = 5050。'
+ '此公式由首末项配对推导:(1+100)+(2+99)+...+(50+51) = 50×101 = 5050。'
+ '通用形式 S = n(a₁+aₙ)/2 适用于任意等差数列。'),
+ },
+ {
+ 'domain': 'cat1-sum-loop',
+ 'group': 'g1-sum',
+ 'query': '计算 1+2+3+...+100 的和,要求用 Python 循环累加',
+ 'cot': (
+ 'Python 循环累加:\n'
+ 'total = 0\n'
+ 'for i in range(1, 101):\n'
+ ' total += i\n'
+ 'print(total) # 5050\n'
+ '时间复杂度 O(n),空间 O(1)。也可用 sum(range(1, 101)) 一行写完。'),
+ },
+ {
+ 'domain': 'cat1-palindrome-twoptr',
+ 'group': 'g1-palindrome',
+ 'query': '如何使用双指针原地判断一个字符串是否是回文?要求 O(1) 额外空间。',
+ 'cot': (
+ '双指针方法:l=0, r=len(s)-1。循环 while l0,n₀>0,∀n≥n₀, |f(n)|≤c·|g(n)|。描述增长上界。'
+ '常见复杂度 O(1) List[str]:
+ """Compress a list of texts using the vLLM condenser."""
+ prompts = []
+ for text in texts:
+ user_msg = COMPRESS_USER.format(query=query_hint, text=text)
+ prompts.append({'messages': [
+ {'role': 'system', 'content': COMPRESS_SYSTEM},
+ {'role': 'user', 'content': user_msg},
+ ]})
+
+ params = SamplingParams(max_tokens=8192, temperature=0.2, top_p=0.5, num_samples=1)
+ responses = sampler.sample(prompts, params)
+
+ results = []
+ for resp in responses:
+ seq = resp.sequences[0] if resp and resp.sequences else None
+ text = ''
+ if seq and seq.decoded:
+ text = seq.decoded
+ text = re.sub(r'<\|[^|]+\|>', '', text).rstrip()
+ results.append(text)
+ return results
+
+
+# =============================================================================
+# Embedding extraction (TransformersModel, same path as training)
+# =============================================================================
+
+def _build_features(texts: List[str], template: Template, role: str) -> List[Dict[str, Any]]:
+ """Encode texts into embedding features, matching _get_first_feature in training."""
+ features = []
+ for text in texts:
+ if not text.strip():
+ continue
+ if role == 'anchor':
+ feat = template.encode({'messages': [
+ {'role': 'user', 'content': text},
+ {'role': 'assistant', 'content': 'Match the correct response here.'},
+ ]})
+ feat['labels'] = [1]
+ else:
+ feat = template.encode({'messages': [
+ {'role': 'user', 'content': 'Match the correct query here.'},
+ {'role': 'assistant', 'content': text},
+ ]})
+ feat['labels'] = [0]
+ features.append(feat)
+ return features
+
+
+def get_embeddings(model: TransformersModel, template: Template,
+ texts: List[str], role: str = 'anchor') -> torch.Tensor:
+ """Get embeddings via forward_only(task='embedding'), same code path as training."""
+ features = _build_features(texts, template, role)
+ if not features:
+ return torch.zeros(0)
+ outputs = model.forward_only(inputs=features, task='embedding', return_logits=True)
+ return outputs['embeddings']
+
+
+# =============================================================================
+# Group analysis helper
+# =============================================================================
+
+def print_group_matrix(pairs, q_embs, c_embs, title: str):
+ """Print intra-group cross-similarity (rows=query, cols=cot) for every
+ 'group' tag found in pairs. Highlights how well an embedding distinguishes
+ near-neighbour queries (cat1 same-query-different-approach, cat4
+ mutually-interfering)."""
+ from collections import defaultdict
+ groups = defaultdict(list)
+ for i, p in enumerate(pairs):
+ g = p.get('group')
+ if g:
+ groups[g].append(i)
+ if not groups:
+ return
+ logger.info(f'\n{"=" * 80}\n{title}\n{"=" * 80}')
+ for gname, idxs in groups.items():
+ logger.info(f'\n[Group: {gname}] rows=query, cols=cot, * = matched pair')
+ header = ' ' * 28
+ for j in idxs:
+ header += f' {pairs[j]["domain"][-14:]:>15}'
+ logger.info(header)
+ for i in idxs:
+ row = f'{pairs[i]["domain"][-28:]:<28}'
+ for j in idxs:
+ s = F.cosine_similarity(q_embs[i:i + 1], c_embs[j:j + 1]).item()
+ mark = '*' if i == j else ' '
+ row += f' {s:>13.4f}{mark}'
+ logger.info(row)
+
+
+# =============================================================================
+# Main
+# =============================================================================
+
+def main():
+ NUM_GPUS = SAMPLER_GPUS + EMB_GPUS
+
+ # 1. Initialize Twinkle with both device groups
+ device_groups = [
+ DeviceGroup(name='sampler',
+ ranks=list(range(SAMPLER_GPUS)),
+ device_type='GPU',
+ gpus_per_worker=SAMPLER_GPUS),
+ DeviceGroup(name='emb_model',
+ ranks=list(range(SAMPLER_GPUS, NUM_GPUS)),
+ device_type='GPU'),
+ ]
+ sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, tp_size=SAMPLER_GPUS)
+ emb_mesh = DeviceMesh.from_sizes(world_size=EMB_GPUS, dp_size=EMB_GPUS)
+ twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
+
+ # 2. vLLM condenser sampler
+ sampler = vLLMSampler(
+ model_id=CONDENSE_MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': 0.8,
+ 'max_model_len': 32768,
+ },
+ device_mesh=sampler_mesh,
+ remote_group='sampler',
+ )
+ sampler.set_template(
+ 'Qwen3_5Template', model_id=CONDENSE_MODEL_ID,
+ enable_thinking=False, max_length=32768)
+
+ # 3. Embedding model (same as training: TransformersModel + InputProcessor + InfonceLoss)
+ emb_model = TransformersModel(
+ model_id=EMB_MODEL_ID,
+ device_mesh=emb_mesh,
+ remote_group='emb_model',
+ )
+ emb_model.set_processor(InputProcessor)
+ emb_model.set_loss(InfonceLoss, temperature=0.03, use_batch=True)
+ emb_template = Template(model_id=EMB_MODEL_ID, max_length=EMB_MAX_LENGTH, enable_thinking=False)
+
+ logger.info(get_device_placement())
+
+ # 4. Compress all pairs
+ all_pairs = TEST_PAIRS + NEGATIVE_PAIRS
+ queries = [p['query'] for p in all_pairs]
+ cots = [p['cot'] for p in all_pairs]
+
+ logger.info(f'Compressing {len(queries)} queries ...')
+ compressed_queries = compress_texts(sampler, queries, EMBED_QUERY_Q)
+ logger.info(f'Compressing {len(cots)} CoTs ...')
+ compressed_cots = compress_texts(sampler, cots, EMBED_QUERY_COT)
+
+ # Print compression results
+ for i, pair in enumerate(all_pairs):
+ qc, cc = compressed_queries[i], compressed_cots[i]
+ logger.info(
+ f'\n{"=" * 70}\n'
+ f'[{pair["domain"]}] Query ({len(pair["query"])}→{len(qc)} chars):\n'
+ f' Raw: {pair["query"][:80]}...\n'
+ f' Compressed: {qc[:120]}...\n'
+ f'CoT ({len(pair["cot"])}→{len(cc)} chars):\n'
+ f' Compressed: {cc[:120]}...')
+
+ # 5. Get embeddings via TransformersModel.forward_only(task='embedding')
+ logger.info('Computing query embeddings ...')
+ q_embs = get_embeddings(emb_model, emb_template, compressed_queries, role='anchor')
+ logger.info('Computing CoT embeddings ...')
+ c_embs = get_embeddings(emb_model, emb_template, compressed_cots, role='positive')
+
+ logger.info('Computing raw query embeddings (no compression) ...')
+ raw_q_embs = get_embeddings(emb_model, emb_template, queries, role='anchor')
+ logger.info('Computing raw CoT embeddings (no compression) ...')
+ raw_c_embs = get_embeddings(emb_model, emb_template, cots, role='positive')
+
+ # 6. Compute similarities
+ n_positive = len(TEST_PAIRS)
+ n_negative = len(NEGATIVE_PAIRS)
+
+ logger.info(f'\n{"=" * 70}')
+ logger.info('RESULTS: Cosine Similarity (compressed query ↔ compressed CoT)')
+ logger.info(f'{"=" * 70}')
+ logger.info(f'{"Domain":<30} {"Compressed":>12} {"Raw":>12} {"Δ":>8}')
+ logger.info('-' * 70)
+
+ pos_sims_compressed, pos_sims_raw = [], []
+ neg_sims_compressed, neg_sims_raw = [], []
+
+ for i, pair in enumerate(all_pairs):
+ sim_c = F.cosine_similarity(q_embs[i:i+1], c_embs[i:i+1]).item()
+ sim_r = F.cosine_similarity(raw_q_embs[i:i+1], raw_c_embs[i:i+1]).item()
+ delta = sim_c - sim_r
+ marker = '✓' if i < n_positive else '✗'
+ logger.info(f' {marker} {pair["domain"]:<28} {sim_c:>10.4f} {sim_r:>10.4f} {delta:>+.4f}')
+
+ if i < n_positive:
+ pos_sims_compressed.append(sim_c)
+ pos_sims_raw.append(sim_r)
+ else:
+ neg_sims_compressed.append(sim_c)
+ neg_sims_raw.append(sim_r)
+
+ # Summary statistics
+ avg_pos_c = sum(pos_sims_compressed) / len(pos_sims_compressed)
+ avg_pos_r = sum(pos_sims_raw) / len(pos_sims_raw)
+ avg_neg_c = sum(neg_sims_compressed) / len(neg_sims_compressed) if neg_sims_compressed else 0
+ avg_neg_r = sum(neg_sims_raw) / len(neg_sims_raw) if neg_sims_raw else 0
+
+ logger.info(f'\n{"=" * 70}')
+ logger.info('SUMMARY')
+ logger.info(f' Positive pairs (matched): compressed={avg_pos_c:.4f} raw={avg_pos_r:.4f}')
+ logger.info(f' Negative pairs (mismatched): compressed={avg_neg_c:.4f} raw={avg_neg_r:.4f}')
+ logger.info(f' Margin (pos - neg): compressed={avg_pos_c - avg_neg_c:.4f} '
+ f'raw={avg_pos_r - avg_neg_r:.4f}')
+ logger.info(f'{"=" * 70}')
+
+ # 6. Group analysis: intra-group cross-sim for cat1/cat4 pairs.
+ print_group_matrix(all_pairs, q_embs, c_embs,
+ 'GROUP ANALYSIS (compressed, intra-group cross-sim)')
+ print_group_matrix(all_pairs, raw_q_embs, raw_c_embs,
+ 'GROUP ANALYSIS (raw, intra-group cross-sim)')
+
+ # 7. Global cross-similarity matrix across all pairs (compressed).
+ # Useful for cat3 (reusable basics) to spot whether a 'general' CoT lights
+ # up against unrelated queries.
+ logger.info(f'\n{"=" * 80}')
+ logger.info('GLOBAL cross-similarity matrix (compressed); * = matched diagonal')
+ logger.info(f'{"=" * 80}')
+ n_all = len(all_pairs)
+ cross_sim = F.cosine_similarity(
+ q_embs[:n_all].unsqueeze(1),
+ c_embs[:n_all].unsqueeze(0), dim=2)
+ header = ' ' * 30
+ for j in range(n_all):
+ header += f' {all_pairs[j]["domain"][-7:]:>8}'
+ logger.info(header)
+ for i in range(n_all):
+ row = f'{all_pairs[i]["domain"][-30:]:<30}'
+ for j in range(n_all):
+ val = cross_sim[i, j].item()
+ mark = '*' if i == j else ' '
+ row += f' {val:>7.4f}{mark}'
+ logger.info(row)
+
+ logger.info('Done.')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py
index d7d001d36..ad4c917f9 100644
--- a/cookbook/transformers/fsdp2.py
+++ b/cookbook/transformers/fsdp2.py
@@ -5,6 +5,7 @@
import twinkle
from twinkle import DeviceMesh, get_device_placement, get_logger
+from twinkle.cli import CLI
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
@@ -13,38 +14,19 @@
from twinkle.kernel import kernelize_model
logger = get_logger()
+args = CLI.from_args()
-MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
-DATASET_ID = 'ms://swift/self-cognition'
-TEMPLATE_NAME = 'Qwen3_5Template'
-MODEL_NAME = 'twinkle大模型'
-MODEL_AUTHOR = 'ModelScope社区'
-FSDP_SIZE = 2
-DP_SIZE = 4
-BATCH_SIZE = 8
-LEARNING_RATE = 1e-4
-GRADIENT_ACCUMULATION_STEPS = 2
-LOG_INTERVAL = 20
-EVAL_INTERVAL = 40
-EVAL_SAMPLES = 100
-TRAIN_SAMPLES = 1000
-
-OUTPUT_DIR = './output/fsdp2'
-RESUME_FROM_CHECKPOINT = None
-RESUME_ONLY_MODEL = False
-IGNORE_DATA_SKIP = False
-ADAPTER_NAME = 'default'
-
-# Construct a device_mesh
-device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE)
-# use torchrun mode
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size)
+twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh)
def build_dataset(num_samples: int) -> Dataset:
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples)))
- dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR))
+ dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples)))
+ dataset.set_template(args.template.template_cls, model_id=args.model.model_id)
+ dataset.map(SelfCognitionProcessor(
+ args.extra.get('model_name', 'twinkle大模型'),
+ args.extra.get('model_author', 'ModelScope社区'),
+ ))
dataset.encode()
return dataset
@@ -52,15 +34,16 @@ def build_dataset(num_samples: int) -> Dataset:
def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
model.save(
checkpoint_name,
- output_dir=OUTPUT_DIR,
- adapter_name=ADAPTER_NAME,
+ output_dir=args.training.output_dir,
+ adapter_name=args.lora.adapter_name,
save_optimizer=True,
consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
)
def evaluate(model):
- dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE)
+ eval_samples = args.training.eval_samples or 100
+ dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size)
for batch in tqdm(dataloader):
model.forward_only(inputs=batch)
model.calculate_loss()
@@ -68,54 +51,50 @@ def evaluate(model):
def train():
- dataset = build_dataset(TRAIN_SAMPLES)
- # Global batch size = 8, for GPUs, so 1 sample per GPU
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
- # Use a TransformersModel
- model = TransformersModel(model_id=MODEL_ID)
+ train_samples = int(args.extra.get('train_samples', 1000))
+ dataset = build_dataset(train_samples)
+ dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size)
+ model = TransformersModel(model_id=args.model.model_id)
model.model._no_split_modules = {'Qwen3_5DecoderLayer'}
# npu patch
if Torch.is_npu_available():
model = kernelize_model(model, mode='train', device='npu')
- lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
- # Add a lora to model, with name `default`
- model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
- # Add Optimizer for lora `default`
- model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
+ lora_config = LoraConfig(**args.get_lora_args())
+ model.add_adapter_to_model(
+ args.lora.adapter_name, lora_config,
+ gradient_accumulation_steps=args.training.gradient_accumulation_steps)
+ model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate)
+
# Add LRScheduler for lora `default`
model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
+ scheduler_cls=args.scheduler.scheduler_cls,
+ num_warmup_steps=args.scheduler.num_warmup_steps,
+ num_training_steps=len(dataloader))
- if RESUME_FROM_CHECKPOINT:
- checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()
- kwargs = {}
- if ADAPTER_NAME:
- kwargs['adapter_name'] = ADAPTER_NAME
+ if args.training.resume_from_checkpoint:
+ checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve()
progress = model.resume_from_checkpoint(
- str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs)
- if not IGNORE_DATA_SKIP:
+ str(checkpoint_path),
+ resume_only_model=args.training.resume_only_model,
+ adapter_name=args.lora.adapter_name)
+ if not args.training.ignore_data_skip:
dataloader.resume_from_checkpoint(progress['consumed_train_samples'])
logger.info(get_device_placement())
- # Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
- optimizer_group = model.optimizer_group[ADAPTER_NAME]
+ optimizer_group = model.optimizer_group[args.lora.adapter_name]
best_loss = float('inf')
- # lora: 8G * 8
- # full: 18G * 8
+ eval_interval = args.training.eval_interval or 40
for batch in dataloader:
- # Do forward and backward
model.forward_backward(inputs=batch)
- # Step
model.clip_grad_and_step()
cur_step = optimizer_group.cur_step
- if cur_step % LOG_INTERVAL == 0:
- # Print metric
+ if cur_step % args.training.log_interval == 0:
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}')
- if cur_step > 0 and cur_step % EVAL_INTERVAL == 0:
+ if cur_step > 0 and cur_step % eval_interval == 0:
metrics = evaluate(model)
logger.info(f'Eval metric: {metrics}')
metrics['step'] = cur_step
diff --git a/cookbook/transformers/fsdp2.sh b/cookbook/transformers/fsdp2.sh
index 93c531a98..bbe269629 100644
--- a/cookbook/transformers/fsdp2.sh
+++ b/cookbook/transformers/fsdp2.sh
@@ -1 +1,25 @@
-CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2.py
+#!/usr/bin/env bash
+# All training config passed as CLI flags. Override at invocation, e.g.:
+# bash fsdp2.sh --batch-size 16 --lr 5e-5
+
+CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} \
+ torchrun --nproc_per_node=8 fsdp2.py \
+ --model-id ms://Qwen/Qwen3.5-4B \
+ --dataset-id ms://swift/self-cognition \
+ --template-cls Qwen3_5Template \
+ --fsdp-size 2 \
+ --dp-size 4 \
+ --batch-size 8 \
+ --lr 1e-4 \
+ --gradient-accumulation-steps 2 \
+ --log-interval 20 \
+ --eval-interval 40 \
+ --eval-samples 100 \
+ --output-dir ./output/fsdp2 \
+ --adapter-name default \
+ --scheduler-cls CosineWarmupScheduler \
+ --num-warmup-steps 5 \
+ --train-samples 1000 \
+ --model-name twinkle大模型 \
+ --model-author ModelScope社区 \
+ "$@"
diff --git a/docs/source_en/Usage Guide/Installation.md b/docs/source_en/Usage Guide/Installation.md
index c8f87ba23..5dd5342e5 100644
--- a/docs/source_en/Usage Guide/Installation.md
+++ b/docs/source_en/Usage Guide/Installation.md
@@ -21,7 +21,7 @@ pip install -e .
You can also use our pre-built Docker image:
```text
-modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.2.1
+modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.3.0
```
## Client Installation
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\256\211\350\243\205.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\256\211\350\243\205.md"
index 71c6bee64..dd9d9a18a 100644
--- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\256\211\350\243\205.md"
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\256\211\350\243\205.md"
@@ -21,7 +21,7 @@ pip install -e .
你也可以使用我们的预构建 Docker 镜像:
```text
-modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.2.1
+modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:twinkle-0.3.0
```
## 客户端安装
diff --git a/pyproject.toml b/pyproject.toml
index 964a7548c..26a7db554 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,8 @@ kernels = ["kernels"]
megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridge"]
vllm = ["vllm>=0.11"]
ray = ["ray[serve]"]
+pyodps = ["pyodps"]
+datajuicer = ["py-data-juicer"]
tinker = ["tinker==0.14.0"]
docs = [
"sphinx>=5.3.0,<6.0.0",
diff --git a/src/twinkle/__init__.py b/src/twinkle/__init__.py
index f64917a5e..6e2c04512 100644
--- a/src/twinkle/__init__.py
+++ b/src/twinkle/__init__.py
@@ -16,7 +16,8 @@
'framework_util', 'torch_util', 'exists', 'requires', 'Platform', 'GPU', 'NPU', 'find_node_ip',
'find_free_port', 'trust_remote_code', 'check_unsafe', 'DeviceMesh', 'Plugin', 'DeviceGroup', 'get_logger'
],
- 'infra': ['initialize', 'remote_class', 'remote_function', 'get_device_placement', 'is_master'],
+ 'infra':
+ ['initialize', 'remote_class', 'remote_function', 'get_device_placement', 'is_master'],
}
import sys
diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py
index cde5c519d..2a22f7bae 100644
--- a/src/twinkle/checkpoint_engine/manager.py
+++ b/src/twinkle/checkpoint_engine/manager.py
@@ -116,12 +116,12 @@ def sync_weights(self, merge_and_sync=True):
peft_config = None
if self.base_sync_done and not merge_and_sync:
if self._peft_config is None:
- self._peft_config = self.model.get_peft_config_dict()
+ self._peft_config = self.model.get_peft_config_dict(lazy_collect=False)
peft_config = self._peft_config
if self._model_keys is None:
if hasattr(self.sampler, 'get_state_keys'):
- self._model_keys = self.sampler.get_state_keys()
+ self._model_keys = self.sampler.get_state_keys(lazy_collect=False)
if self._model_keys is None:
self._model_keys = []
diff --git a/src/twinkle/cli/__init__.py b/src/twinkle/cli/__init__.py
new file mode 100644
index 000000000..4dcc1d2a5
--- /dev/null
+++ b/src/twinkle/cli/__init__.py
@@ -0,0 +1,30 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .cli import (CLI, Args, CheckpointArgs, CLISource, ConfigResolver, ConfigSource, DatasetArgs, DotEnvSource,
+ EnvVarSource, InfraArgs, LoraArgs, LossArgs, ModelArgs, OptimizerArgs, RLArgs, SamplerArgs,
+ SamplingArgs, SchedulerArgs, ServerArgs, TemplateArgs, TrainingArgs, ValueCaster, YamlSource)
+
+__all__ = [
+ 'CLI',
+ 'Args',
+ 'ConfigSource',
+ 'ConfigResolver',
+ 'ValueCaster',
+ 'DotEnvSource',
+ 'EnvVarSource',
+ 'YamlSource',
+ 'CLISource',
+ 'ModelArgs',
+ 'LoraArgs',
+ 'DatasetArgs',
+ 'TemplateArgs',
+ 'TrainingArgs',
+ 'OptimizerArgs',
+ 'SchedulerArgs',
+ 'LossArgs',
+ 'SamplerArgs',
+ 'SamplingArgs',
+ 'InfraArgs',
+ 'ServerArgs',
+ 'RLArgs',
+ 'CheckpointArgs',
+]
diff --git a/src/twinkle/cli/cli.py b/src/twinkle/cli/cli.py
new file mode 100644
index 000000000..0cbfc97d1
--- /dev/null
+++ b/src/twinkle/cli/cli.py
@@ -0,0 +1,607 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+import os
+import sys
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field, fields
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union
+
+# ────────────────────────────────────────────────────────────────────────────────
+# Arg group dataclasses
+# ────────────────────────────────────────────────────────────────────────────────
+
+
+@dataclass
+class ModelArgs:
+ model_id: str | None = field(default=None, metadata={'primary': True})
+ model_cls: str | None = None
+ tokenizer_id: str | None = None
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16'
+ strategy: Literal['accelerate', 'native_fsdp'] = field(
+ default='accelerate', metadata={'aliases': ('use_megatron', )})
+ memory_efficient_init: bool = False
+ gradient_checkpointing: bool = True
+ trust_remote_code: bool = True
+ ddp_config: dict[str, Any] | None = None
+ fsdp_config: dict[str, Any] | None = None
+ grad_scaler_config: dict[str, Any] | None = None
+
+
+@dataclass
+class LoraArgs:
+ use_lora: bool = False
+ lora_r: int = 8
+ lora_alpha: int = 32
+ lora_dropout: float = 0.05
+ lora_target_modules: list[str] | None = None
+ adapter_name: str = 'default'
+
+
+@dataclass
+class DatasetArgs:
+ dataset_id: str = ''
+ subset_name: str = 'default'
+ split: str = 'train'
+ streaming: bool = False
+ num_proc: int | None = None
+ data_slice: str | None = None
+ revision: str | None = None
+
+
+@dataclass
+class TemplateArgs:
+ template_cls: str | None = None
+ model_id: str | None = None
+ max_length: int = 8192
+ truncation_strategy: Literal['raise', 'left', 'right', 'split', 'delete'] = 'raise'
+ use_chat_template: bool = True
+ enable_thinking: bool = True
+ default_system: str | None = None
+
+
+@dataclass
+class TrainingArgs:
+ max_steps: int = 200
+ num_train_epochs: int | None = None
+ batch_size: int = 8
+ mini_batch_size: int | None = None
+ micro_batch_size: int = 2
+ gradient_accumulation_steps: int = 1
+ output_dir: str = './output'
+ save_steps: int = 50
+ save_total_limit: int | None = None
+ log_interval: int = 10
+ eval_interval: int | None = None
+ eval_samples: int | None = None
+ resume_from_checkpoint: str | None = None
+ resume_only_model: bool = False
+ ignore_data_skip: bool = False
+ seed: int = field(default=42, metadata={'primary': True})
+ full_determinism: bool = False
+ padding_free: bool = False
+
+
+@dataclass
+class OptimizerArgs:
+ optimizer_cls: str = 'AdamW'
+ learning_rate: float = field(default=1e-5, metadata={'aliases': ('lr', )})
+ weight_decay: float = 0.0
+ adam_beta1: float = 0.9
+ adam_beta2: float = 0.999
+ adam_epsilon: float = 1e-8
+ max_grad_norm: float = 1.0
+
+
+@dataclass
+class SchedulerArgs:
+ scheduler_cls: str = 'CosineAnnealingLR'
+ num_warmup_steps: int = 0
+ num_training_steps: int | None = None
+ t_max: int | None = None
+ eta_min: float = 0.0
+ lr_decay_steps: int | None = None
+ max_lr: float | None = None
+
+
+@dataclass
+class LossArgs:
+ loss_cls: str = 'CrossEntropyLoss'
+ epsilon: float = 0.2
+ epsilon_high: float | None = None
+ beta: float = 0.0
+ entropy_coef: float = 0.0
+ ignore_index: int = -100
+
+
+@dataclass
+class SamplerArgs:
+ sampler_type: str = 'vLLMSampler'
+ gpu_memory_utilization: float = 0.8
+ max_model_len: int | None = None
+ tensor_parallel_size: int | None = None
+ enable_lora: bool = False
+ max_lora_rank: int = 32
+ enforce_eager: bool = False
+
+
+@dataclass
+class SamplingArgs:
+ max_tokens: int | None = field(default=None, metadata={'aliases': ('max_new_tokens', )})
+ temperature: float = 1.0
+ top_k: int = -1
+ top_p: float = 1.0
+ repetition_penalty: float = 1.0
+ num_samples: int = 1
+ logprobs: int | None = None
+ seed: int | None = None
+ stop: str | None = None
+
+
+@dataclass
+class InfraArgs:
+ mode: Literal['local', 'ray'] = 'local'
+ nproc_per_node: int = field(default=8, metadata={'aliases': ('num_gpus', )})
+ ncpu_proc_per_node: int = 8
+ model_gpus: int | None = None
+ sampler_gpus: int | None = None
+ dp_size: int | None = None
+ fsdp_size: int | None = None
+ tp_size: int | None = None
+ cp_size: int | None = None
+ ep_size: int | None = None
+ ulysses_size: int | None = None
+ lazy_collect: bool = True
+
+
+@dataclass
+class ServerArgs:
+ config: str | None = None
+ ray_namespace: str = 'twinkle_cluster'
+ host: str = '0.0.0.0'
+ port: int = 8000
+ log_level: str = 'INFO'
+
+
+@dataclass
+class RLArgs:
+ num_generations: int = 8
+ advantage_type: str = 'GRPOAdvantage'
+ advantage_scale: Literal['group', 'batch', 'none'] = 'group'
+ reward_fns: list[str] | None = None
+
+
+@dataclass
+class CheckpointArgs:
+ save_optimizer: bool = True
+ merge_and_sync: bool = True
+ platform: str = 'GPU'
+
+
+# ────────────────────────────────────────────────────────────────────────────────
+# ConfigSource hierarchy
+# ────────────────────────────────────────────────────────────────────────────────
+
+
+class ConfigSource(ABC):
+ """Base class for all configuration sources."""
+
+ @abstractmethod
+ def load(self) -> dict[str, Any]:
+ """Return raw key-value pairs from this source."""
+ ...
+
+
+class DotEnvSource(ConfigSource):
+
+ def __init__(self, path: str | Path | None = None):
+ self._path = path
+
+ def load(self) -> dict[str, str]:
+ path = self._resolve_path()
+ if path is None:
+ return {}
+ result: dict[str, str] = {}
+ with open(path) as f:
+ for line in f:
+ line = line.strip()
+ if not line or line.startswith('#'):
+ continue
+ if '=' not in line:
+ continue
+ key, _, value = line.partition('=')
+ result[key.strip()] = value.strip().strip('"').strip("'")
+ return result
+
+ def _resolve_path(self) -> Path | None:
+ if self._path is not None:
+ p = Path(self._path)
+ return p if p.is_file() else None
+ for name in ('.env', '.env.local'):
+ p = Path.cwd() / name
+ if p.is_file():
+ return p
+ return None
+
+
+class EnvVarSource(ConfigSource):
+ """Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry."""
+
+ def __init__(self, registry: ConfigRegistry):
+ self._registry = registry
+
+ def load(self) -> dict[str, str]:
+ result: dict[str, str] = {}
+ for key, value in os.environ.items():
+ if key.startswith('TWINKLE_'):
+ result[key[8:]] = value
+ elif self._registry.resolve(key) is not None:
+ result[key] = value
+ return result
+
+
+class YamlSource(ConfigSource):
+
+ def __init__(self, path: str | Path):
+ self._path = Path(path)
+
+ def load(self) -> dict[str, Any]:
+ from omegaconf import OmegaConf
+ if not self._path.is_file():
+ raise FileNotFoundError(f'Config file not found: {self._path}')
+ cfg = OmegaConf.load(self._path)
+ return OmegaConf.to_container(cfg, resolve=True)
+
+
+class CLISource(ConfigSource):
+
+ def __init__(self, argv: list[str] | None = None):
+ self._argv = argv if argv is not None else sys.argv[1:]
+
+ def load(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ i = 0
+ argv = self._argv
+ while i < len(argv):
+ token = argv[i]
+ if not token.startswith('--'):
+ i += 1
+ continue
+ token = token[2:]
+ if token.startswith('no_') or token.startswith('no-'):
+ result[token[3:]] = False
+ i += 1
+ continue
+ if '=' in token:
+ key, _, value = token.partition('=')
+ result[key] = value
+ i += 1
+ continue
+ if i + 1 < len(argv) and not argv[i + 1].startswith('--'):
+ result[token] = argv[i + 1]
+ i += 2
+ else:
+ result[token] = True
+ i += 1
+ return result
+
+
+# ────────────────────────────────────────────────────────────────────────────────
+# ConfigRegistry: maps normalized keys to (group_name, field_name)
+# ────────────────────────────────────────────────────────────────────────────────
+
+
+class ConfigRegistry:
+ """Introspects Args dataclass groups to build a case-insensitive key→field map."""
+
+ # Same field name in 2+ groups — the winning group must declare metadata={'primary': True}
+
+ def __init__(self, groups: dict[str, Any]):
+ self._field_map: dict[str, tuple[str, str]] = {}
+ self._alias_map: dict[str, str] = {}
+ self._groups = groups
+ self._build(groups)
+
+ def _build(self, groups: dict[str, Any]) -> None:
+ owners: dict[str, list[tuple[str, bool]]] = {}
+ for group_name, group_obj in groups.items():
+ for f in fields(group_obj):
+ is_primary = f.metadata.get('primary', False)
+ owners.setdefault(f.name.lower(), []).append((group_name, is_primary))
+ for alias in f.metadata.get('aliases', ()): # field-local aliases
+ self._alias_map[alias.lower()] = f.name.lower()
+ for key, owner_list in owners.items():
+ if len(owner_list) == 1:
+ self._field_map[key] = (owner_list[0][0], key)
+ continue
+ primaries = [g for g, p in owner_list if p]
+ if len(primaries) != 1:
+ all_groups = [g for g, _ in owner_list]
+ raise ValueError(f'Field {key!r} exists in groups {all_groups}; '
+ f"exactly one must declare metadata={{'primary': True}}, found {len(primaries)}")
+ self._field_map[key] = (primaries[0], key)
+
+ def resolve(self, key: str) -> tuple[str, str] | None:
+ normalized = key.lower().replace('-', '_')
+ canonical = self._alias_map.get(normalized, normalized)
+ if canonical in self._field_map:
+ return self._field_map[canonical]
+ # prefix-based fallback: model_xxx → group=model, field=xxx
+ for group_name in self._groups:
+ prefix = group_name + '_'
+ if canonical.startswith(prefix):
+ stripped = canonical[len(prefix):]
+ if stripped and (group_name, stripped) in ((g, f.name) for g, obj in self._groups.items()
+ for f in fields(obj)):
+ return (group_name, stripped)
+ return None
+
+ def all_keys(self) -> Iterator[str]:
+ return iter(self._field_map)
+
+
+# ────────────────────────────────────────────────────────────────────────────────
+# Args: unified container
+# ────────────────────────────────────────────────────────────────────────────────
+
+
+@dataclass
+class Args:
+ """Unified argument container. Access groups directly or via get_*_args() dicts."""
+
+ model: ModelArgs = field(default_factory=ModelArgs)
+ lora: LoraArgs = field(default_factory=LoraArgs)
+ dataset: DatasetArgs = field(default_factory=DatasetArgs)
+ template: TemplateArgs = field(default_factory=TemplateArgs)
+ training: TrainingArgs = field(default_factory=TrainingArgs)
+ optimizer: OptimizerArgs = field(default_factory=OptimizerArgs)
+ scheduler: SchedulerArgs = field(default_factory=SchedulerArgs)
+ loss: LossArgs = field(default_factory=LossArgs)
+ sampler: SamplerArgs = field(default_factory=SamplerArgs)
+ sampling: SamplingArgs = field(default_factory=SamplingArgs)
+ infra: InfraArgs = field(default_factory=InfraArgs)
+ server: ServerArgs = field(default_factory=ServerArgs)
+ rl: RLArgs = field(default_factory=RLArgs)
+ checkpoint: CheckpointArgs = field(default_factory=CheckpointArgs)
+ extra: dict[str, Any] = field(default_factory=dict)
+
+ def get_model_args(self) -> dict[str, Any]:
+ d = self._to_dict(self.model)
+ if not d.get('model_id') and self.template.model_id:
+ d['model_id'] = self.template.model_id
+ return d
+
+ def get_lora_args(self) -> dict[str, Any]:
+ return {
+ 'target_modules': self.lora.lora_target_modules or 'all-linear',
+ 'r': self.lora.lora_r,
+ 'lora_alpha': self.lora.lora_alpha,
+ 'lora_dropout': self.lora.lora_dropout,
+ }
+
+ def get_dataset_args(self) -> dict[str, Any]:
+ return self._to_dict(self.dataset)
+
+ def get_template_args(self) -> dict[str, Any]:
+ d = self._to_dict(self.template)
+ if not d.get('model_id') and self.model.model_id:
+ d['model_id'] = self.model.model_id
+ return d
+
+ def get_training_args(self) -> dict[str, Any]:
+ return self._to_dict(self.training)
+
+ def get_optimizer_args(self) -> dict[str, Any]:
+ d = self._to_dict(self.optimizer)
+ d['lr'] = d.pop('learning_rate', 1e-5)
+ return d
+
+ def get_scheduler_args(self) -> dict[str, Any]:
+ return self._to_dict(self.scheduler)
+
+ def get_loss_args(self) -> dict[str, Any]:
+ return self._to_dict(self.loss)
+
+ def get_sampler_args(self) -> dict[str, Any]:
+ return self._to_dict(self.sampler)
+
+ def get_sampling_args(self) -> dict[str, Any]:
+ return self._to_dict(self.sampling)
+
+ def get_infra_args(self) -> dict[str, Any]:
+ return self._to_dict(self.infra)
+
+ def get_server_args(self) -> dict[str, Any]:
+ return self._to_dict(self.server)
+
+ def get_rl_args(self) -> dict[str, Any]:
+ return self._to_dict(self.rl)
+
+ def get_checkpoint_args(self) -> dict[str, Any]:
+ return self._to_dict(self.checkpoint)
+
+ def get(self, key: str, default: Any = None) -> Any:
+ for f in fields(self):
+ if f.name == 'extra':
+ continue
+ group = getattr(self, f.name)
+ if hasattr(group, key):
+ return getattr(group, key)
+ return self.extra.get(key, default)
+
+ def __getitem__(self, key: str) -> Any:
+ val = self.get(key, _SENTINEL)
+ if val is _SENTINEL:
+ raise KeyError(key)
+ return val
+
+ def to_dict(self) -> dict[str, Any]:
+ result = {}
+ for f in fields(self):
+ if f.name == 'extra':
+ continue
+ result.update(self._to_dict(getattr(self, f.name)))
+ result.update(self.extra)
+ return result
+
+ @staticmethod
+ def _to_dict(obj: Any) -> dict[str, Any]:
+ return {f.name: getattr(obj, f.name) for f in fields(obj) if getattr(obj, f.name) is not None}
+
+
+_SENTINEL = object()
+
+# ────────────────────────────────────────────────────────────────────────────────
+# ValueCaster: type coercion
+# ────────────────────────────────────────────────────────────────────────────────
+
+
+class ValueCaster:
+
+ @staticmethod
+ def auto_cast(value: Any) -> Any:
+ if not isinstance(value, str):
+ return value
+ low = value.lower()
+ if low in ('true', 'yes', 'on'):
+ return True
+ if low in ('false', 'no', 'off'):
+ return False
+ if low in ('none', 'null', '~'):
+ return None
+ try:
+ return int(value)
+ except ValueError:
+ pass
+ try:
+ return float(value)
+ except ValueError:
+ pass
+ if ',' in value:
+ return [ValueCaster.auto_cast(v.strip()) for v in value.split(',')]
+ return value
+
+ @staticmethod
+ def coerce_to_field(obj: Any, field_name: str, value: Any) -> Any:
+ current = getattr(obj, field_name, None)
+ if current is None or value is None:
+ return value
+ target_type = type(current)
+ if target_type is bool:
+ if isinstance(value, bool):
+ return value
+ return ValueCaster.auto_cast(str(value))
+ if target_type is int and not isinstance(value, int):
+ try:
+ return int(float(value)) if isinstance(value, str) else int(value)
+ except (ValueError, TypeError):
+ return value
+ if target_type is float and not isinstance(value, (int, float)):
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ return value
+ if target_type is list and isinstance(value, str):
+ return [v.strip() for v in value.split(',')]
+ return value
+
+
+# ────────────────────────────────────────────────────────────────────────────────
+# ConfigResolver: merges sources
+# ────────────────────────────────────────────────────────────────────────────────
+
+
+class ConfigResolver:
+
+ def __init__(self, args: Args):
+ self._args = args
+ self._groups = {f.name: getattr(args, f.name) for f in fields(args) if f.name != 'extra'}
+ self._registry = ConfigRegistry(self._groups)
+
+ @property
+ def registry(self) -> ConfigRegistry:
+ return self._registry
+
+ def apply(self, source: dict[str, Any], cast_strings: bool = False) -> None:
+ flat = self._flatten(source)
+ for raw_key, raw_value in flat.items():
+ key = raw_key.lower().replace('-', '_')
+ value = ValueCaster.auto_cast(raw_value) if cast_strings else raw_value
+ # handle use_megatron alias
+ if key == 'use_megatron':
+ if ValueCaster.auto_cast(str(value)):
+ self._set('model', 'strategy', 'native_fsdp')
+ continue
+ resolved = self._registry.resolve(key)
+ if resolved:
+ group_name, field_name = resolved
+ group = self._groups[group_name]
+ coerced = ValueCaster.coerce_to_field(group, field_name, value)
+ setattr(group, field_name, coerced)
+ else:
+ self._args.extra[key] = value
+
+ def _set(self, group_name: str, field_name: str, value: Any) -> None:
+ group = self._groups[group_name]
+ setattr(group, field_name, value)
+
+ def _flatten(self, d: Any, prefix: str = '') -> dict[str, Any]:
+ if not isinstance(d, dict):
+ return {prefix: d} if prefix else {}
+ result: dict[str, Any] = {}
+ for key, value in d.items():
+ full_key = f'{prefix}_{key}' if prefix else key
+ if isinstance(value, dict):
+ result.update(self._flatten(value, full_key))
+ else:
+ result[full_key] = value
+ return result
+
+
+# ────────────────────────────────────────────────────────────────────────────────
+# CLI: top-level entry point
+# ────────────────────────────────────────────────────────────────────────────────
+
+
+class CLI:
+ """Unified configuration parser.
+
+ Resolution order (later wins):
+ 1. Dataclass defaults
+ 2. .env file
+ 3. Environment variables (TWINKLE_ prefix or bare)
+ 4. YAML config file (--config / explicit)
+ 5. CLI overrides (--key value)
+
+ All keys are case-insensitive and dash/underscore equivalent:
+ --model-id, MODEL_ID, TWINKLE_MODEL_ID, model_id: in .yaml all resolve the same.
+ """
+
+ @staticmethod
+ def from_args(
+ argv: list[str] | None = None,
+ env_file: str | Path | None = None,
+ config_file: str | Path | None = None,
+ ) -> Args:
+ args = Args()
+ resolver = ConfigResolver(args)
+
+ # 1. .env
+ resolver.apply(DotEnvSource(env_file).load(), cast_strings=True)
+
+ # 2. Environment variables
+ resolver.apply(EnvVarSource(resolver.registry).load(), cast_strings=True)
+
+ # 3. CLI (first pass to extract --config)
+ cli_data = CLISource(argv).load()
+ yaml_path = config_file or cli_data.pop('config', None)
+
+ # 4. YAML
+ if yaml_path:
+ resolver.apply(YamlSource(yaml_path).load(), cast_strings=False)
+
+ # 5. CLI overrides (highest priority, values are strings from argv)
+ resolver.apply(cli_data, cast_strings=True)
+
+ return args
diff --git a/src/twinkle/data_format/__init__.py b/src/twinkle/data_format/__init__.py
index 2b2c3cf04..c93bebd2d 100644
--- a/src/twinkle/data_format/__init__.py
+++ b/src/twinkle/data_format/__init__.py
@@ -3,4 +3,4 @@
from .message import Message, Tool, ToolCall
from .output import LossOutput, ModelOutput
from .sampling import SampledSequence, SampleResponse, SamplingParams
-from .trajectory import Trajectory
+from .trajectory import Trajectory, pack_value, user_data_get
diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py
index 2462aae0a..617e5ac7b 100644
--- a/src/twinkle/data_format/trajectory.py
+++ b/src/twinkle/data_format/trajectory.py
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
+import json
import sys
from typing import Any, List, Optional, Tuple, Union
@@ -14,8 +15,31 @@
class Trajectory(TypedDict, total=False):
messages: List[Message]
tools: List[Tool]
- user_data: List[Tuple[str, Any]]
+ # PyArrow-stable encoding: each entry is (key, json.dumps(value)). Use the helpers below.
+ user_data: List[Tuple[str, str]]
images: Optional[List[Union[str, Any]]]
videos: Optional[List[Union[str, Any]]]
audios: Optional[List[Union[str, Any]]]
prompt: Optional[str]
+
+
+def pack_value(value: Any) -> str:
+ """Encode a single user_data value to a JSON string."""
+ return json.dumps(value, ensure_ascii=False, default=str)
+
+
+def user_data_get(items: Any, key: str, default: Any = None) -> Any:
+ """Look up the first value matching ``key`` in packed user_data, decoded."""
+ if not isinstance(items, list):
+ return default
+ for entry in items:
+ if isinstance(entry, (list, tuple)) and len(entry) == 2 and entry[0] == key:
+ v = entry[1]
+ if not isinstance(v, str):
+ return v
+ try:
+ return json.loads(v)
+ except (json.JSONDecodeError, ValueError):
+ return v
+ return default
+
diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py
index d44856b79..6c02c82cb 100644
--- a/src/twinkle/dataset/base.py
+++ b/src/twinkle/dataset/base.py
@@ -1,10 +1,14 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
+import json as _json
import os.path
+import threading
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from datasets import DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
+from queue import Queue
from torch.utils.data import Dataset as TorchDataset
-from typing import Any, Callable, Dict, Type, Union
+from torch.utils.data import IterableDataset as TorchIterableDataset
+from typing import Any, Callable, Dict, List, Optional, Type, Union
import twinkle
from twinkle import preprocessor
@@ -13,6 +17,7 @@
from twinkle.preprocessor import DataFilter, Preprocessor
from twinkle.template import Template
from twinkle.utils import construct_class, processing_lock
+from twinkle.utils.parallel import PosixFileLock
try:
import multiprocess
@@ -27,20 +32,34 @@ class DatasetMeta:
The dataset meta-information, used to describe a dataset.
"""
# The dataset id or local path
- dataset_id: str
+ dataset_id: str = ''
# The subset name
subset_name: str = 'default'
# The split
split: str = 'train'
# Pick a data slice
data_slice: Iterable = None
+ # In-memory / in-process data source. Supports:
+ # - List[Dict] (row-oriented, eager)
+ # - Dict[str, List] (column-oriented, eager)
+ # - Callable (generator function; routed to HF from_generator,
+ # streaming vs eager picked from `streaming` kwarg.
+ # Bind args via functools.partial.)
+ # - HFDataset / HFIterableDataset (already-constructed, passed through)
+ data: Any = None
def get_id(self):
+ if self.data is not None:
+ return f'__memory_{self._uid}__:' + self.subset_name + ':' + self.split
return self.dataset_id.replace(os.sep, '_').replace('.', '_') + ':' + self.subset_name + ':' + self.split
def __post_init__(self):
+ import uuid
+ self._uid = uuid.uuid4().hex[:8]
if self.data_slice is not None and not isinstance(self.data_slice, Iterable):
raise ValueError('data_slice must be an iterable')
+ if not self.dataset_id and self.data is None:
+ raise ValueError('Either dataset_id or data must be provided')
@remote_class(execute='first')
@@ -58,6 +77,7 @@ class Dataset(TorchDataset):
def __init__(self, dataset_meta: DatasetMeta = None, **kwargs):
self.template = None
+ self._mixed = False
if dataset_meta is None:
self.datasets = {}
self.dataset = None
@@ -79,6 +99,16 @@ def set_template(self, template_func: Union[Template, Type[Template], str], **kw
"""
self.template = construct_class(template_func, Template, twinkle.template, **kwargs)
+ @staticmethod
+ def _normalize_cache_kwargs(target, kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ kw = dict(kwargs)
+ # Streaming datasets (HF IterableDataset / torch IterableDataset wrappers) reject load_from_cache_file.
+ if isinstance(target, (IterableDataset, TorchIterableDataset)):
+ kw.pop('load_from_cache_file', None)
+ else:
+ kw.setdefault('load_from_cache_file', False)
+ return kw
+
@remote_function()
def encode(self, add_generation_prompt: bool = False, **kwargs):
"""An inplace operation to encode the dataset.
@@ -90,18 +120,16 @@ def encode(self, add_generation_prompt: bool = False, **kwargs):
**kwargs: The mapping and filter kwargs of the `datasets.map`.
"""
kwargs['batched'] = True # Only supported batched, because a single row may explode to several rows
- if 'load_from_cache_file' not in kwargs:
- # By default, we don't use load_from_cache_file, because read cache will not consider
- # the changes in the same file,
- # which will cause unexpected behaviors.
- kwargs['load_from_cache_file'] = False
+ kwargs = self._normalize_cache_kwargs(self.dataset, kwargs)
from functools import partial
encode_fn = partial(self.template.batch_encode, add_generation_prompt=add_generation_prompt)
+ # Dataset.filter() does not accept map-only kwargs (e.g. remove_columns); split them off.
+ filter_kwargs = {k: v for k, v in kwargs.items() if k != 'remove_columns'}
with processing_lock('dataset'):
# use a default lock because encode is to all datasets
self.dataset = self.dataset.map(encode_fn, **kwargs).filter(
lambda batch: [True] * len(next(iter(batch.values())))
- if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs)
+ if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **filter_kwargs)
@remote_function()
def check(self, **kwargs):
@@ -111,9 +139,7 @@ def check(self, **kwargs):
**kwargs: The mapping and filter kwargs of the `datasets.map`.
"""
kwargs['batched'] = True # Only supported batched, because a single row may explode to several rows
- # check depends on template/tokenizer behavior; cached filter results can keep old empty outputs.
- # Disable cache here to avoid the "silent stop" caused by stale empty cache.
- kwargs.setdefault('load_from_cache_file', False)
+ kwargs = self._normalize_cache_kwargs(self.dataset, kwargs)
with processing_lock('dataset'):
# use a default lock because check is to all datasets
def _check_batch(batch):
@@ -126,6 +152,23 @@ def _check_batch(batch):
@staticmethod
def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
+ # In-memory / in-process data path
+ if dataset_meta.data is not None:
+ from datasets import Dataset as HFDataset
+ from datasets import IterableDataset as HFIterableDataset
+ d = dataset_meta.data
+ if isinstance(d, (HFDataset, HFIterableDataset)):
+ return d
+ if isinstance(d, list):
+ return HFDataset.from_list(d)
+ if isinstance(d, dict):
+ return HFDataset.from_dict(d)
+ if callable(d):
+ cls = HFIterableDataset if kwargs.get('streaming') else HFDataset
+ return cls.from_generator(d)
+ raise ValueError(f'DatasetMeta.data must be list, dict, callable, or HF Dataset/IterableDataset, '
+ f'got {type(d).__name__}')
+
dataset_id = dataset_meta.dataset_id
subset_name = dataset_meta.subset_name
split = dataset_meta.split
@@ -168,6 +211,9 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
raise KeyError(f"Split '{split}' not found for dataset '{dataset_id}'. "
f'Available splits: {available_splits}')
+ if hasattr(dataset, 'to_hf_dataset'):
+ dataset = dataset.to_hf_dataset()
+
if isinstance(dataset_meta.data_slice, Iterable) and hasattr(dataset, '__len__'):
iter_list = []
@@ -209,22 +255,22 @@ def map(self,
**kwargs: The kwargs of the `datasets.map`.
"""
init_args = init_args or {}
- if 'load_from_cache_file' not in kwargs:
- # By default, we don't use load_from_cache_file, because read cache will not consider
- # the changes in the same file,
- # which will cause unexpected behaviors.
- kwargs['load_from_cache_file'] = False
preprocess_func = construct_class(preprocess_func, Preprocessor, twinkle.preprocessor, **init_args)
- if dataset_meta is None:
- assert len(self.datasets) == 1
- key = next(iter(self.datasets.keys()))
- else:
- key = dataset_meta.get_id()
kwargs['batched'] = True
- with processing_lock(key):
- self.datasets[key] = self.datasets[key].map(preprocess_func, **kwargs)
- if len(self.datasets) == 1:
- self.dataset = self.datasets[key]
+
+ if self._mixed:
+ self.dataset = self.dataset.map(preprocess_func, **self._normalize_cache_kwargs(self.dataset, kwargs))
+ else:
+ if dataset_meta is None:
+ assert len(self.datasets) == 1
+ key = next(iter(self.datasets.keys()))
+ else:
+ key = dataset_meta.get_id()
+ with processing_lock(key):
+ kw = self._normalize_cache_kwargs(self.datasets[key], kwargs)
+ self.datasets[key] = self.datasets[key].map(preprocess_func, **kw)
+ if len(self.datasets) == 1:
+ self.dataset = self.datasets[key]
@remote_function()
def filter(self,
@@ -242,16 +288,20 @@ def filter(self,
"""
init_args = init_args or {}
filter_func = construct_class(filter_func, DataFilter, twinkle.preprocessor, **init_args)
- if dataset_meta is None:
- assert len(self.datasets) == 1
- key = next(iter(self.datasets.keys()))
+ if self._mixed:
+ kwargs['batched'] = False
+ self.dataset = self.dataset.filter(filter_func, **kwargs)
else:
- key = dataset_meta.get_id()
- kwargs['batched'] = False
- with processing_lock(key):
- self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs)
- if len(self.datasets) == 1:
- self.dataset = self.datasets[key]
+ if dataset_meta is None:
+ assert len(self.datasets) == 1
+ key = next(iter(self.datasets.keys()))
+ else:
+ key = dataset_meta.get_id()
+ kwargs['batched'] = False
+ with processing_lock(key):
+ self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs)
+ if len(self.datasets) == 1:
+ self.dataset = self.datasets[key]
@remote_function()
def add_dataset(self, dataset_meta: DatasetMeta, **kwargs):
@@ -279,15 +329,302 @@ def mix_dataset(self, interleave=True):
dataset_types = [isinstance(ds, IterableDataset) for ds in self.datasets]
assert all(
dataset_types) or not any(dataset_types), 'All datasets must be all streaming=True or streaming=False'
+ if not any(dataset_types):
+ dsets = list(self.datasets.values())
+ # Align features
+ ref_features = dsets[0].features
+ aligned = []
+ for ds in dsets:
+ if ds.features != ref_features:
+ ds = ds.cast(ref_features)
+ aligned.append(ds)
+ else:
+ aligned = list(self.datasets.values())
if interleave:
- self.dataset = interleave_datasets(list(self.datasets.values()))
+ self.dataset = interleave_datasets(aligned)
else:
- self.dataset = concatenate_datasets(list(self.datasets.values()))
+ self.dataset = concatenate_datasets(aligned)
+ self._mixed = True
+
+ @remote_function()
+ def save_as(self,
+ output_path: str,
+ format: Optional[str] = None,
+ batch_size: int = 1000,
+ mode: str = 'immediate',
+ **kwargs) -> None:
+ """Save the merged dataset to a local file.
+
+ Args:
+ output_path: Target file path. Extension determines format if `format` is None.
+ format: One of 'jsonl', 'json', 'csv', 'parquet'. Auto-detected from extension if None.
+ batch_size: Batch size for buffered writing.
+ mode: 'immediate' to save all data now; 'training' to write-through as data is
+ consumed by __iter__/__getitem__ — call flush_save() when training ends.
+ **kwargs: Extra args passed to the underlying HF export method (immediate bulk only).
+ """
+ if self.dataset is None:
+ raise ValueError('No dataset to save.')
+ if len(self.datasets) > 1 and not self._mixed:
+ raise ValueError('Call mix_dataset() before save_as() when multiple datasets are loaded.')
+
+ fmt = format or self._infer_format(output_path)
+ if fmt not in ('jsonl', 'json', 'csv', 'parquet'):
+ raise ValueError(f"Unsupported format: '{fmt}'. Use jsonl/json/csv/parquet.")
+
+ dir_path = os.path.dirname(os.path.abspath(output_path))
+ os.makedirs(dir_path, exist_ok=True)
+
+ if mode == 'training':
+ self._save_state = _SaveState(output_path, fmt, batch_size)
+ return
+
+ if self._should_materialize():
+ self._save_incremental(output_path, fmt, batch_size)
+ else:
+ self._save_bulk(output_path, fmt, **kwargs)
+
+ @remote_function()
+ def flush_save(self) -> None:
+ """Finalize and close the training-mode writer opened by save_as(mode='training')."""
+ state = getattr(self, '_save_state', None)
+ if state is not None:
+ state.close()
+ self._save_state = None
+
+ def _write_through(self, row):
+ """If training-mode save is active, persist the row."""
+ state = getattr(self, '_save_state', None)
+ if state is not None:
+ state.write(row)
+ return row
+
+ @staticmethod
+ def _infer_format(path: str) -> str:
+ ext = os.path.splitext(path)[1].lstrip('.').lower()
+ return {
+ 'jsonl': 'jsonl',
+ 'json': 'jsonl',
+ 'csv': 'csv',
+ 'parquet': 'parquet',
+ 'pq': 'parquet'
+ }.get(ext, 'jsonl')
+
+ def _should_materialize(self) -> bool:
+ if isinstance(self.dataset, IterableDataset):
+ return True
+ if hasattr(self, 'do_encode') and self.do_encode:
+ return True
+ if getattr(self, '_lazy_map_ops', None) or getattr(self, '_global_map_ops', None):
+ return True
+ return False
+
+ def _save_bulk(self, path: str, fmt: str, **kwargs) -> None:
+ if fmt in ('jsonl', 'json'):
+ self.dataset.to_json(path, **kwargs)
+ elif fmt == 'csv':
+ self.dataset.to_csv(path, **kwargs)
+ elif fmt == 'parquet':
+ self.dataset.to_parquet(path, **kwargs)
+
+ def _save_incremental(self, path: str, fmt: str, batch_size: int) -> None:
+ iterator = self._row_iterator()
+ if fmt in ('jsonl', 'json'):
+ self._write_jsonl(path, iterator)
+ elif fmt == 'csv':
+ self._write_csv(path, iterator, batch_size)
+ elif fmt == 'parquet':
+ self._write_parquet(path, iterator, batch_size)
+
+ def _row_iterator(self):
+ if isinstance(self.dataset, IterableDataset):
+ yield from self.dataset
+ else:
+ for i in range(len(self)):
+ yield self[i]
+
+ @staticmethod
+ def _write_jsonl(path: str, iterator) -> None:
+ with open(path, 'w', encoding='utf-8') as f:
+ for row in iterator:
+ f.write(_json.dumps(row, ensure_ascii=False, default=_default_serializer) + '\n')
+
+ @staticmethod
+ def _write_csv(path: str, iterator, batch_size: int) -> None:
+ import pandas as pd
+ first = True
+ batch: List[Dict] = []
+ for row in iterator:
+ batch.append(row)
+ if len(batch) >= batch_size:
+ pd.DataFrame(batch).to_csv(path, mode='a', header=first, index=False)
+ first = False
+ batch = []
+ if batch:
+ pd.DataFrame(batch).to_csv(path, mode='a', header=first, index=False)
+
+ @staticmethod
+ def _write_parquet(path: str, iterator, batch_size: int) -> None:
+ import pyarrow as pa
+ import pyarrow.parquet as pq
+ writer = None
+ batch: List[Dict] = []
+ for row in iterator:
+ batch.append(row)
+ if len(batch) >= batch_size:
+ table = pa.Table.from_pylist(batch)
+ if writer is None:
+ writer = pq.ParquetWriter(path, table.schema)
+ writer.write_table(table)
+ batch = []
+ if batch:
+ table = pa.Table.from_pylist(batch)
+ if writer is None:
+ writer = pq.ParquetWriter(path, table.schema)
+ writer.write_table(table)
+ if writer:
+ writer.close()
@remote_function()
def __getitem__(self, idx):
- return self.dataset[idx]
+ item = self.dataset[idx]
+ self._write_through(item)
+ return item
@remote_function()
def __len__(self):
return len(self.dataset)
+
+
+def _default_serializer(obj):
+ """Handle numpy types in JSON serialization."""
+ import numpy as np
+ if isinstance(obj, np.integer):
+ return int(obj)
+ if isinstance(obj, np.floating):
+ return float(obj)
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ raise TypeError(f'Object of type {type(obj).__name__} is not JSON serializable')
+
+
+_SENTINEL = object()
+
+
+class _SaveState:
+ """Async persistent writer for training-mode save_as.
+
+ Writes happen on a background daemon thread so the training loop is never blocked.
+ Uses fcntl file-lock for cross-process safety when multiple ranks write one file.
+ """
+
+ def __init__(self, path: str, fmt: str, batch_size: int):
+
+ self._path = path
+ self._fmt = fmt
+ self._batch_size = batch_size
+ self._queue: Queue = Queue(maxsize=batch_size * 4)
+ self._lock = PosixFileLock(path + '.lock')
+ self._error = None
+
+ self._thread = threading.Thread(target=self._writer_loop, daemon=True)
+ self._thread.start()
+
+ def write(self, row: Dict) -> None:
+ self._queue.put(row)
+
+ def close(self) -> None:
+ self._queue.put(_SENTINEL)
+ self._thread.join()
+ self._lock.close()
+ if self._error:
+ raise self._error
+
+ def _writer_loop(self) -> None:
+ try:
+ if self._fmt in ('jsonl', 'json'):
+ self._loop_jsonl()
+ elif self._fmt == 'csv':
+ self._loop_csv()
+ elif self._fmt == 'parquet':
+ self._loop_parquet()
+ except Exception as e:
+ self._error = e
+
+ def _acquire_lock(self):
+ self._lock.acquire()
+
+ def _release_lock(self):
+ self._lock.release()
+
+ def _loop_jsonl(self) -> None:
+ with open(self._path, 'a', encoding='utf-8') as f:
+ while True:
+ item = self._queue.get()
+ if item is _SENTINEL:
+ return
+ line = _json.dumps(item, ensure_ascii=False, default=_default_serializer) + '\n'
+ self._acquire_lock()
+ try:
+ f.write(line)
+ f.flush()
+ finally:
+ self._release_lock()
+
+ def _loop_csv(self) -> None:
+ import pandas as pd
+ header_written = False
+ buffer: List[Dict] = []
+ while True:
+ item = self._queue.get()
+ if item is _SENTINEL:
+ if buffer:
+ self._acquire_lock()
+ try:
+ pd.DataFrame(buffer).to_csv(self._path, mode='a', header=not header_written, index=False)
+ finally:
+ self._release_lock()
+ return
+ buffer.append(item)
+ if len(buffer) >= self._batch_size:
+ self._acquire_lock()
+ try:
+ pd.DataFrame(buffer).to_csv(self._path, mode='a', header=not header_written, index=False)
+ header_written = True
+ finally:
+ self._release_lock()
+ buffer = []
+
+ def _loop_parquet(self) -> None:
+ import pyarrow as pa
+ import pyarrow.parquet as pq
+ writer = None
+ buffer: List[Dict] = []
+ try:
+ while True:
+ item = self._queue.get()
+ if item is _SENTINEL:
+ if buffer:
+ table = pa.Table.from_pylist(buffer)
+ if writer is None:
+ writer = pq.ParquetWriter(self._path, table.schema)
+ self._acquire_lock()
+ try:
+ writer.write_table(table)
+ finally:
+ self._release_lock()
+ return
+ buffer.append(item)
+ if len(buffer) >= self._batch_size:
+ table = pa.Table.from_pylist(buffer)
+ if writer is None:
+ writer = pq.ParquetWriter(self._path, table.schema)
+ self._acquire_lock()
+ try:
+ writer.write_table(table)
+ finally:
+ self._release_lock()
+ buffer = []
+ finally:
+ if writer:
+ writer.close()
diff --git a/src/twinkle/dataset/iterable_dataset.py b/src/twinkle/dataset/iterable_dataset.py
index 21ae82f88..b985d83e9 100644
--- a/src/twinkle/dataset/iterable_dataset.py
+++ b/src/twinkle/dataset/iterable_dataset.py
@@ -29,6 +29,6 @@ def __getitem__(self, idx):
@remote_function()
def __iter__(self):
- # TODO if this class passed through actor handler, an error will occur:
- # a global single dataset, multiple dataloaders, the self._iter will cover each other
- return self.dataset.__iter__()
+ for row in self.dataset:
+ self._write_through(row)
+ yield row
diff --git a/src/twinkle/dataset/lazy_dataset.py b/src/twinkle/dataset/lazy_dataset.py
index 29f8f678e..383f85d7a 100644
--- a/src/twinkle/dataset/lazy_dataset.py
+++ b/src/twinkle/dataset/lazy_dataset.py
@@ -186,6 +186,7 @@ def __getitem__(self, idx):
elif self.do_check:
item = self.template.check(item)
+ self._write_through(item)
return item
@remote_function()
diff --git a/src/twinkle/dataset/packing_dataset.py b/src/twinkle/dataset/packing_dataset.py
index a940fdf6b..fa4acbd57 100644
--- a/src/twinkle/dataset/packing_dataset.py
+++ b/src/twinkle/dataset/packing_dataset.py
@@ -117,7 +117,9 @@ def __getitem__(self, index):
output = {}
for key in rows[0]:
output[key] = [r[key] for r in rows]
- if isinstance(rows[0][key], (list, np.ndarray)) and isinstance(rows[0][key][0], (int, float, np.number)):
+ if key in ('mm_token_type_ids', 'position_ids'):
+ output[key] = np.concatenate([np.asarray(v) for v in output[key]], axis=-1).tolist()
+ elif isinstance(rows[0][key], (list, np.ndarray)) and isinstance(rows[0][key][0], (int, float, np.number)):
output[key] = [v for lst in output[key] for v in lst]
return output
diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py
index 83158a281..a6d288015 100644
--- a/src/twinkle/infra/__init__.py
+++ b/src/twinkle/infra/__init__.py
@@ -1,10 +1,13 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import functools
import inspect
+import itertools
import json
import numpy as np
import os
-from typing import Any, Callable, List, Literal, Optional, TypeVar, Union
+import random
+import sys
+from typing import Any, AsyncIterator, Callable, List, Literal, Optional, TypeVar, Union
from twinkle.notifier import Notifier, notify_exception
from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, get_logger, requires
@@ -36,6 +39,30 @@
_TWINKLE_NOTIFIER_ENV = 'TWINKLE_NOTIFIER'
+def _capture_caller() -> Optional[str]:
+ """Return ``file:line`` of the first frame outside this module, or ``None``."""
+ f = sys._getframe(1)
+ while f and f.f_code.co_filename == __file__:
+ f = f.f_back
+ return f'{f.f_code.co_filename}:{f.f_lineno}' if f else None
+
+
+def _tag_exc(exc: BaseException, caller: Optional[str]) -> None:
+ """Stamp driver-caller location onto exc for both traceback and str(exc)."""
+ if not caller:
+ return
+ try:
+ marker = f'[twinkle] driver caller: {caller}'
+ if marker not in (getattr(exc, '__notes__', None) or []):
+ exc.add_note(marker)
+ if not getattr(exc, '_twinkle_caller_augmented', False):
+ prefix = f'[twinkle driver caller: {caller}] '
+ exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) if exc.args else (prefix.rstrip(),)
+ exc._twinkle_caller_augmented = True
+ except Exception: # noqa: BLE001
+ pass
+
+
def _maybe_load_worker_notifier() -> None:
"""Lazily reconstruct notifier + name on ray workers from inherited env vars."""
global _notifier, _name
@@ -384,13 +411,18 @@ def dispatch_func(arg, n):
# Comment this because remote_class supports `first``
# assert device_mesh.world_size == len(workers)
length = len(workers)
+ # Map actor index to global_rank: with gpus_per_worker>1, consecutive
+ # global ranks belong to the same actor (TP peers).
+ _mesh_world = device_mesh.world_size if device_mesh is not None else length
+ _rank_stride = max(1, _mesh_world // length)
def dispatch_func(arg, n):
import torch
if isinstance(arg, list) or isinstance(arg, torch.Tensor):
_args = []
for i in range(n):
- _args.append(arg[device_mesh.get_slice(len(arg), device_mesh.get_data_rank_from_global_rank(i))])
+ _args.append(arg[device_mesh.get_slice(
+ len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))])
return _args
elif isinstance(arg, dict):
_args = [{} for _ in range(n)]
@@ -487,15 +519,19 @@ def decorator(cls):
@functools.wraps(init_method)
def new_init(self, *args, **kwargs):
+ _caller = _capture_caller()
_ctx = f'{cls.__name__}.__init__'
+ if _caller:
+ _ctx = f'{_ctx} <- {_caller}'
try:
_maybe_load_worker_notifier()
- _new_init_body(self, *args, **kwargs)
+ _new_init_body(self, _caller, *args, **kwargs)
except Exception as _e: # noqa: BLE001
+ _tag_exc(_e, _caller)
notify_exception(_notifier, _ctx, _e, _name)
raise
- def _new_init_body(self, *args, **kwargs):
+ def _new_init_body(self, _caller, *args, **kwargs):
if _mode == 'local':
# Get the actual device_mesh
device_mesh = _get_device_mesh_param(args, kwargs)
@@ -519,10 +555,11 @@ def _new_init_body(self, *args, **kwargs):
from ._ray import RayHelper
# In case the same class created twice in the same device group
- # Try to get the caller's line
- frame = inspect.currentframe().f_back
- caller_file = frame.f_code.co_filename.replace(os.sep, '_').replace('.', '_')
- caller_line = frame.f_lineno
+ # Try to get the caller's line (resolved in ``new_init`` so it points
+ # at user code, not at the wrapper itself).
+ _cf, _, _cl = (_caller or f'{__file__}:0').rpartition(':')
+ caller_file = _cf.replace(os.sep, '_').replace('.', '_')
+ caller_line = _cl
# Pass an instance_id is recommended
instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}'
remote_group = kwargs.get('remote_group')
@@ -688,6 +725,10 @@ def decorator(func: Callable[..., T1]) -> Callable[..., T1]:
@functools.wraps(func)
def wrapper(self, *args, **kwargs) -> T1:
_ctx = f'{type(self).__name__}.{func.__name__}'
+ # Only capture caller on driver side; worker frames are Ray internals
+ _caller = _capture_caller() if hasattr(self, '_actors') else None
+ if _caller:
+ _ctx = f'{_ctx} <- {_caller}'
try:
device_mesh = getattr(self, 'device_mesh', None)
if _mode == 'local':
@@ -766,6 +807,7 @@ def _notifying_result_func(*rargs, **rkwargs):
try:
return _orig_result_func(*rargs, **rkwargs)
except Exception as _e: # noqa: BLE001
+ _tag_exc(_e, _caller)
notify_exception(_notifier, _ctx, _e, _name)
raise
@@ -779,6 +821,7 @@ def _notifying_result_func(*rargs, **rkwargs):
except StopIteration:
raise
except Exception as _e: # noqa: BLE001
+ _tag_exc(_e, _caller)
notify_exception(_notifier, _ctx, _e, _name)
raise
@@ -790,3 +833,27 @@ def _notifying_result_func(*rargs, **rkwargs):
return wrapper
return decorator
+
+
+async def _wrap_async_iter_with_notify(gen: AsyncIterator, ctx: str, caller: Optional[str] = None) -> AsyncIterator:
+ """Re-emit chunks from a local async generator and forward exceptions to the notifier."""
+ try:
+ async for chunk in gen:
+ yield chunk
+ except Exception as _e: # noqa: BLE001
+ _tag_exc(_e, caller)
+ notify_exception(_notifier, ctx, _e, _name)
+ raise
+
+
+async def _wrap_objrefgen_with_notify(ref_gen: Any, ctx: str, caller: Optional[str] = None) -> AsyncIterator:
+ """Drain a Ray ObjectRefGenerator chunk-by-chunk; forward exceptions to the notifier."""
+ import ray
+ try:
+ async for ref in ref_gen:
+ yield await ref
+ except Exception as _e: # noqa: BLE001
+ _tag_exc(_e, caller)
+ notify_exception(_notifier, ctx, _e, _name)
+ raise
+
diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py
index 4e4d0e82b..8e1d0e2ad 100644
--- a/src/twinkle/loss/__init__.py
+++ b/src/twinkle/loss/__init__.py
@@ -5,6 +5,7 @@
from .dpo import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss
from .gkd import GKDLoss
from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss
+from .infonce import InfonceLoss
from .mse import MSELoss
torch_loss_mapping = {
@@ -25,4 +26,6 @@
'simpo': SimPOLoss,
'cpo': CPOLoss,
'orpo': ORPOLoss,
+ # Embedding / contrastive losses
+ 'infonce': InfonceLoss,
}
diff --git a/src/twinkle/loss/base.py b/src/twinkle/loss/base.py
index 334d5eddc..5fd046ae7 100644
--- a/src/twinkle/loss/base.py
+++ b/src/twinkle/loss/base.py
@@ -6,6 +6,7 @@ class Loss:
require_logits = False
require_entropy = False
+ require_logps = True
def __call__(self, inputs: InputFeature, outputs: ModelOutput, **kwargs) -> LossOutput:
...
diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py
index abcc9591d..c1b5225d6 100644
--- a/src/twinkle/loss/cross_entropy.py
+++ b/src/twinkle/loss/cross_entropy.py
@@ -4,37 +4,28 @@
class CrossEntropyLoss(Loss):
- """Calculate CE from logps"""
+ """Calculate CE from logps, with optional DFT (arxiv 2508.05629) entropy weighting."""
- def __init__(self, ignore_index: int = -100, reduction='mean', **kwargs):
+ def __init__(self, ignore_index: int = -100, reduction='mean', dft: bool = False, **kwargs):
super().__init__()
self.ignore_index = ignore_index
self.reduction = reduction
+ self.dft = dft
def __call__(self, inputs, outputs, **kwargs):
labels = inputs['labels']
logps = outputs.get('logps')
- logits = outputs.get('logits')
- if logps is not None:
- loss_mask = (labels != self.ignore_index).float()
- if self.reduction != 'sum':
- return LossOutput(
- loss=(-logps * loss_mask).sum() / loss_mask.sum().clamp(min=1),
- num_tokens=0,
- )
- else:
- return LossOutput(
- loss=(-logps * loss_mask).sum(),
- num_tokens=loss_mask.sum().clamp(min=1),
- )
- else:
- import torch
- assert logits is not None
- logits = logits.view(-1, logits.shape[-1])
+ if logps is None:
+ import torch.nn.functional as F
+ logits = outputs['logits'].view(-1, outputs['logits'].shape[-1])
labels = labels.view(-1)
- loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels)
- if self.reduction != 'sum':
- return LossOutput(loss=loss, num_tokens=0)
- else:
- return LossOutput(loss=loss, num_tokens=(labels != self.ignore_index).sum())
+ logps = F.log_softmax(logits, dim=-1).gather(-1, labels.clamp(min=0).unsqueeze(-1)).squeeze(-1)
+
+ mask = (labels != self.ignore_index).float()
+ # DFT: -p·log(p) instead of -log(p)
+ per_token = -logps * logps.exp() if self.dft else -logps
+
+ if self.reduction != 'sum':
+ return LossOutput(loss=(per_token * mask).sum() / mask.sum().clamp(min=1), num_tokens=0)
+ return LossOutput(loss=(per_token * mask).sum(), num_tokens=mask.sum().clamp(min=1))
diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py
new file mode 100644
index 000000000..0bbb97288
--- /dev/null
+++ b/src/twinkle/loss/infonce.py
@@ -0,0 +1,272 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Embedding / contrastive losses for Twinkle.
+
+Inputs convention:
+ inputs['labels']: pair / multi-negative grouping labels (see each class docstring).
+ outputs['embeddings']: sentence embeddings produced by the model
+ (shape ``[B, D]``). Falls back to ``outputs['logits']`` for
+ backward-compatibility with the legacy hook-side pooling layout.
+
+All classes return :class:`LossOutput` with ``num_tokens=0`` (no per-token
+normalization, matching the convention used by ``DPOLoss``/``GRPOLoss``).
+"""
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from enum import Enum
+from torch import nn
+from typing import Optional
+
+from twinkle.data_format import LossOutput
+from .base import Loss
+
+
+# Borrowed from sentence_transformers.
+class SiameseDistanceMetric(Enum):
+ """Distance metrics available to the pairwise contrastive losses."""
+
+ EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa
+ MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa
+ COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa
+
+
+def _extract_sentences(outputs) -> torch.Tensor:
+ """Return [B, D] sentence embeddings from postprocess_tensor_sp output.
+
+ Prefers the canonical ``embeddings`` key (post-pooling); falls back to
+ ``logits`` (legacy hook-side pooling) and applies CLS pooling for 3-D.
+ """
+ sentences = outputs.get('embeddings')
+ if sentences is None:
+ sentences = outputs['logits']
+ if sentences.dim() == 3:
+ sentences = sentences[:, 0]
+ return sentences
+
+
+def _parse_pair_sentence(outputs):
+ """Split an interleaved [s1_0, s2_0, s1_1, s2_1, ...] tensor into (s1, s2)."""
+ sentences = _extract_sentences(outputs)
+ return sentences[0::2], sentences[1::2]
+
+
+def _parse_multi_negative_sentences(sentences: torch.Tensor,
+ labels: torch.Tensor,
+ hard_negatives: Optional[int] = None):
+ """Split a flat embedding tensor into per-sample groups.
+
+ ``labels`` is a 1-D mask where ``1`` marks the start of a new
+ ``anchor(1)+positive(1)+negatives(n)`` group; the inserted offsets account for
+ the anchor sitting immediately before each positive in the flat layout.
+ """
+ split_indices = torch.nonzero(labels, as_tuple=False).squeeze().tolist()
+ if isinstance(split_indices, int):
+ split_indices = [split_indices]
+ split_indices.append(len(labels))
+ split_tensors = []
+ for i in range(len(split_indices) - 1):
+ start, end = split_indices[i], split_indices[i + 1]
+ split_part = sentences[start:end]
+ if hard_negatives is not None:
+ negatives = len(split_part) - 2
+ assert negatives > 0
+ if negatives > hard_negatives:
+ split_part = split_part[:hard_negatives + 2]
+ elif negatives < hard_negatives:
+ # upsample negatives with replacement; skip index 0 (positive)
+ selected = np.random.choice(list(range(negatives)), size=hard_negatives - negatives, replace=True) + 1
+ split_part = torch.cat((split_part, split_part[selected]), dim=0)
+ split_tensors.append(split_part)
+ return split_tensors
+
+
+class InfonceLoss(Loss):
+ """InfoNCE contrastive loss with optional cross-DP gathering.
+
+ Each sample is laid out as ``anchor(1) + positive(1) + negatives(n)``;
+ ``inputs['labels']`` is a 1-D mask where ``1`` marks the start of every
+ such group. Setting ``use_batch=True`` enables in-batch negatives and,
+ when distributed is initialized, gathers embeddings from all DP ranks
+ (only the local shard keeps gradients).
+
+ Args:
+ temperature: Logit scaling factor.
+ use_batch: Include cross-sample (and cross-rank) in-batch negatives.
+ hard_negatives: Fix the per-sample negative count via truncation/upsampling.
+ ``None`` keeps the original variable counts.
+ mask_fake_negative: Mask any logit greater than ``positive + fake_neg_margin``.
+ fake_neg_margin: Threshold offset above the positive logit when masking.
+ include_qq: Append the query-query similarity block (self diagonal masked).
+ include_dd: Append the positive-doc to all-docs block (self positive masked).
+ process_group: Distributed process group used for the all-gather.
+ When ``None``, the default group (``dist.group.WORLD``) is used.
+ """
+
+ require_logits = True
+ require_entropy = False
+ require_logps = False
+
+ def __init__(
+ self,
+ temperature: float = 0.1,
+ use_batch: bool = True,
+ hard_negatives: Optional[int] = None,
+ mask_fake_negative: bool = False,
+ fake_neg_margin: float = 0.1,
+ include_qq: bool = False,
+ include_dd: bool = False,
+ process_group=None,
+ **kwargs,
+ ):
+ self.temperature = temperature
+ self.use_batch = use_batch
+ self.hard_negatives = hard_negatives
+ self.mask_fake_negative = mask_fake_negative
+ self.fake_neg_margin = fake_neg_margin
+ self.include_qq = include_qq
+ self.include_dd = include_dd
+ self.process_group = process_group
+
+ def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor):
+ """All-gather embeddings & labels across DP ranks; only local shard keeps grad."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return sentences, labels
+ world_size = dist.get_world_size(group=self.process_group)
+ if world_size <= 1:
+ return sentences, labels
+ rank = dist.get_rank(group=self.process_group)
+
+ # variable per-rank shapes require communicating shape first
+ local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long)
+ shapes = [torch.empty_like(local_shape) for _ in range(world_size)]
+ dist.all_gather(shapes, local_shape, group=self.process_group)
+ all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes]
+ dist.all_gather(all_sentences, sentences.contiguous(), group=self.process_group)
+
+ local_label_shape = labels.new_tensor(labels.shape, dtype=torch.long)
+ label_shapes = [torch.empty_like(local_label_shape) for _ in range(world_size)]
+ dist.all_gather(label_shapes, local_label_shape, group=self.process_group)
+ all_labels = [labels.new_empty(shape.tolist()) for shape in label_shapes]
+ dist.all_gather(all_labels, labels.contiguous(), group=self.process_group)
+
+ # keep the local shard differentiable; detach others
+ all_sentences[rank] = sentences
+ for idx in range(world_size):
+ if idx != rank:
+ all_sentences[idx] = all_sentences[idx].detach()
+ return torch.cat(all_sentences, dim=0), torch.cat(all_labels, dim=0)
+
+ def __call__(self, inputs, outputs, **kwargs) -> LossOutput:
+ labels = inputs['labels'].view(-1)
+ sentences = _extract_sentences(outputs)
+
+ if self.use_batch:
+ sentences, labels = self._gather_across_dp(sentences, labels)
+
+ split_tensors = _parse_multi_negative_sentences(sentences, labels, self.hard_negatives)
+ can_batched = self.hard_negatives is not None or len({s.shape[0] for s in split_tensors}) == 1
+
+ if not self.use_batch:
+ loss = self._intra_sample_loss(split_tensors, can_batched)
+ else:
+ loss = self._in_batch_loss(split_tensors, can_batched)
+ return LossOutput(loss=loss, num_tokens=0)
+
+ def _intra_sample_loss(self, split_tensors, can_batched) -> torch.Tensor:
+ """InfoNCE with only the per-sample negatives (no cross-sample sharing)."""
+ if can_batched:
+ sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D]
+ similarity_matrix = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / self.temperature
+ labels = torch.zeros(len(split_tensors), dtype=torch.int64, device=sentences.device)
+ return nn.CrossEntropyLoss()(similarity_matrix.squeeze(1), labels)
+
+ loss = 0
+ for tensor in split_tensors:
+ similarity_matrix = torch.matmul(tensor[0], tensor[1:].T) / self.temperature
+ labels = torch.tensor(0, device=tensor.device)
+ loss = loss + nn.CrossEntropyLoss()(similarity_matrix, labels)
+ return loss / len(split_tensors)
+
+ def _in_batch_loss(self, split_tensors, can_batched) -> torch.Tensor:
+ """InfoNCE with cross-sample (and optionally cross-rank) negatives."""
+ if can_batched:
+ return self._in_batch_loss_batched(split_tensors)
+ return self._in_batch_loss_unbatched(split_tensors)
+
+ def _in_batch_loss_batched(self, split_tensors) -> torch.Tensor:
+ sentences = torch.stack(split_tensors, dim=0) # [B, neg+2, D]
+ queries = sentences[:, 0] # [B, D]
+ docs_all = sentences[:, 1:].reshape(-1, sentences.size(2)) # [B*(neg+1), D]
+ qd_matrix = torch.matmul(queries, docs_all.T) # [B, B*(neg+1)]
+ # each row's positive sits at column row_idx * (neg+1)
+ block = sentences.size(1) - 1
+ labels = torch.arange(0, sentences.size(0) * block, block, device=sentences.device)
+
+ logits_list = [qd_matrix]
+
+ if self.include_qq:
+ qq_matrix = torch.matmul(queries, queries.T).clone()
+ qq_matrix.fill_diagonal_(float('-inf'))
+ logits_list.append(qq_matrix)
+
+ if self.include_dd:
+ pos_docs = sentences[:, 1] # [B, D]
+ dd_matrix = torch.matmul(pos_docs, docs_all.T) # [B, B*(neg+1)]
+ if block > 0:
+ row_idx = torch.arange(dd_matrix.size(0), device=dd_matrix.device)
+ dd_matrix[row_idx, row_idx * block] = float('-inf')
+ logits_list.append(dd_matrix)
+
+ if self.mask_fake_negative:
+ row_idx = torch.arange(qd_matrix.size(0), device=qd_matrix.device)
+ thresholds = (qd_matrix[row_idx, labels].view(-1, 1).detach() + self.fake_neg_margin)
+
+ qd_block = qd_matrix.clone()
+ qd_block[qd_block > thresholds] = float('-inf')
+ components = [qd_block]
+ if self.include_qq:
+ qq_block = logits_list[1].clone()
+ qq_block[qq_block > thresholds] = float('-inf')
+ components.append(qq_block)
+ if self.include_dd:
+ # align with Qwen3-Embedding: no threshold masking on d-d block
+ components.append(logits_list[-1])
+ similarity_matrix = torch.cat(components, dim=1)
+ else:
+ similarity_matrix = torch.cat(logits_list, dim=1)
+
+ return nn.CrossEntropyLoss()(similarity_matrix / self.temperature, labels)
+
+ def _in_batch_loss_unbatched(self, split_tensors) -> torch.Tensor:
+ # docs from every sample concatenated as a shared negative bank
+ docs_bank = torch.cat([t[1:] for t in split_tensors], dim=0)
+ queries_all = torch.stack([t[0] for t in split_tensors], dim=0) if self.include_qq else None
+
+ loss = 0
+ length = 0
+ for idx, tensor in enumerate(split_tensors):
+ qd_vec = torch.matmul(tensor[0], docs_bank.T)
+ target = torch.tensor(length, device=tensor.device)
+ threshold = qd_vec[target].detach() + self.fake_neg_margin
+
+ qd_masked = torch.where(qd_vec > threshold, qd_vec.new_full(
+ (), float('-inf')), qd_vec) if self.mask_fake_negative else qd_vec
+ logits_parts = [qd_masked]
+
+ if self.include_qq:
+ qq_vec = torch.matmul(tensor[0], queries_all.T).clone()
+ qq_vec[idx] = float('-inf')
+ if self.mask_fake_negative:
+ qq_vec = torch.where(qq_vec > threshold, qq_vec.new_full((), float('-inf')), qq_vec)
+ logits_parts.append(qq_vec)
+
+ if self.include_dd:
+ dd_vec = torch.matmul(tensor[1], docs_bank.T)
+ dd_vec[length] = float('-inf')
+ logits_parts.append(dd_vec)
+
+ logits_row = torch.cat(logits_parts, dim=-1) / self.temperature
+ loss = loss + nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0))
+ length += tensor.size(0) - 1
+ return loss / len(split_tensors)
diff --git a/src/twinkle/loss_scale/base.py b/src/twinkle/loss_scale/base.py
deleted file mode 100644
index 5cef9f1d4..000000000
--- a/src/twinkle/loss_scale/base.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Copyright (c) ModelScope Contributors. All rights reserved.
-
-
-class LossScale:
-
- pass
diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py
index ad244e1db..baeb6c1c9 100644
--- a/src/twinkle/metric/__init__.py
+++ b/src/twinkle/metric/__init__.py
@@ -3,6 +3,7 @@
from .base import Metric
from .completion_and_reward import CompletionRewardMetric
from .dpo import DPOMetric
+from .embedding import EmbeddingMetric
from .grpo import CISPOMetric, GRPOMetric, GSPOMetric
from .loss import LossMetric
from .train_metric import TrainMetric
diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py
new file mode 100644
index 000000000..9fb3aed8c
--- /dev/null
+++ b/src/twinkle/metric/embedding.py
@@ -0,0 +1,107 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from typing import List, Union
+
+from twinkle.data_format import InputFeature, ModelOutput
+from .base import Metric
+
+
+class EmbeddingMetric(Metric):
+ """Embedding similarity metric for InfoNCE training.
+
+ Reports anchor-positive cosine similarity stats (mean/min/max) and
+ average anchor-to-other-positives (in-batch negative) similarity.
+ Performs an extra all_gather to compute cross-rank statistics.
+ """
+
+ def __init__(self, device_mesh, process_group, **kwargs):
+ super().__init__(device_mesh, process_group, **kwargs)
+ self.reset()
+
+ def reset(self):
+ self.pos_sim_sum = 0.0
+ self.pos_sim_min = float('inf')
+ self.pos_sim_max = float('-inf')
+ self.pos_count = 0
+ self.neg_sim_sum = 0.0
+ self.neg_count = 0
+ self.total_loss = 0.0
+ self.total_count = 0
+ self.grad_norm = 0.0
+
+ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs):
+ sentences = outputs.get('embeddings')
+ if sentences is None:
+ sentences = outputs.get('logits')
+ if sentences is None:
+ return
+ if sentences.dim() == 3:
+ sentences = sentences[:, 0]
+
+ if not isinstance(inputs, list):
+ inputs = [inputs]
+ labels = torch.cat([inp['labels'].view(-1) for inp in inputs], dim=0)
+
+ # Gather embeddings and labels across DP for in-batch stats
+ if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
+ world_size = dist.get_world_size()
+ local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long)
+ shapes = [torch.empty_like(local_shape) for _ in range(world_size)]
+ dist.all_gather(shapes, local_shape)
+ all_sentences = [sentences.new_empty(s.tolist()) for s in shapes]
+ dist.all_gather(all_sentences, sentences.contiguous())
+ sentences = torch.cat(all_sentences, dim=0)
+
+ local_lshape = labels.new_tensor(labels.shape, dtype=torch.long)
+ lshapes = [torch.empty_like(local_lshape) for _ in range(world_size)]
+ dist.all_gather(lshapes, local_lshape)
+ all_labels = [labels.new_empty(s.tolist()) for s in lshapes]
+ dist.all_gather(all_labels, labels.contiguous())
+ labels = torch.cat(all_labels, dim=0)
+
+ anchor_idx = torch.nonzero(labels, as_tuple=False).squeeze(-1)
+ if anchor_idx.numel() == 0:
+ return
+
+ anchors = sentences[anchor_idx]
+ positives = sentences[anchor_idx + 1]
+
+ # Anchor-positive cosine similarity
+ pos_cos = F.cosine_similarity(anchors, positives, dim=1)
+ self.pos_sim_sum += pos_cos.sum().item()
+ self.pos_sim_min = min(self.pos_sim_min, pos_cos.min().item())
+ self.pos_sim_max = max(self.pos_sim_max, pos_cos.max().item())
+ self.pos_count += pos_cos.numel()
+
+ # Anchor vs all other positives (in-batch negatives)
+ if anchors.size(0) > 1:
+ sim_matrix = torch.matmul(anchors, positives.T)
+ mask = ~torch.eye(sim_matrix.size(0), dtype=torch.bool, device=sim_matrix.device)
+ neg_sims = sim_matrix[mask]
+ self.neg_sim_sum += neg_sims.sum().item()
+ self.neg_count += neg_sims.numel()
+
+ loss = outputs.get('loss')
+ if loss is not None:
+ self.total_loss += loss.item() if hasattr(loss, 'item') else loss
+ self.total_count += 1
+ grad_norm = kwargs.get('grad_norm')
+ if grad_norm is not None:
+ self.grad_norm = grad_norm
+
+ def calculate(self):
+ results = {}
+ if self.pos_count > 0:
+ results['pos_sim'] = f'{self.pos_sim_sum / self.pos_count:.4f}'
+ results['pos_sim_min'] = f'{self.pos_sim_min:.4f}'
+ results['pos_sim_max'] = f'{self.pos_sim_max:.4f}'
+ if self.neg_count > 0:
+ results['neg_sim'] = f'{self.neg_sim_sum / self.neg_count:.4f}'
+ if self.total_count > 0:
+ results['loss'] = f'{self.total_loss / self.total_count:.4f}'
+ if self.grad_norm > 0:
+ results['grad_norm'] = f'{self.grad_norm:.6f}'
+ self.reset()
+ return results
diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py
index 96dace652..5a5426a68 100644
--- a/src/twinkle/model/megatron/megatron.py
+++ b/src/twinkle/model/megatron/megatron.py
@@ -1,5 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import asyncio
+import contextlib
import json
import logging
import numpy as np
@@ -31,7 +32,7 @@
from twinkle.metric import LossMetric, Metric, TrainMetric
from twinkle.model.base import TwinkleModel
from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus
-from twinkle.patch import Patch, apply_patch
+from twinkle.patch import Patch, apply_context, apply_patch
from twinkle.processor import InputProcessor
from twinkle.template import Template
from twinkle.utils import construct_class, get_logger, selective_log_softmax
@@ -41,6 +42,22 @@
logger = get_logger()
+def _resolve_task_context(model, task):
+ """Return a context manager that applies the right per-forward Patch for ``task``.
+
+ Mirrors the transformers backend: 'causal_lm' (default) is a no-op, while
+ 'embedding' installs :class:`MegatronEmbeddingPatch` which swaps the
+ ``output_layer`` for identity (with TP/SP gather) and registers a hook that
+ handles CP gather + last-token pooling, returning ``[n_seqs, hidden]``.
+ """
+ if task in (None, 'causal_lm'):
+ return contextlib.nullcontext()
+ if task == 'embedding':
+ from twinkle.patch.megatron_emb import MegatronEmbeddingPatch
+ return apply_context(model, MegatronEmbeddingPatch())
+ raise ValueError(f'Unknown task={task!r}; expected one of: causal_lm, embedding.')
+
+
@dataclass
class MegatronOptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Megatron training.
@@ -286,6 +303,7 @@ def forward_backward(self,
temperature = float(kwargs.pop('temperature', 1.0))
forward_only = kwargs.pop('forward_only', False)
return_logits = kwargs.pop('return_logits', False)
+ task = kwargs.pop('task', 'causal_lm')
optimizer_config = self.optimizer_group[adapter_name]
loss_instance = self.optimizer_group[adapter_name].loss_instance
if not inputs:
@@ -349,14 +367,17 @@ def forward_backward(self,
_mb_counter = [0] # mutable counter for closure
- def post_loss_function(output_tensor, inputs, logps, unpacked_logits=None, entropies=None):
+ def post_loss_function(output_tensor, inputs, logps, unpacked_logits=None, entropies=None, embeddings=None):
mb_idx = _mb_counter[0]
_mb_counter[0] += 1
current_kwargs = loss_extra_kwargs_per_mb[mb_idx % len(loss_extra_kwargs_per_mb)]
- logits = unpacked_logits if unpacked_logits is not None else output_tensor
- outputs = ModelOutput(logits=logits, logps=logps)
- if entropies is not None:
- outputs['entropies'] = entropies
+ if embeddings is not None:
+ outputs = ModelOutput(embeddings=embeddings)
+ else:
+ logits = unpacked_logits if unpacked_logits is not None else output_tensor
+ outputs = ModelOutput(logits=logits, logps=logps)
+ if entropies is not None:
+ outputs['entropies'] = entropies
result = loss_instance(inputs, outputs, **current_kwargs)
if unpacked_logits is not None:
outputs.pop('logits', None)
@@ -390,21 +411,29 @@ def forward_step_func(data_iterator, model):
logps = None
unpacked_logits = None
entropies = None
+ embeddings = None
_loss_instance = loss_instance
- if labels is not None and mpu.is_pipeline_last_stage(False, unwrapped_model.vp_stage):
- loss_mask = (labels != -100).bool()
- masked_labels = labels.clone()
- masked_labels[~loss_mask] = 0
- output_tensor.div_(temperature)
+ is_last_pp = mpu.is_pipeline_last_stage(False, unwrapped_model.vp_stage)
+ if task == 'embedding':
+ # MegatronEmbeddingPatch already pooled output to [n_seqs, hidden] on last PP stage.
+ if is_last_pp:
+ embeddings = output_tensor
+ elif labels is not None and is_last_pp:
+ _loss_require_logps = getattr(_loss_instance, 'require_logps', True)
_loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy)
- if _loss_require_entropy:
- logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True)
- else:
- logps = selective_log_softmax(output_tensor, masked_labels)
- # Reconstruct full-length tensors from CP-split shards
- logps = processor.postprocess_tensor_cp(logps)
- if entropies is not None:
- entropies = processor.postprocess_tensor_cp(entropies)
+ if _loss_require_logps:
+ loss_mask = (labels != -100).bool()
+ masked_labels = labels.clone()
+ masked_labels[~loss_mask] = 0
+ output_tensor.div_(temperature)
+ if _loss_require_entropy:
+ logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True)
+ else:
+ logps = selective_log_softmax(output_tensor, masked_labels)
+ # Reconstruct full-length tensors from CP-split shards
+ logps = processor.postprocess_tensor_cp(logps)
+ if entropies is not None:
+ entropies = processor.postprocess_tensor_cp(entropies)
batch['labels'] = processor.postprocess_tensor_cp(labels)
if 'position_ids' in batch:
pos = batch['position_ids']
@@ -427,6 +456,7 @@ def forward_step_func(data_iterator, model):
logps=logps,
unpacked_logits=unpacked_logits,
entropies=entropies,
+ embeddings=embeddings,
)
# Get Megatron's forward-backward function
@@ -446,15 +476,16 @@ def forward_step_func(data_iterator, model):
# Run forward-backward with Megatron's scheduler
# Megatron handles all communication internally using proper process groups
- losses = forward_backward_func(
- forward_step_func=forward_step_func,
- data_iterator=data_iter,
- model=self.model,
- num_microbatches=len(inputs),
- seq_length=seq_length,
- micro_batch_size=micro_batch_size,
- forward_only=forward_only,
- )
+ with _resolve_task_context(self.model, task):
+ losses = forward_backward_func(
+ forward_step_func=forward_step_func,
+ data_iterator=data_iter,
+ model=self.model,
+ num_microbatches=len(inputs),
+ seq_length=seq_length,
+ micro_batch_size=micro_batch_size,
+ forward_only=forward_only,
+ )
# Extract loss from results (only last PP stage returns non-empty)
loss = torch.tensor(0.0).to(Platform.get_local_device())
@@ -565,7 +596,7 @@ def step(self, **kwargs):
success, grad_norm, num_zeros = optimizer.step()
# Store grad_norm for later retrieval
- optimizer_config._last_grad_norm = grad_norm if grad_norm is not None else 0.0
+ optimizer_config._last_grad_norm = grad_norm.detach().cpu().item() if grad_norm is not None else 0.0
optimizer_config._last_step_success = success
def _is_model_ddp_wrapped(self) -> bool:
@@ -920,12 +951,11 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
)
else:
bridge = self.strategy.bridge
- for _model in self.strategy.unwrap_model(self.model):
- bridge.load_weights(
- _model,
- checkpoint_dir,
- peft_format=(adapter_name != _default_adapter_name),
- )
+ bridge.load_weights(
+ self.strategy.unwrap_model(self.model),
+ checkpoint_dir,
+ peft_format=(adapter_name != _default_adapter_name),
+ )
if dist.is_initialized():
dist.barrier()
diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py
index 8388c8e43..9915d2038 100644
--- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py
+++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py
@@ -987,6 +987,21 @@ def _trim_gathered_sequence_padding(tensor: torch.Tensor, real_position_ids: tor
return torch.cat(pieces, dim=1).contiguous() if pieces else tensor[:, :0].contiguous()
return tensor[:, :real_position_ids.shape[-1]].contiguous()
+ def gather_features(self, features: torch.Tensor) -> torch.Tensor:
+ """All-gather SP-sharded per-token features ``[B, T_local, H]`` -> ``[B, T_real, H]``.
+
+ Mirrors the gather + trim path used for logps but operates directly on
+ hidden_states, so embedding pooling can run on the full sequence with
+ the same ``real_position_ids`` source of truth.
+ """
+ if features is None or not torch.is_tensor(features):
+ return features
+ if not self.enabled or self.ulysses_size <= 1:
+ return features
+ real_position_ids = sequence_parallel.real_position_ids
+ gathered, _ = GatherLoss.apply(features, None, 1, real_position_ids)
+ return self._trim_gathered_sequence_padding(gathered, real_position_ids)
+
def gather_loss_tensors(
self,
inputs: Dict[str, Any],
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 424d3298c..b153a2054 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -36,7 +36,7 @@
from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus
from twinkle.model.transformers.moe import apply_expert_parallel
from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy
-from twinkle.patch import Patch, apply_patch
+from twinkle.patch import Patch, apply_context, apply_patch
from twinkle.processor import InputProcessor
from twinkle.template import Template
from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util
@@ -47,6 +47,22 @@
logger = get_logger()
+def _resolve_task_context(model, task):
+ """Return a context manager that applies the right per-forward Patch for ``task``.
+
+ 'causal_lm' (default) keeps the model untouched (returns ``nullcontext``).
+ 'embedding' swaps lm_head for identity + installs a feature-extraction hook so
+ downstream pooling can run inside
+ ``InputProcessor.postprocess_tensor_sp(task='embedding', ...)``.
+ """
+ if task in (None, 'causal_lm'):
+ return contextlib.nullcontext()
+ if task == 'embedding':
+ from twinkle.patch.transformers_emb import TransformersEmbeddingPatch
+ return apply_context(model, TransformersEmbeddingPatch())
+ raise ValueError(f'Unknown task={task!r}; expected one of: causal_lm, embedding.')
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -106,6 +122,8 @@ def accumulate_metrics(self, is_training):
self._ensure_dp_group()
status = self.train_status if is_training else self.eval_status
if len(status.metrics) > 0 and status.inputs is not None and status.outputs is not None:
+ forward_kwargs = copy(status.forward_kwargs)
+ forward_kwargs.pop('gradient_accumulation_steps', None)
for metric in status.metrics:
metric.accumulate(
status.inputs,
@@ -115,7 +133,7 @@ def accumulate_metrics(self, is_training):
gradient_accumulation_steps=self.gradient_accumulation_steps,
grad_norm=self._last_grad_norm,
loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'),
- **status.forward_kwargs)
+ **forward_kwargs)
_default_adapter_name = ''
@@ -380,6 +398,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
temperature = float(kwargs.pop('temperature', 1.0))
return_logits = kwargs.pop('return_logits', False)
+ task = kwargs.pop('task', 'causal_lm')
optimizer_config = self.optimizer_group[adapter_name]
self._lazy_wrap_model()
if not inputs:
@@ -397,6 +416,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
loss_instance = optimizer_config.loss_instance
loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits)
loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy)
+ loss_require_logps = getattr(loss_instance, 'require_logps', True)
assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding'
inputs: Dict[str, Any] = processor(
inputs,
@@ -407,9 +427,10 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
)
labels: torch.Tensor = inputs.pop('labels', None)
optimizer_config.accumulate_metrics(True)
- outputs = self.model(**inputs)
+ with _resolve_task_context(self.model, task):
+ outputs = self.model(**inputs)
inputs['labels'] = labels
- if labels is not None:
+ if labels is not None and loss_require_logps:
loss_mask = (labels != -100).bool()
masked_labels = labels.clone()
masked_labels[~loss_mask] = 0
@@ -424,8 +445,8 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
outputs['past_key_values'] = None
if not (return_logits or loss_require_logits):
outputs['logits'] = None
- inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy)
- inputs, outputs = processor.unpack_packed_sequences(inputs, outputs)
+ inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy, task=task)
+ inputs, outputs = processor.unpack_packed_sequences(inputs, outputs, task=task)
optimizer_config.train_status.inputs = inputs
optimizer_config.train_status.outputs = outputs
optimizer_config.train_status.forward_kwargs = kwargs
@@ -451,6 +472,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
disable_lora = kwargs.pop('disable_lora', False)
temperature = float(kwargs.pop('temperature', 1.0))
return_logits = kwargs.pop('return_logits', False)
+ task = kwargs.pop('task', 'causal_lm')
optimizer_config = self.optimizer_group[adapter_name]
self._lazy_wrap_model()
if not inputs:
@@ -470,6 +492,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
loss_instance = optimizer_config.loss_instance
loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits)
loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy)
+ loss_require_logps = getattr(loss_instance, 'require_logps', True)
inputs: Dict[str, Any] = processor(
inputs,
sp_strategy=self.sp_strategy,
@@ -480,13 +503,13 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
labels = inputs.pop('labels', None)
optimizer_config.accumulate_metrics(False)
unwrapped_model = self.strategy.unwrap_model(self.model)
- if disable_lora and isinstance(unwrapped_model, PeftModel):
- with unwrapped_model.disable_adapter():
- outputs = self.model(**inputs)
- else:
+ lora_ctx = (
+ unwrapped_model.disable_adapter()
+ if disable_lora and isinstance(unwrapped_model, PeftModel) else contextlib.nullcontext())
+ with _resolve_task_context(self.model, task), lora_ctx:
outputs = self.model(**inputs)
inputs['labels'] = labels
- if labels is not None:
+ if labels is not None and loss_require_logps:
loss_mask = (labels != -100).bool()
masked_labels = labels.clone()
masked_labels[~loss_mask] = 0
@@ -501,8 +524,8 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
outputs['past_key_values'] = None
if not (return_logits or loss_require_logits):
outputs['logits'] = None
- inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy)
- inputs, outputs = processor.unpack_packed_sequences(inputs, outputs)
+ inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy, task=task)
+ inputs, outputs = processor.unpack_packed_sequences(inputs, outputs, task=task)
optimizer_config.eval_status.inputs = inputs
optimizer_config.eval_status.outputs = outputs
optimizer_config.eval_status.forward_kwargs = kwargs
@@ -582,7 +605,7 @@ def backward(self, **kwargs):
scaler = optimizer_config.scaler
optimizer_config.cur_step += 1
- should_sync = optimizer_config.do_grad_sync()
+ should_sync = optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps'))
import contextlib
no_sync_ctx = contextlib.nullcontext()
diff --git a/src/twinkle/patch/__init__.py b/src/twinkle/patch/__init__.py
index 76d42eb94..da7a0165c 100644
--- a/src/twinkle/patch/__init__.py
+++ b/src/twinkle/patch/__init__.py
@@ -1,14 +1,30 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import sys
+from contextlib import contextmanager
from typing import Any, Type, Union
from .base import Patch
+def _resolve(patch_cls: Union[Patch, Type[Patch], str]) -> Patch:
+ from twinkle.utils import construct_class
+ return construct_class(patch_cls, Patch, sys.modules[__name__])
+
+
def apply_patch(module: Any, patch_cls: Union[Patch, Type[Patch], str], *args, **kwargs):
- from ..utils import construct_class
- patch_ins = construct_class(patch_cls, Patch, sys.modules[__name__])
+ patch_ins = _resolve(patch_cls)
return patch_ins(module, *args, **kwargs)
-__all__ = ['apply_patch', 'Patch']
+@contextmanager
+def apply_context(module: Any, patch_cls: Union[Patch, Type[Patch], str], *args, **kwargs):
+ # Apply patch on enter; revert via subclass-implemented unpatch on exit (even on exception).
+ patch_ins = _resolve(patch_cls)
+ result = patch_ins(module, *args, **kwargs)
+ try:
+ yield result
+ finally:
+ patch_ins.unpatch(module, *args, **kwargs)
+
+
+__all__ = ['apply_patch', 'apply_context', 'Patch']
diff --git a/src/twinkle/patch/base.py b/src/twinkle/patch/base.py
index 08982ba92..3a9b8c079 100644
--- a/src/twinkle/patch/base.py
+++ b/src/twinkle/patch/base.py
@@ -9,3 +9,6 @@ class Patch:
def __call__(self, module: Union['torch.nn.Module', List['torch.nn.Module'], Any], *args, **kwargs):
...
+
+ def unpatch(self, module: Union['torch.nn.Module', List['torch.nn.Module'], Any], *args, **kwargs):
+ raise NotImplementedError()
diff --git a/src/twinkle/patch/megatron_emb.py b/src/twinkle/patch/megatron_emb.py
new file mode 100644
index 000000000..f9621509a
--- /dev/null
+++ b/src/twinkle/patch/megatron_emb.py
@@ -0,0 +1,137 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Patch a Megatron causal LM into a sentence-embedding model.
+
+Two mutations applied to every pipeline-last-stage chunk (``post_process=True``):
+
+1. ``output_layer.forward`` (a ``ColumnParallelLinear``) is replaced with an
+ identity that returns ``(hidden_states, None)``. When ``sequence_parallel``
+ is enabled, the gather across the TP group that ``ColumnParallelLinear``
+ normally performs is mirrored, so the chunk's forward hook always sees a
+ full-length ``[s, b, h]`` tensor.
+2. A forward hook on the chunk gathers across CP (when ``cp_size > 1``),
+ pools the last valid token (per-segment via ``packed_seq_params.cu_seqlens_q``
+ for padding-free batches; per-row via ``position_ids`` for padded batches),
+ L2-normalises and returns ``[n_seqs, hidden]`` embeddings.
+
+Intermediate PP stages (``post_process=False``) are left untouched.
+
+Both mutations are reverted by ``unpatch``.
+"""
+import torch
+import torch.nn.functional as F
+from types import MethodType
+from typing import List, Optional
+
+from twinkle.patch import Patch
+from twinkle.utils.torch_utils import gather_cp_load_balanced
+
+
+def _last_valid_from_position_ids(position_ids: torch.Tensor) -> torch.Tensor:
+ if position_ids.dim() == 3:
+ position_ids = position_ids[0]
+ valid = (position_ids >= 0).int()
+ seq_len = valid.shape[-1]
+ return seq_len - 1 - torch.fliplr(valid).argmax(dim=-1)
+
+
+def _last_valid_from_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor:
+ seq_len = attention_mask.shape[1]
+ return seq_len - 1 - torch.fliplr(attention_mask).argmax(dim=1)
+
+
+def _resolve_cp_group(module) -> Optional[object]:
+ cp_group = getattr(module, 'cp_group', None)
+ if cp_group is None:
+ pg = getattr(module, 'pg_collection', None)
+ cp_group = getattr(pg, 'cp', None) if pg is not None else None
+ return cp_group
+
+
+def _output_embedding_hook(module, args, kwargs, output):
+ if not torch.is_tensor(output) or output.dim() != 3:
+ return output
+
+ cp_group = _resolve_cp_group(module)
+ if cp_group is not None and cp_group.size() > 1:
+ output = gather_cp_load_balanced(output, cp_group, seq_dim=1)
+
+ packed_seq_params = kwargs.get('packed_seq_params', None)
+ if packed_seq_params is not None:
+ cu = getattr(packed_seq_params, 'cu_seqlens_q', None)
+ if cu is not None and cu.numel() >= 2:
+ # cu is full-seq based (built before CP split), so it indexes the gathered output directly.
+ last_idx = (cu[1:].long() - 1).to(output.device)
+ embeddings = output[0, last_idx]
+ return F.normalize(embeddings, p=2, dim=1).contiguous()
+
+ position_ids = kwargs.get('position_ids', None)
+ attention_mask = kwargs.get('attention_mask', None)
+ if position_ids is not None and cp_group is not None and cp_group.size() > 1:
+ position_ids = gather_cp_load_balanced(
+ position_ids if position_ids.dim() >= 2 else position_ids.unsqueeze(0),
+ cp_group,
+ seq_dim=1,
+ )
+
+ if position_ids is not None:
+ last_idx = _last_valid_from_position_ids(position_ids)
+ elif attention_mask is not None and attention_mask.dim() == 2:
+ last_idx = _last_valid_from_attention_mask(attention_mask)
+ else:
+ last_idx = torch.full((output.shape[0], ), output.shape[1] - 1, device=output.device, dtype=torch.long)
+
+ last_idx = last_idx.to(device=output.device, dtype=torch.long)
+ embeddings = output[torch.arange(output.shape[0], device=output.device), last_idx]
+ return F.normalize(embeddings, p=2, dim=1).contiguous()
+
+
+def _identity_output_layer(self, hidden_states, weight=None, runtime_gather_output=None, **kwargs):
+ # Mirror ColumnParallelLinear's seq-parallel gather so the hook sees full [s, b, h].
+ if getattr(self, 'sequence_parallel', False):
+ from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
+ hidden_states = gather_from_sequence_parallel_region(
+ hidden_states, tensor_parallel_output_grad=True, group=self.tp_group)
+ return hidden_states, None
+
+
+def _iter_chunks(module) -> List[torch.nn.Module]:
+ if isinstance(module, (list, tuple)):
+ return [m for m in module if isinstance(m, torch.nn.Module)]
+ return [module]
+
+
+def _find_post_process_owner(chunk: torch.nn.Module) -> Optional[torch.nn.Module]:
+ """Locate the GPTModel-like owner of ``output_layer`` inside a chunk.
+
+ Walks all submodules so it transparently handles DDP/Float16Module/PeftModel wrappers.
+ """
+ for sub in chunk.modules():
+ layer = getattr(sub, 'output_layer', None)
+ post_process = getattr(sub, 'post_process', None)
+ if isinstance(layer, torch.nn.Module) and (post_process is None or post_process):
+ return sub
+ return None
+
+
+class MegatronEmbeddingPatch(Patch):
+ """Convert a Megatron causal LM into a sentence-embedding model. Reversible via ``unpatch``."""
+
+ def __call__(self, module, *args, **kwargs):
+ self._patched = []
+ for chunk in _iter_chunks(module):
+ owner = _find_post_process_owner(chunk)
+ if owner is None:
+ continue
+ output_layer = owner.output_layer
+ origin_forward = output_layer.forward
+ output_layer.forward = MethodType(_identity_output_layer, output_layer)
+ hook_handle = owner.register_forward_hook(_output_embedding_hook, with_kwargs=True)
+ self._patched.append((output_layer, origin_forward, hook_handle))
+ return module
+
+ def unpatch(self, module, *args, **kwargs):
+ for output_layer, origin_forward, hook_handle in self._patched:
+ hook_handle.remove()
+ output_layer.forward = origin_forward
+ self._patched = []
+ return module
diff --git a/src/twinkle/patch/no_split_modules.py b/src/twinkle/patch/no_split_modules.py
new file mode 100644
index 000000000..7d8aee58f
--- /dev/null
+++ b/src/twinkle/patch/no_split_modules.py
@@ -0,0 +1,21 @@
+from typing import Set, Union
+
+from twinkle.patch import Patch
+
+
+class NoSplitModulesPatch(Patch):
+ """Set _no_split_modules on a model so FSDP2 respects layer boundaries."""
+
+ def __init__(self, module_names: Union[Set[str], str] = frozenset({'Qwen3_5DecoderLayer'})):
+ if isinstance(module_names, str):
+ module_names = {module_names}
+ self._names = set(module_names)
+
+ def __call__(self, module, *args, **kwargs):
+ module._no_split_modules = self._names
+ return module
+
+ def unpatch(self, module, *args, **kwargs):
+ if hasattr(module, '_no_split_modules'):
+ del module._no_split_modules
+ return module
diff --git a/src/twinkle/patch/qwen3_chat_template.py b/src/twinkle/patch/qwen3_chat_template.py
index 822f8e8ea..85af99bd5 100644
--- a/src/twinkle/patch/qwen3_chat_template.py
+++ b/src/twinkle/patch/qwen3_chat_template.py
@@ -50,6 +50,14 @@
" {%- set content = _parts[1].lstrip('\\n') %}\n"
' {%- endif %}')
+_OLD_TAIL = ('{%- if ns.multi_step_tool %}\n'
+ " {{- raise_exception('No user query found in messages.') }}\n"
+ '{%- endif %}')
+
+_NEW_TAIL = ('{%- if ns.multi_step_tool %}\n'
+ ' {#- patched: tool-tail prefix allowed (Qwen3AllowToolTailTemplate) -#}\n'
+ '{%- endif %}')
+
class Qwen3ChatTemplate(Patch):
"""Patch tokenizer.chat_template in-place to fix Qwen3.x parse defects.
@@ -81,3 +89,34 @@ def __call__(self, tokenizer, *args, **kwargs):
return False
tokenizer.chat_template = tmpl.replace(_OLD, _NEW, 1)
return True
+
+
+class Qwen3AllowToolTailTemplate(Patch):
+ """Relax Qwen3.x ``multi_step_tool`` check so prefixes ending in ``tool``
+ (or whose only user messages are ```` wrappers) render
+ instead of raising ``No user query found in messages``.
+
+ Required by ScoreFilter when scoring intermediate assistant turns of
+ multi-turn agent rollouts: the slice ``messages[:asst_idx]`` legitimately
+ ends with a ``tool`` message, and skipping such rounds would silently
+ discard exactly the turns where tool-call accuracy lives.
+ """
+
+ def __call__(self, tokenizer, *args, **kwargs):
+ tmpl = getattr(tokenizer, 'chat_template', None)
+ if not tmpl or not isinstance(tmpl, str):
+ return False
+ if _NEW_TAIL in tmpl:
+ return False
+ if _OLD_TAIL not in tmpl:
+ warnings.warn(
+ 'Qwen3AllowToolTailTemplate patch: expected OLD multi_step_tool '
+ 'block not found in tokenizer.chat_template. Upstream template '
+ 'may have diverged; skipping patch. ScoreFilter on multi-turn '
+ 'agent prefixes will likely raise TemplateError.',
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ return False
+ tokenizer.chat_template = tmpl.replace(_OLD_TAIL, _NEW_TAIL, 1)
+ return True
diff --git a/src/twinkle/patch/transformers_emb.py b/src/twinkle/patch/transformers_emb.py
new file mode 100644
index 000000000..f311c00af
--- /dev/null
+++ b/src/twinkle/patch/transformers_emb.py
@@ -0,0 +1,87 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Patch a HF transformers causal LM into a sentence-embedding model.
+
+Two mutations applied to the model:
+
+1. ``lm_head.forward`` is replaced with identity, so the wrapped model returns
+ the final hidden states under ``output.logits``.
+2. A forward hook on the lm-head-bearing submodule L2-normalizes per-token
+ hidden states and stores them under ``outputs['features']`` (shape
+ ``[B, T, H]`` or ``[B, T_local, H]`` under SP).
+
+Last-token pooling (incl. padding-free, SP gather) is **deferred** to
+``InputProcessor.postprocess_tensor_sp(task='embedding', ...)`` so this patch
+stays SP/CP/packed-agnostic and the dispatch sits in one place.
+
+Both mutations are reverted by ``unpatch``.
+"""
+from types import MethodType
+from typing import TYPE_CHECKING, Optional
+
+from twinkle.patch import Patch
+
+if TYPE_CHECKING:
+ import torch
+
+_LM_HEADS = ['lm_head', 'output', 'embed_out', 'output_layer']
+
+
+def get_lm_head_model(module, lm_heads=None):
+ from peft import PeftModel
+ from torch.nn import Module
+ if isinstance(module, PeftModel):
+ module = module.model
+ if lm_heads is None:
+ lm_heads = _LM_HEADS
+ for sub in module.modules():
+ for name in lm_heads:
+ child = getattr(sub, name, None)
+ if isinstance(child, Module):
+ return sub
+ return module
+
+
+def _output_features_hook(module, args, kwargs, output):
+ import torch.nn.functional as F
+ hidden_states = output.logits
+ return {'features': F.normalize(hidden_states, p=2, dim=-1).contiguous()}
+
+
+def _identity_forward(self, hidden_states):
+ return hidden_states
+
+
+class TransformersEmbeddingPatch(Patch):
+ """Convert a causal LM into a sentence-embedding feature extractor. Reversible via ``unpatch``."""
+
+ def __call__(self, module, *args, **kwargs):
+ from torch.nn import Module
+ lm_head_model = get_lm_head_model(module, lm_heads=_LM_HEADS)
+
+ head: Optional[Module] = None
+ for name in _LM_HEADS:
+ if hasattr(lm_head_model, name):
+ head = getattr(lm_head_model, name)
+ break
+ assert head is not None, 'Cannot find the proper lm_head name'
+
+ # Save originals BEFORE mutation so unpatch can restore them verbatim.
+ self._head = head
+ self._origin_forward = head.forward
+ head.forward = MethodType(_identity_forward, head)
+ self._hook_handle = lm_head_model.register_forward_hook(_output_features_hook, with_kwargs=True)
+ return module
+
+ def unpatch(self, module, *args, **kwargs):
+ handle = getattr(self, '_hook_handle', None)
+ if handle is not None:
+ handle.remove()
+ self._hook_handle = None
+
+ head = getattr(self, '_head', None)
+ origin = getattr(self, '_origin_forward', None)
+ if head is not None and origin is not None:
+ head.forward = origin
+ self._origin_forward = None
+ self._head = None
+ return module
diff --git a/src/twinkle/preprocessor/base.py b/src/twinkle/preprocessor/base.py
index 06ad06baa..0225d3c1e 100644
--- a/src/twinkle/preprocessor/base.py
+++ b/src/twinkle/preprocessor/base.py
@@ -7,7 +7,9 @@
class Preprocessor:
@staticmethod
- def map_col_to_row(rows: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
+ def map_col_to_row(rows) -> List[Dict[str, Any]]:
+ if isinstance(rows, list):
+ return rows
if not rows:
return []
_new_rows = []
@@ -20,12 +22,14 @@ def map_col_to_row(rows: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
return _new_rows
@staticmethod
- def map_row_to_col(rows: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
+ def map_row_to_col(rows, keys: List[str] = None) -> Dict[str, List[Any]]:
+ if isinstance(rows, dict):
+ return rows
if not rows:
- return {}
+ return {k: [] for k in keys} if keys else {}
columns: Dict[str, List[Any]] = {}
- keys = rows[0].keys()
+ keys = keys or rows[0].keys()
for key in keys:
columns[key] = [row[key] for row in rows]
diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py
index f63833a3a..69942b23f 100644
--- a/src/twinkle/processor/base.py
+++ b/src/twinkle/processor/base.py
@@ -143,12 +143,98 @@ def postprocess_tensor_sp(self, inputs: Dict[str, Any], outputs: Dict[str, Any],
After this call, logps and labels are in per-sequence batch format
``[num_sequences, max_seq_len]`` when the input was packed, or left
unchanged for normal (non-packed) batches.
+
+ For ``task='embedding'`` this also performs the last-valid-token
+ pooling (with padding-free / SP gather awareness) and writes the
+ pooled ``[n_seqs, H]`` tensor to ``outputs['embeddings']``; the raw
+ per-token ``outputs['features']`` is consumed and removed.
"""
sp_strategy = kwargs.get('sp_strategy')
+ task = kwargs.get('task', 'causal_lm')
+ if task == 'embedding':
+ return self._postprocess_embedding(inputs, outputs, sp_strategy=sp_strategy)
if self.framework == 'transformers' and sp_strategy is not None:
return sp_strategy.gather_loss_tensors(inputs, outputs)
return inputs, outputs
+ @staticmethod
+ def _packed_last_indices(position_ids: torch.Tensor, total_len: int) -> torch.Tensor:
+ """For padding-free batches: per-segment last-token indices into a [1, total] sequence."""
+ flat = position_ids.squeeze(0) if position_ids.dim() == 2 else position_ids
+ starts = (flat == 0).nonzero(as_tuple=False).squeeze(-1)
+ end_anchor = torch.tensor([total_len], device=flat.device, dtype=starts.dtype)
+ boundaries = torch.cat([starts, end_anchor])
+ return (boundaries[1:] - 1).long()
+
+ def _postprocess_embedding(self,
+ inputs: Dict[str, Any],
+ outputs: Dict[str, Any],
+ sp_strategy=None) -> tuple[Dict[str, Any], Dict[str, Any]]:
+ """Pool per-token features to per-sequence embeddings (last-valid-token).
+
+ Build a one-hot end-token mask in the un-padded global frame, route it
+ through the same pad+split as ``input_ids`` so it aligns with local
+ features, pool locally, then ``all_reduce`` only the ``[n_seqs, H]``
+ tensor across SP × RP. No feature gather; uniform across
+ DP / Ulysses / zigzag-ring / padding-free.
+ """
+ import torch.distributed as dist
+ from copy import copy
+
+ features = outputs.get('features')
+ assert features is not None
+
+ sp_enabled = (
+ self.framework == 'transformers' and sp_strategy is not None and getattr(sp_strategy, 'enabled', False)
+ and getattr(sp_strategy, 'world_size', 1) > 1)
+
+ ref_pos = sp_strategy.real_position_ids if sp_enabled else inputs['position_ids']
+ if ref_pos.dim() == 3:
+ ref_pos = ref_pos[0]
+ cu_seq_lens_q = inputs.get('cu_seq_lens_q')
+
+ is_packed = (
+ features.shape[0] == 1 and (cu_seq_lens_q is not None or int((ref_pos.reshape(-1) == 0).sum()) > 1))
+
+ device, dtype = features.device, features.dtype
+ T_real = ref_pos.shape[-1]
+
+ if is_packed:
+ if torch.is_tensor(cu_seq_lens_q) and cu_seq_lens_q.numel() >= 2:
+ end_idx = (cu_seq_lens_q[1:].long() - 1).to(device)
+ else:
+ end_idx = self._packed_last_indices(ref_pos, T_real).to(device)
+ n_seqs = end_idx.shape[0]
+ mask = torch.zeros(1, T_real, n_seqs, dtype=dtype, device=device)
+ mask[0, end_idx, torch.arange(n_seqs, device=device)] = 1.0
+ else:
+ B = ref_pos.shape[0]
+ end_idx = (ref_pos >= 0).long().sum(-1) - 1
+ mask = torch.zeros(B, T_real, 1, dtype=dtype, device=device)
+ mask[torch.arange(B, device=device), end_idx, 0] = 1.0
+
+ if sp_enabled:
+ # Route mask through the same pad+split as input_ids to align with local features.
+ rp = sp_strategy.real_position_ids
+ rp_padded = sp_strategy.pad(rp, padding_value=-1, position_ids=rp, dim=-1)
+ mask = sp_strategy.pad(mask, padding_value=0, position_ids=rp, dim=1)
+ mask = sp_strategy.split(mask, dim=1, position_ids=rp_padded)
+
+ embeddings = (
+ torch.einsum('th,tn->nh', features.squeeze(0), mask.squeeze(0)) if is_packed else
+ (features * mask).sum(dim=1))
+
+ if sp_enabled and dist.is_available() and dist.is_initialized():
+ for grp_attr, size_attr in (('_sp_group', 'sp_world_size'), ('_rp_group', 'rp_world_size')):
+ grp = getattr(sp_strategy, grp_attr, None)
+ if grp is not None and getattr(sp_strategy, size_attr, 1) > 1:
+ dist.all_reduce(embeddings, op=dist.ReduceOp.SUM, group=grp)
+
+ outputs = copy(outputs)
+ outputs.pop('features', None)
+ outputs['embeddings'] = embeddings.contiguous()
+ return inputs, outputs
+
def pad_cp(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]:
if self.device_mesh is None:
@@ -262,8 +348,10 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di
return torch.cat(new_inputs, dim=dim)
if cp_size > 1:
- input_ids = split_cp_inputs(input_ids, cu_seqlens_q, dim=1)
- position_ids = split_cp_inputs(position_ids, cu_seqlens_q, dim=1)
+ if position_ids.shape[0] == 1:
+ # mm input_ids will do split inside of the mcore_bridge
+ input_ids = split_cp_inputs(input_ids, cu_seqlens_q, dim=1)
+ position_ids = split_cp_inputs(position_ids, cu_seqlens_q, dim=-1)
# attention_mask = split_cp_inputs(attention_mask, cu_seqlens_q, dim=1)
batch_labels = split_cp_inputs(batch_labels, cu_seqlens_q, dim=1)
@@ -469,6 +557,7 @@ def unpack_packed_sequences(
self,
inputs: Dict[str, Any],
outputs: Optional[Dict[str, Any]] = None,
+ task: str = 'causal_lm',
) -> tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
"""Unpack packed (padding_free) sequences into per-sequence batch format.
@@ -476,7 +565,12 @@ def unpack_packed_sequences(
Unpacks ``labels`` and any present output keys (``logps``, ``logits``)
from ``[1, total_tokens, ...]`` to ``[num_sequences, max_seq_len, ...]``.
Keys that are ``None`` are silently skipped.
+
+ For ``task='embedding'`` the outputs are already pooled to ``[n_seqs, H]``
+ by ``postprocess_tensor_sp``, so this is a no-op.
"""
+ if task == 'embedding':
+ return inputs, outputs
labels = inputs.get('labels')
position_ids = inputs.get('position_ids')
@@ -654,46 +748,14 @@ def collate_fn(self,
return outputs
def postprocess_tensor_cp(self, tensor):
- """All-gather and reconstruct full sequence from CP-split tensor.
-
- Uses load-balanced split pattern: each CP rank holds chunks [rank] and
- [2*cp_size - rank - 1] from the original 2*cp_size chunks.
-
- Only the current rank's slice retains the original tensor (and its
- gradient graph); other ranks' slices are plain copies. This means
- backward through the reconstructed tensor only produces gradients for
- the local chunk, naturally distributing the gradient across CP ranks
- without extra scaling.
+ """All-gather and reconstruct full sequence from a CP load-balanced shard.
- Args:
- tensor: [batch_size, seq_len/cp_size] CP-split tensor
-
- Returns:
- [batch_size, full_seq_len] reconstructed full tensor
+ Thin wrapper over :func:`twinkle.utils.torch_utils.gather_cp_load_balanced`
+ that resolves the CP group via Megatron's ``parallel_state``.
"""
if self.device_mesh.cp_world_size <= 1:
return tensor
-
from megatron.core import parallel_state as mpu
- cp_size = mpu.get_context_parallel_world_size()
- cp_rank = mpu.get_context_parallel_rank()
- cp_group = mpu.get_context_parallel_group()
-
- gathered = [torch.empty_like(tensor) for _ in range(cp_size)]
- torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group)
- gathered[cp_rank] = tensor
-
- batch_size = tensor.shape[0]
- seq_len_per_cp = tensor.shape[1]
- full_seq_len = seq_len_per_cp * cp_size
- chunk_len = full_seq_len // (2 * cp_size)
- half_len = seq_len_per_cp // 2
-
- output = tensor.new_zeros(batch_size, full_seq_len)
- for j in range(cp_size):
- o = gathered[j]
- output[:, j * chunk_len:(j + 1) * chunk_len] = o[:, :half_len]
- reverse_idx = 2 * cp_size - j - 1
- output[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o[:, half_len:]
-
- return output
+
+ from twinkle.utils.torch_utils import gather_cp_load_balanced
+ return gather_cp_load_balanced(tensor, mpu.get_context_parallel_group(), seq_dim=1)
diff --git a/src/twinkle/sampler/base.py b/src/twinkle/sampler/base.py
index d8222ead1..e0c012a2d 100644
--- a/src/twinkle/sampler/base.py
+++ b/src/twinkle/sampler/base.py
@@ -1,7 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from abc import ABC, abstractmethod
from peft import PeftConfig
-from typing import Any, List, Optional, Type, Union
+from typing import Any, AsyncIterator, Dict, List, Optional, Type, Union
import twinkle
from twinkle import remote_function
@@ -47,6 +47,25 @@ def sample(
def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None:
...
+ def astream_one(
+ self,
+ trajectory: Trajectory,
+ sampling_params: Optional[SamplingParams] = None,
+ adapter_name: str = '',
+ adapter_path: Optional[str] = None,
+ *,
+ use_base_model: bool = False,
+ ) -> AsyncIterator[Dict[str, Any]]:
+ """Stream OpenAI-shape delta chunks for a single trajectory.
+
+ Default implementation raises ``NotImplementedError``; backend samplers
+ opt in by overriding (e.g. ``vLLMSampler``).
+
+ Yields:
+ Dicts shaped ``{'index': int, 'delta': {...}, 'finish_reason': ...}``.
+ """
+ raise NotImplementedError(f'{type(self).__name__} does not support streaming')
+
@staticmethod
def _not_encoded(inputs: Any) -> bool:
"""Check if inputs are not yet encoded (i.e., is Trajectory, not InputFeature).
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py
index 4965f7b3d..29b1a73cc 100644
--- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py
+++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py
@@ -337,6 +337,57 @@ async def sample(self,
topk_prompt_logprobs=result_topk_prompt_logprobs,
)
+ async def astream(self,
+ prompt: Union[List[int], str],
+ sampling_params: Union[SamplingParams, Dict[str, Any]],
+ lora_request: Optional[Any] = None,
+ request_id: Optional[str] = None,
+ priority: int = 0,
+ *,
+ multi_modal_data: Optional[Dict[str, Any]] = None,
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
+ disable_lora: bool = False,
+ **kwargs):
+ """Streaming counterpart of :meth:`sample`. Yields raw vLLM ``RequestOutput``
+ deltas as they arrive from the engine — no aggregation.
+
+ Caller is responsible for diffing token_ids across frames.
+ """
+ from vllm.inputs import TextPrompt, TokensPrompt
+
+ if isinstance(sampling_params, dict):
+ sampling_params = SamplingParams.from_dict(sampling_params)
+ vllm_params = sampling_params.to_vllm(**kwargs)
+
+ if request_id is None:
+ request_id = uuid.uuid4().hex
+ if isinstance(prompt, str):
+ prompt = TextPrompt(prompt=prompt)
+ else:
+ prompt = TokensPrompt(prompt_token_ids=prompt)
+ if multi_modal_data:
+ prompt['multi_modal_data'] = multi_modal_data
+ if mm_processor_kwargs:
+ prompt['mm_processor_kwargs'] = mm_processor_kwargs
+
+ if lora_request is not None and not self.enable_lora:
+ logger.warning('lora_request provided but enable_lora is False — ignored')
+ lora_request = None
+ if disable_lora:
+ lora_request = None
+ elif lora_request is None and self._synced_lora_request is not None:
+ lora_request = self._synced_lora_request
+
+ generator = self.engine.generate(
+ prompt=prompt,
+ sampling_params=vllm_params,
+ request_id=request_id,
+ lora_request=lora_request,
+ priority=priority,
+ )
+ async for output in generator:
+ yield output
+
# -----------------------------------------------------------------
# RL-training synced LoRA helpers
# -----------------------------------------------------------------
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
index 79db15db0..ace9fdd63 100644
--- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
+++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
@@ -25,7 +25,7 @@
import os
import threading
from copy import copy
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import Any, AsyncIterator, Dict, List, Optional, Type, Union
from twinkle import DeviceMesh, get_logger, remote_class, remote_function, requires
from twinkle.checkpoint_engine import CheckpointEngineMixin
@@ -251,6 +251,7 @@ async def _sample_single(
else:
feat['input_ids'] = response.prompt_token_ids
feat['labels'] = [-100] * len(response.prompt_token_ids)
+
if not logprobs_only:
# response.sequences contains num_samples sequences for this prompt
sequences = []
@@ -333,13 +334,12 @@ def sample(
sampling_params = copy(sampling_params)
sampling_params.max_tokens = 1
logprobs_only = True
- assert not is_trajectory, 'Logprobs only not supported for Trajectory inputs'
multi_modal_data_list = []
for feat in inputs_list:
multi_modal_data_list.append(self._extract_multi_modal_data(feat))
- if is_trajectory and not logprobs_only:
+ if is_trajectory:
template = self.template
assert template is not None, \
'Use set_template to add a template when trying to input Trajectory'
diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py
index b27ec23d0..7e9d9af8e 100644
--- a/src/twinkle/server/sampler/twinkle_handlers.py
+++ b/src/twinkle/server/sampler/twinkle_handlers.py
@@ -6,9 +6,13 @@
"""
from __future__ import annotations
+import json
+import time
import traceback
+import uuid
from fastapi import Depends, FastAPI, HTTPException, Request
-from typing import TYPE_CHECKING, Callable
+from fastapi.responses import StreamingResponse
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from twinkle_client.common.serialize import deserialize_object
diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py
index b1ab1d213..6c4bdddd2 100644
--- a/src/twinkle/template/__init__.py
+++ b/src/twinkle/template/__init__.py
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .base import Template
from .deepseek_v4 import DeepseekV4Template
-from .qwen import QwenTemplate
from .qwen3_5_vl import Qwen3_5Template
diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py
index 189961c90..bef50c682 100644
--- a/src/twinkle/template/base.py
+++ b/src/twinkle/template/base.py
@@ -1,15 +1,17 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import inspect
+import json
import numpy as np
import os
from collections.abc import Mapping
from copy import copy, deepcopy
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union
from twinkle import remote_class
-from twinkle.data_format import InputFeature, Message, Trajectory
+from twinkle.data_format import InputFeature, Message, Trajectory, user_data_get
from twinkle.hub import HubOperation
from twinkle.utils import load_image, to_device
+from .tools import ToolCallRegistry, trailing_prefix_of
from .utils import TokenizeByRound, transfer_to_standard_message
if TYPE_CHECKING:
@@ -34,7 +36,7 @@ def __init__(self,
model_id: str,
use_chat_template: bool = True,
max_length: Optional[int] = 8192,
- truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ truncation_strategy: Literal['raise', 'left', 'right', 'split', 'delete'] = 'raise',
default_system: Optional[str] = None,
enable_thinking: bool = True,
**kwargs):
@@ -72,31 +74,138 @@ def __init__(self,
def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]:
"""Parse tool calls from the assistant's decoded output.
- Dispatches by model family on ``self.model_id``; the actual
- wire-format logic lives in :mod:`.tool_call_parser`.
+ Polls registered :class:`ToolCallParser` in order; first parser whose
+ ``detect`` matches takes ownership and produces the result. Other
+ parsers are not invoked on the same text — prevents nested re-extraction.
"""
- mid = (self.model_id or '').lower()
- if 'qwen' in mid:
- from .qwen import QwenTemplate
- return QwenTemplate.parse(self, decoded)
- if 'deepseek' in mid:
- from .deepseek_v4 import DeepseekV4Template
- return DeepseekV4Template.parse(self, decoded)
- # TODO: Other models (Llama3, OpenAI JSON, …) — add a parser in
- # ``tool_call_parser.py`` and extend this dispatch.
- return []
+ parser = ToolCallRegistry.detect_first(decoded or '')
+ return parser.parse(decoded) if parser else []
def clean_tool_call(self, decoded: str) -> str:
- """Strip family-specific tool-call markup from assistant text."""
- mid = (self.model_id or '').lower()
- if 'qwen' in mid:
- from .qwen import QwenTemplate
- return QwenTemplate.clean(self, decoded)
- if 'deepseek' in mid:
- from .deepseek_v4 import DeepseekV4Template
- return DeepseekV4Template.clean(self, decoded)
- # TODO: Other models
- return (decoded or '').rstrip()
+ """Strip tool-call markup using the same parser that ``parse_tool_call`` would pick."""
+ parser = ToolCallRegistry.detect_first(decoded or '')
+ return parser.clean(decoded) if parser else (decoded or '').rstrip()
+
+ def parse_tool_call_stream(
+ self,
+ state: Dict[str, Any],
+ new_text: str,
+ finished: bool = False,
+ ) -> List[Dict[str, Any]]:
+ """Convert incremental decoded text into OpenAI streaming ``delta`` parts.
+
+ Selects a parser once (cached on ``state``) by ``model_id``. If that
+ parser declares ``open_marker``/``close_marker`` (e.g. Hermes/Qwen),
+ runs the generic block-buffer state machine: holds back partial
+ markers, parses each closed block via ``parser.parse``, emits one
+ ``tool_calls`` delta per parsed call. Otherwise streams plain content.
+
+ Args:
+ state: Per-sequence opaque dict; caller allocates ``{}`` once.
+ new_text: Incremental decoded text since the previous call.
+ finished: True on the final call so partial buffers can flush.
+
+ Returns:
+ List of delta dicts; each carries at most one of ``content`` /
+ ``tool_calls``.
+ """
+ parser = state.get('parser')
+ if 'parser' not in state:
+ parser = ToolCallRegistry.select_for_model(self.model_id)
+ state['parser'] = parser
+ if parser is None or not parser.open_marker:
+ return [{'content': new_text}] if new_text else []
+ return self._stream_marker_blocks(state, new_text, finished, parser)
+
+ def _stream_marker_blocks(
+ self,
+ state: Dict[str, Any],
+ new_text: str,
+ finished: bool,
+ parser,
+ ) -> List[Dict[str, Any]]:
+ """Generic open/close marker streaming protocol.
+
+ Buffers partial markup until ``parser.close_marker`` arrives, then
+ parses the block via ``parser.parse``. Used by Hermes/Qwen and any
+ future block-style format (Mistral ``[TOOL_CALLS]``, etc.).
+ """
+ open_marker, close_marker = parser.open_marker, parser.close_marker
+ state.setdefault('pending', '')
+ state.setdefault('tc_count', 0)
+ if new_text:
+ state['pending'] += new_text
+
+ events: List[Dict[str, Any]] = []
+ while True:
+ buf = state['pending']
+ if not buf:
+ break
+ open_idx = buf.find(open_marker)
+ if open_idx == -1:
+ partial = 0 if finished else trailing_prefix_of(buf, open_marker)
+ emit = buf[:-partial] if partial else buf
+ state['pending'] = buf[-partial:] if partial else ''
+ if emit:
+ events.append({'content': emit})
+ break
+ if open_idx > 0:
+ events.append({'content': buf[:open_idx]})
+ state['pending'] = buf[open_idx:]
+ continue
+ close_idx = buf.find(close_marker)
+ if close_idx == -1:
+ if finished:
+ # EOF with unclosed block — let parser.parse handle the truncation.
+ try:
+ parsed = parser.parse(buf) or []
+ except Exception:
+ import logging
+ logging.getLogger(__name__).exception(
+ 'tool-call parse failed for unclosed streamed block; emitting as raw content')
+ events.append({'content': buf})
+ state['pending'] = ''
+ break
+ if parsed:
+ for tc in parsed:
+ events.append({'tool_calls': [self._format_tc_delta(state, tc)]})
+ else:
+ events.append({'content': buf})
+ state['pending'] = ''
+ break
+ block_end = close_idx + len(close_marker)
+ block = buf[:block_end]
+ try:
+ parsed = parser.parse(block) or []
+ except Exception:
+ logger.warn('tool-call parse failed for streamed block; emitting as raw content')
+ events.append({'content': block})
+ state['pending'] = buf[block_end:]
+ continue
+ for tc in parsed:
+ events.append({'tool_calls': [self._format_tc_delta(state, tc)]})
+ state['pending'] = buf[block_end:]
+ return events
+
+ @staticmethod
+ def _format_tc_delta(state: Dict[str, Any], tc: Dict[str, Any]) -> Dict[str, Any]:
+ """Format a parsed tool_call dict as an OpenAI streaming delta entry.
+
+ ``arguments`` is encoded as JSON string for the wire format (OpenAI
+ streaming spec); ``index`` and ``id`` are auto-assigned from ``state``.
+ """
+ fn = dict(tc.get('function') or {})
+ args = fn.get('arguments')
+ if isinstance(args, dict):
+ fn['arguments'] = json.dumps(args, ensure_ascii=False)
+ delta = {
+ 'index': state['tc_count'],
+ 'id': tc.get('id') or f'call_{state["tc_count"]}',
+ 'type': tc.get('type') or 'function',
+ 'function': fn,
+ }
+ state['tc_count'] += 1
+ return delta
@property
def tokenizer(self):
@@ -250,6 +359,10 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]:
message['reasoning_content'] = reasoning_content
message['content'] = new_content
+ # Always emit string (never None/missing) — keeps PyArrow struct schema
+ # stable across shards; empty string renders identically to None in jinja.
+ if not isinstance(message.get('reasoning_content'), str):
+ message['reasoning_content'] = ''
result.append(message)
@@ -278,6 +391,8 @@ def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeatur
result['labels'] = result['labels'][:self.max_length]
if 'mm_token_type_ids' in result:
result['mm_token_type_ids'] = result['mm_token_type_ids'][..., :self.max_length]
+ else:
+ raise ValueError(f'Unsupported truncation_strategy={strategy!r}.')
return InputFeature(**result)
def set_mm_position_ids(self, input_feature: InputFeature):
@@ -306,6 +421,12 @@ def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]:
results.append(InputFeature(**feat))
return results
+ # Drop oversized samples entirely; downstream must tolerate empty list (sample skipped).
+ if strategy == 'delete':
+ if len(input_feature['input_ids']) > self.max_length:
+ return []
+ return [input_feature]
+
# left/right/raise
return [self._truncate_feature(input_feature, strategy)]
@@ -491,7 +612,9 @@ def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]:
trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios)
if not self.is_mm:
for message in trajectory['messages']:
- message['content'] = message['content'][0]['text']
+ c = message.get('content')
+ if isinstance(c, list):
+ message['content'] = c[0]['text'] if c else ''
return [trajectory]
def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs):
@@ -506,6 +629,20 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo
k: v
for k, v in b.items() if v is not None
} for b in msg['content'] if isinstance(b, dict)]
+ for msg in messages:
+ tcs = msg.get('tool_calls')
+ if isinstance(tcs, str):
+ tcs = json.loads(tcs) if tcs else []
+ msg['tool_calls'] = tcs
+ if not tcs:
+ continue
+ new_tcs = []
+ for tc in tcs:
+ fn = tc['function']
+ args = fn['arguments']
+ decoded = json.loads(args) if args.strip() else {}
+ new_tcs.append({**tc, 'function': {**fn, 'arguments': decoded}})
+ msg['tool_calls'] = new_tcs
# ``tool_calls`` / ``tools`` are already OpenAI-shaped (see
# :mod:`twinkle.data_format.message`); pass them through verbatim.
tools = list(trajectory.get('tools') or [])
@@ -561,10 +698,20 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo
**kwargs)
return inputs
+ @staticmethod
+ def _get_train_indices(trajectory: Trajectory) -> Optional[Set[int]]:
+ """Extract key-round assistant indices from trajectory's packed ``user_data``."""
+ kr = user_data_get(trajectory.get('user_data'), 'key_rounds')
+ if isinstance(kr, list) and kr:
+ return set(kr)
+ return None
+
def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs) -> InputFeature:
"""Encode a single trajectory's messages into InputFeature."""
labels = None
input_ids = None
+ # key-round selective training
+ train_indices = self._get_train_indices(trajectory) if not add_generation_prompt else None
if self.use_chat_template:
if add_generation_prompt:
# For inference: just get input_ids with generation prompt, no labels needed
@@ -574,6 +721,13 @@ def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool =
if hasattr(input_ids, 'squeeze'):
input_ids = input_ids.squeeze(0)
labels = np.full_like(input_ids, -100) # No labels for inference
+ elif train_indices is not None:
+ # key-round-only: always use TokenizeByRound with filtered indices
+ if kwargs.get('tokenize', True):
+ input_ids, labels, encoded = TokenizeByRound.tokenize_with_assistant_labels(
+ self.tokenizer, self._apply_chat_template, trajectory, train_indices=train_indices, **kwargs)
+ else:
+ encoded = self._apply_chat_template(trajectory, **kwargs)
elif self._template_support_assistant_tokens_mask:
encoded = self._apply_chat_template(
trajectory, return_assistant_tokens_mask=kwargs.get('tokenize', True), **kwargs)
diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py
index 3b17ce819..7a396f0dc 100644
--- a/src/twinkle/template/deepseek_v4.py
+++ b/src/twinkle/template/deepseek_v4.py
@@ -110,7 +110,7 @@ def __init__(
model_id: str,
use_chat_template: bool = True,
max_length: Optional[int] = 8192,
- truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ truncation_strategy: Literal['raise', 'left', 'right', 'split', 'delete'] = 'raise',
default_system: Optional[str] = None,
enable_thinking: bool = True,
**kwargs,
diff --git a/src/twinkle/template/qwen3_5_vl.py b/src/twinkle/template/qwen3_5_vl.py
index c8332f49b..b61265f5e 100644
--- a/src/twinkle/template/qwen3_5_vl.py
+++ b/src/twinkle/template/qwen3_5_vl.py
@@ -7,8 +7,7 @@
from twinkle import remote_class, requires
from twinkle.data_format import InputFeature
-from twinkle.template.base import ImageInput, VideoInput
-from twinkle.template.qwen import QwenTemplate
+from twinkle.template.base import ImageInput, Template, VideoInput
from twinkle.template.utils import get_inputs_embeds_hf
_ROPE_INDEX_CACHE: Dict[str, Callable] = {}
@@ -31,7 +30,7 @@ def _build_rope_index_func(config) -> Callable:
@remote_class()
-class Qwen3_5Template(QwenTemplate):
+class Qwen3_5Template(Template):
"""
Processor for Qwen VL series.
@@ -44,8 +43,15 @@ def __init__(self, *args, **kwargs):
# Fix upstream Qwen3 chat_template parse bugs (orphan handling).
# Deferred import to avoid cycles; idempotent across Ray actor re-init.
from twinkle.patch import apply_patch
- from twinkle.patch.qwen3_chat_template import Qwen3ChatTemplate
+ from twinkle.patch.qwen3_chat_template import Qwen3AllowToolTailTemplate, Qwen3ChatTemplate
apply_patch(self.tokenizer, Qwen3ChatTemplate)
+ # Allow ScoreFilter to render multi-turn agent prefixes ending in `tool`.
+ apply_patch(self.tokenizer, Qwen3AllowToolTailTemplate)
+ # Qwen3VLProcessor carries its own chat_template; _apply_chat_template
+ # routes through self.processor, so the patch must be applied there too.
+ if self.processor is not self.tokenizer:
+ apply_patch(self.processor, Qwen3ChatTemplate)
+ apply_patch(self.processor, Qwen3AllowToolTailTemplate)
self._patch_size: Optional[int] = None
self._merge_size: Optional[int] = None
self._init_vision_config()
diff --git a/src/twinkle/template/tools/__init__.py b/src/twinkle/template/tools/__init__.py
new file mode 100644
index 000000000..243774bdf
--- /dev/null
+++ b/src/twinkle/template/tools/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tool-call parser registry.
+
+Importing this package auto-registers every parser. Order matters:
+narrower / stronger formats first so round-robin detection prefers them
+over weaker fallbacks.
+"""
+from .base import ToolCallParser, ToolCallRegistry, trailing_prefix_of
+from .cline import ClineParser
+from .qwen import HermesQwenParser
+from .react import ReActParser
+from .vcp import VCPParser
+
+# Order: strongest/most-specific markers first. Hermes owns ````
+# (also denied by Cline), so its detection wins for shared-XML inputs.
+ToolCallRegistry.register(HermesQwenParser())
+ToolCallRegistry.register(ClineParser())
+ToolCallRegistry.register(VCPParser())
+ToolCallRegistry.register(ReActParser())
+
+__all__ = [
+ 'ToolCallParser',
+ 'ToolCallRegistry',
+ 'trailing_prefix_of',
+ 'HermesQwenParser',
+ 'ClineParser',
+ 'VCPParser',
+ 'ReActParser',
+]
diff --git a/src/twinkle/template/tools/base.py b/src/twinkle/template/tools/base.py
new file mode 100644
index 000000000..9f0855fca
--- /dev/null
+++ b/src/twinkle/template/tools/base.py
@@ -0,0 +1,87 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional
+
+
+class ToolCallParser(ABC):
+ """Single-format tool-call parser."""
+
+ name: str = ''
+ open_marker: Optional[str] = None
+ close_marker: Optional[str] = None
+
+ def matches_model(self, model_id: str) -> bool:
+ """Return True if this parser is the canonical choice for ``model_id``.
+
+ Used for streaming where we must commit to a parser before any text
+ has arrived. Default False — parser is text-detection-only.
+ """
+ return False
+
+ @abstractmethod
+ def detect(self, text: str) -> bool:
+ """Cheap pre-check: does ``text`` carry this format's markup?"""
+
+ @abstractmethod
+ def parse(self, text: str) -> List[Dict[str, Any]]:
+ """Return OpenAI-shape tool_calls. ``arguments`` is a dict (jinja-friendly)."""
+
+ @abstractmethod
+ def clean(self, text: str) -> str:
+ """Strip parser-specific markup; return plain content text."""
+
+ def detect_result(self, text: str) -> bool:
+ """Does ``text`` look like a tool-result message for this protocol?"""
+ return False
+
+ def parse_result(self, text: str) -> str:
+ """Strip protocol-specific result prefix; return the raw tool output body."""
+ return text
+
+
+class ToolCallRegistry:
+ """Global ordered registry of :class:`ToolCallParser` instances."""
+
+ _parsers: List[ToolCallParser] = []
+
+ @classmethod
+ def register(cls, parser: ToolCallParser) -> ToolCallParser:
+ for p in cls._parsers:
+ if p.name == parser.name:
+ return p
+ cls._parsers.append(parser)
+ return parser
+
+ @classmethod
+ def parsers(cls) -> List[ToolCallParser]:
+ return list(cls._parsers)
+
+ @classmethod
+ def select_for_model(cls, model_id: Optional[str]) -> Optional[ToolCallParser]:
+ mid = (model_id or '').lower()
+ for p in cls._parsers:
+ if p.matches_model(mid):
+ return p
+ return None
+
+ @classmethod
+ def detect_first(cls, text: str) -> Optional[ToolCallParser]:
+ if not text:
+ return None
+ for p in cls._parsers:
+ if p.detect(text):
+ return p
+ return None
+
+
+def trailing_prefix_of(buf: str, marker: str) -> int:
+ """Length of trailing chars of ``buf`` that form a strict prefix of ``marker``.
+
+ Used by streaming protocols to hold back the tail when it could be the
+ start of an upcoming open tag, preventing mid-marker splits.
+ """
+ upper = min(len(marker) - 1, len(buf))
+ for k in range(upper, 0, -1):
+ if buf.endswith(marker[:k]):
+ return k
+ return 0
diff --git a/src/twinkle/template/tools/cline.py b/src/twinkle/template/tools/cline.py
new file mode 100644
index 000000000..3f14694e6
--- /dev/null
+++ b/src/twinkle/template/tools/cline.py
@@ -0,0 +1,164 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Cline / OpenClaw text-embedded XML tool-call format.
+
+Wire format (Layer-B agent app protocol — lives in plain ``content``,
+not in the OpenAI ``tool_calls`` field):
+
+ src/foo.py
+
+ ls -la
+ false
+
+
+Detection is **structural** (no hardcoded tool-name whitelist):
+
+* outer tag is snake_case ``[a-z][a-z0-9_]*`` and not in :data:`_DENY`
+* outer block contains at least one nested ``VAL`` child
+
+Streaming: ``open_marker``/``close_marker`` are ``None`` because the
+outer tag varies per call. The base ``parse_tool_call_stream`` therefore
+falls back to plain content passthrough; recognised blocks are extracted
+only on full-text :meth:`parse` (e.g. by ``AgentTraceFilter`` after
+trajectory assembly).
+"""
+from __future__ import annotations
+
+import re
+from typing import Any, Dict, List
+
+from .base import ToolCallParser
+
+# Common HTML-like / template tags that are NOT Cline tool calls. Outer
+# tags falling here are skipped to prevent false positives.
+_DENY = frozenset({
+ # twinkle-internal / model-internal markers
+ 'think',
+ 'answer',
+ 'tool_call',
+ 'tool_response',
+ 'function',
+ 'parameter',
+ 'parameters',
+ 'tools',
+ 'tool',
+ 'system',
+ 'user',
+ 'assistant',
+ 'message',
+ 'messages',
+ 'content',
+ 'response',
+ 'output',
+ 'role',
+ 'reasoning_content',
+ # html / markdown
+ 'p',
+ 'a',
+ 'b',
+ 'i',
+ 'em',
+ 'strong',
+ 'div',
+ 'span',
+ 'pre',
+ 'code',
+ 'br',
+ 'hr',
+ 'ul',
+ 'ol',
+ 'li',
+ 'h1',
+ 'h2',
+ 'h3',
+ 'h4',
+ 'h5',
+ 'h6',
+ 'table',
+ 'tr',
+ 'td',
+ 'th',
+ 'tbody',
+ 'thead',
+ 'img',
+ 'video',
+ 'audio',
+})
+
+# Outer tool-call block: matched-pair via backreference. Body is non-greedy.
+_BLOCK_RE = re.compile(r'<(?P[a-z][a-z0-9_]*)>(?P[\s\S]*?)(?P=tool)>')
+# Inner parameter: matched-pair via backreference.
+_PARAM_RE = re.compile(r'<(?P[a-z][a-z0-9_]*)>(?P[\s\S]*?)(?P=key)>')
+
+# Cline tool-result: [tool_name for 'path/args'] Result:
+_RESULT_RE = re.compile(
+ r'^\[(?P[a-z][a-z0-9_]*)\s+for\s+\'[^\']*\'\]\s*Result:\s*',
+ re.DOTALL,
+)
+
+
+class ClineParser(ToolCallParser):
+ name = 'cline'
+ # Outer tag varies per tool — no fixed marker; streaming uses passthrough.
+ open_marker = None
+ close_marker = None
+
+ def matches_model(self, model_id: str) -> bool:
+ # Cline is an app-level prompt protocol, not bound to any model family.
+ return False
+
+ def detect(self, text: str) -> bool:
+ if not text or '<' not in text:
+ return False
+ for m in _BLOCK_RE.finditer(text):
+ if m.group('tool') in _DENY:
+ continue
+ if _PARAM_RE.search(m.group('body')):
+ return True
+ return False
+
+ def parse(self, text: str) -> list[dict[str, Any]]:
+ calls: list[dict[str, Any]] = []
+ for m in _BLOCK_RE.finditer(text or ''):
+ tool = m.group('tool')
+ if tool in _DENY:
+ continue
+ args: dict[str, Any] = {}
+ for pm in _PARAM_RE.finditer(m.group('body')):
+ args[pm.group('key')] = pm.group('val').strip()
+ if not args:
+ continue
+ calls.append({
+ 'type': 'function',
+ 'function': {
+ 'name': tool,
+ 'arguments': args
+ },
+ })
+ return calls
+
+ def clean(self, text: str) -> str:
+ if not text:
+ return text or ''
+ spans: list[tuple] = []
+ for m in _BLOCK_RE.finditer(text):
+ if m.group('tool') in _DENY:
+ continue
+ if not _PARAM_RE.search(m.group('body')):
+ continue
+ spans.append((m.start(), m.end()))
+ if not spans:
+ return text.rstrip()
+ out: list[str] = []
+ last = 0
+ for s, e in spans:
+ out.append(text[last:s])
+ last = e
+ out.append(text[last:])
+ return ''.join(out).rstrip()
+
+ def detect_result(self, text: str) -> bool:
+ return bool(_RESULT_RE.match(text or ''))
+
+ def parse_result(self, text: str) -> str:
+ m = _RESULT_RE.match(text or '')
+ return text[m.end():] if m else text
diff --git a/src/twinkle/template/qwen.py b/src/twinkle/template/tools/qwen.py
similarity index 61%
rename from src/twinkle/template/qwen.py
rename to src/twinkle/template/tools/qwen.py
index 4c68ab3a8..12361b737 100644
--- a/src/twinkle/template/qwen.py
+++ b/src/twinkle/template/tools/qwen.py
@@ -3,21 +3,28 @@
import re
from typing import Any, Dict, List
-from twinkle import remote_class
-from twinkle.template import Template
+from .base import ToolCallParser
-@remote_class()
-class QwenTemplate(Template):
+class HermesQwenParser(ToolCallParser):
+ name = 'hermes_qwen'
+ open_marker = ''
+ close_marker = ''
_BLOCK_RE = re.compile(r'\s*([\s\S]*?)\s*(?:|\Z)')
_FUNCTION_RE = re.compile(r']+)>([\s\S]*?)')
_PARAMETER_RE = re.compile(r']+)>\s*([\s\S]*?)\s*')
_STRIP_RE = re.compile(r'[\s\S]*?(?:|\Z)')
- def parse(self, decoded: str) -> List[Dict[str, Any]]:
+ def matches_model(self, model_id: str) -> bool:
+ return 'qwen' in model_id
+
+ def detect(self, text: str) -> bool:
+ return self.open_marker in text
+
+ def parse(self, text: str) -> List[Dict[str, Any]]:
calls: List[Dict[str, Any]] = []
- for block_m in self._BLOCK_RE.finditer(decoded or ''):
+ for block_m in self._BLOCK_RE.finditer(text or ''):
block = block_m.group(1)
func_m = self._FUNCTION_RE.search(block)
if func_m:
@@ -37,7 +44,6 @@ def parse(self, decoded: str) -> List[Dict[str, Any]]:
},
})
continue
- # JSON fallback: ``{"name": ..., "arguments": ...}`` inside the block.
try:
data = json.loads(block)
except json.JSONDecodeError:
@@ -60,26 +66,5 @@ def parse(self, decoded: str) -> List[Dict[str, Any]]:
})
return calls
- def clean(self, decoded: str) -> str:
- return self._STRIP_RE.sub('', decoded or '').rstrip()
-
- def parse_tool_call(self, decoded: str) -> List[Dict[str, Any]]:
- """Parse tool calls from the assistant's decoded output.
-
- Dispatches by model family on ``self.model_id``; the actual
- wire-format logic lives in :mod:`.tool_call_parser`.
- """
- mid = (self.model_id or '').lower()
- if 'qwen' in mid:
- return self.parse(decoded)
- # TODO: Other models (Llama3, OpenAI JSON, …) — add a parser in
- # ``tool_call_parser.py`` and extend this dispatch.
- return []
-
- def clean_tool_call(self, decoded: str) -> str:
- """Strip family-specific tool-call markup from assistant text."""
- mid = (self.model_id or '').lower()
- if 'qwen' in mid:
- return self.clean(decoded)
- # TODO: Other models
- return (decoded or '').rstrip()
+ def clean(self, text: str) -> str:
+ return self._STRIP_RE.sub('', text or '').rstrip()
diff --git a/src/twinkle/template/tools/react.py b/src/twinkle/template/tools/react.py
new file mode 100644
index 000000000..16e369cd6
--- /dev/null
+++ b/src/twinkle/template/tools/react.py
@@ -0,0 +1,34 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import re
+from typing import Any, Dict, List
+
+from .base import ToolCallParser
+
+_ACTION_RE = re.compile(
+ r'^\s*Action\s*:\s*(?P[\w\-./]+)\s*\[(?P.*?)\]\s*$',
+ re.MULTILINE,
+)
+
+
+class ReActParser(ToolCallParser):
+ name = 'react'
+
+ def detect(self, text: str) -> bool:
+ return bool(_ACTION_RE.search(text or ''))
+
+ def parse(self, text: str) -> List[Dict[str, Any]]:
+ calls: List[Dict[str, Any]] = []
+ for m in _ACTION_RE.finditer(text or ''):
+ calls.append({
+ 'type': 'function',
+ 'function': {
+ 'name': m.group('name'),
+ 'arguments': {
+ 'input': m.group('args')
+ },
+ },
+ })
+ return calls
+
+ def clean(self, text: str) -> str:
+ return _ACTION_RE.sub('', text or '').rstrip()
diff --git a/src/twinkle/template/tools/vcp.py b/src/twinkle/template/tools/vcp.py
new file mode 100644
index 000000000..5e030f9d5
--- /dev/null
+++ b/src/twinkle/template/tools/vcp.py
@@ -0,0 +1,65 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import re
+from typing import Any, Dict, List
+
+from .base import ToolCallParser
+
+_VCP_OPEN = '<<<[TOOL_REQUEST]>>>'
+_VCP_CLOSE = '<<<[END_TOOL_REQUEST]>>>'
+
+_VCP_BLOCK_RE = re.compile(
+ r'<<<\[TOOL_REQUEST\]>>>(.*?)<<<\[END_TOOL_REQUEST\]>>>',
+ re.DOTALL,
+)
+
+# `「始ESCAPE」...「末ESCAPE」` is the nesting-safe variant; pair them strictly
+# so an escaped value is not closed by a bare `「末」` from an inner block.
+_VCP_KV_RE = re.compile(
+ r'(?P[A-Za-z_]\w*)\s*:\s*'
+ r'(?:「始ESCAPE」(?P.*?)「末ESCAPE」'
+ r'|「始」(?P.*?)「末」)',
+ re.DOTALL,
+)
+
+
+class VCPParser(ToolCallParser):
+ """VCPChat / VCPSystem custom tool-call format.
+
+ Outer markers ``<<<[TOOL_REQUEST]>>> ... <<<[END_TOOL_REQUEST]>>>`` wrap
+ one call; parameters use full-width brackets ``「始」value「末」`` (escape
+ variant ``「始ESCAPE」...「末ESCAPE」`` permits nested outer markers).
+ The canonical function name lives in the ``tool_name`` field.
+ """
+
+ name = 'vcp'
+ open_marker = _VCP_OPEN
+ close_marker = _VCP_CLOSE
+
+ def detect(self, text: str) -> bool:
+ return _VCP_OPEN in (text or '')
+
+ def parse(self, text: str) -> List[Dict[str, Any]]:
+ calls: List[Dict[str, Any]] = []
+ for block in _VCP_BLOCK_RE.findall(text or ''):
+ args: Dict[str, Any] = {}
+ name = ''
+ for m in _VCP_KV_RE.finditer(block):
+ k = m.group('key')
+ v = m.group('val_esc') if m.group('val_esc') is not None else m.group('val')
+ if k == 'tool_name':
+ name = (v or '').strip()
+ else:
+ args[k] = v
+ if not name:
+ continue
+ calls.append({
+ 'type': 'function',
+ 'function': {
+ 'name': name,
+ 'arguments': args,
+ },
+ })
+ return calls
+
+ def clean(self, text: str) -> str:
+ return _VCP_BLOCK_RE.sub('', text or '').rstrip()
diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py
index 72975d78b..5648a4770 100644
--- a/src/twinkle/template/utils.py
+++ b/src/twinkle/template/utils.py
@@ -193,7 +193,10 @@ class TokenizeByRound:
"""
@staticmethod
- def tokenize_with_assistant_labels(tokenizer: 'PreTrainedTokenizer', encode_func: Callable, trajectory: Trajectory,
+ def tokenize_with_assistant_labels(tokenizer: 'PreTrainedTokenizer',
+ encode_func: Callable,
+ trajectory: Trajectory,
+ train_indices: Optional[set] = None,
**kwargs) -> Tuple[List[int], List[int], Dict[str, Any]]:
"""Tokenize trajectory and generate labels for assistant turns.
@@ -201,6 +204,8 @@ def tokenize_with_assistant_labels(tokenizer: 'PreTrainedTokenizer', encode_func
tokenizer: The tokenizer (unused, kept for interface compatibility).
encode_func: Function to encode a trajectory. Must support add_generation_prompt.
trajectory: The trajectory containing messages.
+ train_indices: If provided, only label assistant messages whose
+ message index is in this set. ``None`` means label all.
Returns:
Tuple of (input_ids, labels, extra_encoded_fields).
@@ -225,6 +230,8 @@ def tokenize_with_assistant_labels(tokenizer: 'PreTrainedTokenizer', encode_func
for i, msg in enumerate(messages):
if msg['role'] != 'assistant':
continue
+ if train_indices is not None and i not in train_indices:
+ continue
# Get position AFTER assistant prefix:
# encode(messages[:i], add_generation_prompt=True) includes the prefix
diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py
index ba3b63e3a..a235c136c 100644
--- a/src/twinkle/utils/parallel.py
+++ b/src/twinkle/utils/parallel.py
@@ -87,6 +87,55 @@ def _try_create_claim(path: str, session: str, payload: str) -> bool:
return True
+class PosixFileLock:
+ """POSIX advisory file lock with persistent fd for repeated acquire/release.
+
+ Fork-safe: reopens its fd lazily when used from a child process so each
+ worker owns its own descriptor.
+ """
+
+ def __init__(self, path: str):
+ import fcntl
+ self._path = path
+ self._fcntl = fcntl
+ self._fd = open(path, 'w')
+ self._pid = os.getpid()
+
+ def _ensure_fd(self):
+ # After fork, child must reopen so it doesn't share parent's fd state.
+ pid = os.getpid()
+ if pid != self._pid:
+ self._fd = open(self._path, 'w')
+ self._pid = pid
+
+ def acquire(self):
+ self._ensure_fd()
+ self._fcntl.flock(self._fd, self._fcntl.LOCK_EX)
+
+ def release(self):
+ self._fcntl.flock(self._fd, self._fcntl.LOCK_UN)
+
+ def close(self):
+ self._fd.close()
+
+ def __enter__(self):
+ self.acquire()
+ return self
+
+ def __exit__(self, *exc):
+ self.release()
+
+ def __getstate__(self):
+ return {'_path': self._path}
+
+ def __setstate__(self, state):
+ import fcntl
+ self._path = state['_path']
+ self._fcntl = fcntl
+ self._fd = open(self._path, 'w')
+ self._pid = os.getpid()
+
+
@contextmanager
def processing_lock(lock_file: str):
"""A file lock to prevent parallel operations to one file.
diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py
index deb788dbb..a2aa8ad9d 100644
--- a/src/twinkle/utils/torch_utils.py
+++ b/src/twinkle/utils/torch_utils.py
@@ -268,6 +268,45 @@ def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200
return torch.stack(padded_tensors, dim=0)
+def gather_cp_load_balanced(tensor: 'torch.Tensor', cp_group, seq_dim: int = 1) -> 'torch.Tensor':
+ """All-gather a CP-load-balanced shard along ``seq_dim`` into the full sequence.
+
+ Inverse of :func:`split_cp_inputs`: each CP rank ``r`` holds chunks ``[r, 2*cp - r - 1]``
+ of the original ``2*cp`` sequence chunks. The local rank's slice keeps autograd;
+ other ranks' slices are detached copies, so backward through the gathered tensor
+ only produces gradients for the local chunk.
+ """
+ import torch
+ cp_size = cp_group.size()
+ if cp_size <= 1:
+ return tensor
+ cp_rank = torch.distributed.get_rank(group=cp_group)
+ gathered = [torch.empty_like(tensor) for _ in range(cp_size)]
+ torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group)
+ gathered[cp_rank] = tensor
+ seq_local = tensor.shape[seq_dim]
+ half_len = seq_local // 2
+ full_seq = seq_local * cp_size
+ chunk_len = full_seq // (2 * cp_size)
+ out_shape = list(tensor.shape)
+ out_shape[seq_dim] = full_seq
+ output = tensor.new_zeros(*out_shape)
+ for j in range(cp_size):
+ o = gathered[j]
+ front = [slice(None)] * tensor.ndim
+ front[seq_dim] = slice(j * chunk_len, (j + 1) * chunk_len)
+ rev = 2 * cp_size - j - 1
+ back = [slice(None)] * tensor.ndim
+ back[seq_dim] = slice(rev * chunk_len, (rev + 1) * chunk_len)
+ local_front = [slice(None)] * tensor.ndim
+ local_front[seq_dim] = slice(0, half_len)
+ local_back = [slice(None)] * tensor.ndim
+ local_back[seq_dim] = slice(half_len, seq_local)
+ output[tuple(front)] = o[tuple(local_front)]
+ output[tuple(back)] = o[tuple(local_back)]
+ return output
+
+
def split_cp_inputs(inputs: 'torch.Tensor', cu_seqlens: Optional['torch.Tensor'], dim: int):
import torch
from megatron.core import mpu
diff --git a/src/twinkle_agentic/data_format/__init__.py b/src/twinkle_agentic/data_format/__init__.py
index 6298015c8..008c3ed7c 100644
--- a/src/twinkle_agentic/data_format/__init__.py
+++ b/src/twinkle_agentic/data_format/__init__.py
@@ -1 +1,2 @@
from .chunks import Chunk, Chunks
+from .score import RoundContext, Scorer, ScoreResult
diff --git a/src/twinkle_agentic/data_format/score.py b/src/twinkle_agentic/data_format/score.py
new file mode 100644
index 000000000..205509400
--- /dev/null
+++ b/src/twinkle_agentic/data_format/score.py
@@ -0,0 +1,35 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Protocol
+
+
+@dataclass
+class RoundContext:
+ """Per-round payload passed to scorers."""
+ row_idx: int
+ rnd_idx: int
+ asst_idx: int
+ row: Dict[str, Any]
+ intent: Optional[str]
+ messages: List[Dict[str, Any]]
+ context_messages: List[Dict[str, Any]]
+ cond_ids: List[int]
+ n_prompt: int
+ asst_ids: List[int]
+ asst_text: str
+ user_prompt: str
+ features: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class ScoreResult:
+ score: Optional[float] = None
+ passed: bool = True
+ extras: Dict[str, Any] = field(default_factory=dict)
+
+
+class Scorer(Protocol):
+ name: str
+ requires_logprobs: bool
+
+ def score(self, contexts: List[RoundContext]) -> List[ScoreResult]:
+ ...
diff --git a/src/twinkle_agentic/preprocessor/__init__.py b/src/twinkle_agentic/preprocessor/__init__.py
new file mode 100644
index 000000000..b8391f9ea
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/__init__.py
@@ -0,0 +1,69 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import json
+import time
+from typing import Any, Callable, Dict, List, Optional
+import os
+from twinkle.preprocessor import Preprocessor
+from twinkle.utils import get_logger
+from twinkle.utils.parallel import PosixFileLock
+from .data_juicer import FixUnicodeFilter, RemoveRepeatSentencesFilter, SpecialCharsFilter, TokenNumFilter
+from .dead_loop_filter import DeadLoopFilter
+from .dedup_filter import DedupFilter
+from .hard_filter import HardFilter
+from .intent_classifier import IntentClassifier
+from .llm_backend import LLMBackend, OpenAIBackend, SamplerBackend # noqa: F401
+from .message_normalizer import MessageNormalizer # noqa: F401
+from .message_sanity import MessageSanityFilter
+from .model_filter import ModelFilter
+from .pii_presidio_filter import PIIPresidioFilter
+from .refuse_filter import RefuseFilter
+from .score_filter import ScoreFilter
+from .token_soup import TokenSoupFilter
+
+logger = get_logger(only_local_master=False)
+
+
+class QualityPreprocessor(Preprocessor):
+ """Thin pipeline runner: accepts a list of callables, runs them in order.
+
+ Each step must accept and return List[Dict[str, Any]].
+ Per-step logging (before/after count) and optional dropped-row JSONL are provided.
+ """
+
+ def __init__(self, pipeline: List[Callable], dropped_log_path: str = ''):
+ super().__init__()
+ self._pipelines = list(pipeline)
+ self._dropped_log_path = dropped_log_path
+ if dropped_log_path:
+ os.makedirs(os.path.dirname(os.path.abspath(dropped_log_path)), exist_ok=True)
+ self._lock: Optional[PosixFileLock] = (PosixFileLock(dropped_log_path + '.lock') if dropped_log_path else None)
+ if dropped_log_path and os.path.exists(dropped_log_path):
+ os.remove(dropped_log_path)
+
+ def __call__(self, rows):
+ rows_list = self.map_col_to_row(rows)
+ total_start = len(rows_list)
+ stats = []
+ for step in self._pipelines:
+ if not rows_list:
+ break
+ step_name = getattr(step, '__name__', None) or type(step).__name__
+ before = len(rows_list)
+ t0 = time.perf_counter()
+ kept, dropped = step(rows_list)
+ rows_list = self.map_col_to_row(kept)
+ elapsed = time.perf_counter() - t0
+ after = len(rows_list)
+ stats.append(f' {step_name}: {before}->{after} (dropped {before - after}, {elapsed:.3f}s)')
+ self._log_dropped(step_name, dropped)
+ summary = '\n'.join(stats)
+ logger.info(f'[QualityPreprocessor] {total_start} -> {len(rows_list)}\n{summary}')
+ return self.map_row_to_col(rows_list)
+
+ def _log_dropped(self, step_name: str, dropped: List[Dict[str, Any]]) -> None:
+ if not self._lock or not dropped:
+ return
+ with self._lock:
+ with open(self._dropped_log_path, 'a', encoding='utf-8') as f:
+ for r in dropped:
+ f.write(json.dumps({'step': step_name, 'row': r}, ensure_ascii=False, default=str) + '\n')
diff --git a/src/twinkle_agentic/preprocessor/data_juicer.py b/src/twinkle_agentic/preprocessor/data_juicer.py
new file mode 100644
index 000000000..d019f12b0
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/data_juicer.py
@@ -0,0 +1,157 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Data-Juicer integration for trajectory quality filtering.
+
+Each class is a standalone Preprocessor with __call__ interface; they share a
+module-level op cache for model/tokenizer reuse.
+"""
+from typing import Any, Dict, List, Tuple
+
+from twinkle.preprocessor import Preprocessor
+
+from .utils import msg_content_text
+
+# ── Shared helpers ────────────────────────────────────────────────────────────
+
+_OP_CACHE: Dict = {}
+
+
+def _get_op(op_class, **kwargs):
+ key = (op_class, repr(tuple(sorted(kwargs.items()))))
+ if key not in _OP_CACHE:
+ _OP_CACHE[key] = op_class(**kwargs)
+ return _OP_CACHE[key]
+
+
+def _get_tokenizer(hf_tokenizer: str):
+ key = ('_tokenizer', hf_tokenizer)
+ if key not in _OP_CACHE:
+ from modelscope import AutoTokenizer
+ _OP_CACHE[key] = AutoTokenizer.from_pretrained(hf_tokenizer, trust_remote_code=True)
+ return _OP_CACHE[key]
+
+
+def _get_text(row: Dict[str, Any], role: str = 'assistant') -> str:
+ """Concatenate text-projected content of all turns matching `role`."""
+ return ' '.join(
+ msg_content_text(msg) for msg in (row.get('messages') or [])
+ if isinstance(msg, dict) and msg.get('role') == role)
+
+
+
+def _keep_mask(op, texts: List[str]) -> List[bool]:
+ """Run a DJ Filter op directly; no dataset/multiprocessing overhead."""
+ from data_juicer.utils.constant import Fields
+ samples = {op.text_key: texts, Fields.stats: [{} for _ in texts], Fields.meta: [{} for _ in texts]}
+ samples = op.compute_stats_batched(samples)
+ return list(op.process_batched(samples))
+
+
+def _apply_mapper(op, rows: List[Dict[str, Any]], role: str) -> None:
+ """Run a DJ Mapper on string-content messages of `role`. Non-string content is preserved verbatim."""
+ indices: List[Tuple[int, int]] = []
+ texts: List[str] = []
+ for ri, row in enumerate(rows):
+ for mi, msg in enumerate(row.get('messages') or []):
+ if not isinstance(msg, dict) or msg.get('role') != role:
+ continue
+ content = msg.get('content')
+ # Skip multimodal/None content — mapper only mutates plain string turns.
+ if not isinstance(content, str):
+ continue
+ indices.append((ri, mi))
+ texts.append(content)
+ if not texts:
+ return
+ result = op.process_batched({op.text_key: texts})
+ for (ri, mi), new_text in zip(indices, result[op.text_key]):
+ rows[ri]['messages'][mi]['content'] = new_text
+
+
+# ── Wrapper classes ───────────────────────────────────────────────────────────
+
+
+class FixUnicodeFilter(Preprocessor):
+
+ def __init__(self, normalization: str = 'NFC', role: str = 'assistant'):
+ super().__init__()
+ self._normalization = normalization
+ self._role = role
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ from data_juicer.ops.mapper import FixUnicodeMapper
+ _apply_mapper(_get_op(FixUnicodeMapper, normalization=self._normalization), rows, self._role)
+ return rows, []
+
+
+class RemoveRepeatSentencesFilter(Preprocessor):
+
+ def __init__(self, lowercase: bool = False, ignore_special_character: bool = True, role: str = 'assistant'):
+ super().__init__()
+ self._lowercase = lowercase
+ self._ignore = ignore_special_character
+ self._role = role
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ from data_juicer.ops.mapper import RemoveRepeatSentencesMapper
+ op = _get_op(RemoveRepeatSentencesMapper, lowercase=self._lowercase, ignore_special_character=self._ignore)
+ _apply_mapper(op, rows, self._role)
+ return rows, []
+
+
+class SpecialCharsFilter(Preprocessor):
+
+ def __init__(self, max_ratio: float = 0.25, role: str = 'assistant'):
+ super().__init__()
+ self._max_ratio = max_ratio
+ self._role = role
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ from data_juicer.ops.filter import SpecialCharactersFilter
+ op = _get_op(SpecialCharactersFilter, min_ratio=0.0, max_ratio=self._max_ratio)
+ texts = [_get_text(r, self._role) for r in rows]
+ # Filter only non-empty text; empty-text rows (e.g. tool-only assistants) are kept verbatim.
+ non_empty = [i for i, t in enumerate(texts) if t.strip()]
+ keep = [True] * len(rows)
+ if non_empty:
+ sub = _keep_mask(op, [texts[i] for i in non_empty])
+ for i, m in zip(non_empty, sub):
+ keep[i] = bool(m)
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ for r, k in zip(rows, keep):
+ if k:
+ out.append(r)
+ else:
+ dropped.append(dict(r, drop_reason='special_chars_ratio'))
+ return out, dropped
+
+
+class TokenNumFilter(Preprocessor):
+
+ def __init__(self,
+ hf_tokenizer: str = 'Qwen/Qwen2.5-0.5B',
+ min_num: int = 10,
+ max_num: int = 8192,
+ role: str = 'assistant'):
+ super().__init__()
+ self._hf_tokenizer = hf_tokenizer
+ self._min_num = min_num
+ self._max_num = max_num
+ self._role = role
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ tokenizer = _get_tokenizer(self._hf_tokenizer)
+ texts = [_get_text(r, self._role) for r in rows]
+ encoded = tokenizer(texts, add_special_tokens=False)
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ for r, ids in zip(rows, encoded['input_ids']):
+ if self._min_num <= len(ids) <= self._max_num:
+ out.append(r)
+ else:
+ dropped.append(dict(r, drop_reason='token_count_out_of_range'))
+ return out, dropped
diff --git a/src/twinkle_agentic/preprocessor/dead_loop_filter.py b/src/twinkle_agentic/preprocessor/dead_loop_filter.py
new file mode 100644
index 000000000..555904c0f
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/dead_loop_filter.py
@@ -0,0 +1,208 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Drop assistant messages that exhibit hesitation / dead-loop patterns."""
+import re
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple
+
+from twinkle.preprocessor import Preprocessor
+
+from .utils import is_agent_row, msg_content_text, cjk_ratio
+
+# ── Hesitation-marker regexes ─────────────────────────────────────────────────
+#
+# Match SURFACE FORM, not semantic meaning, to avoid false positives on normal
+# explanatory language.
+
+_EN_HESITATE = re.compile(
+ r'\b('
+ r'wait[,\s]*\.{2,}|wait[,\s]+(wait|no|actually|hmm|let)|'
+ r'no\s+wait|oh\s+wait|but\s+wait|'
+ r'hmm+[,\s]*\.{0,3}|uh+m*[,\s]*\.{0,3}|'
+ r'actually[,\s]+no|actually[,\s]+wait|actually[,\s]+i\s+was|'
+ r'no[,\s]+actually[,\s]+(that|this|i)|'
+ r'let\s+me\s+(re-?think|try\s+again|start\s+over|reconsider)|'
+ r'i\'?ll\s+(start\s+over|try\s+again|redo\s+this)|'
+ r'i\'?m\s+(getting\s+confused|going\s+in\s+circles|lost\s+here|not\s+sure\s+where)|'
+ r'this\s+is\s+(getting|becoming)\s+(messy|complicated\s+fast|circular)|'
+ r'i\s+keep\s+(making|getting)\s+(the\s+same\s+)?error|'
+ r'i\s+(made|keep\s+making)\s+(the\s+same\s+)?(mistake|error)\s+again'
+ r')\b',
+ re.IGNORECASE,
+)
+
+# '哦' excluded (95%+ sentence-final particle, e.g. "拍拍我哦"); single '嗯' excluded
+# (often affirmation, e.g. "嗯,好的") — only repeated '嗯{2,}' counts.
+# '等一下' excluded — overwhelmingly polite '稍等一下', not self-hesitation.
+_ZH_HESITATE = re.compile(
+ r'('
+ r'等等[,,。\s]*\.{0,3}|哦等等|不不不+|'
+ r'嗯{2,}[,,。\s]*\.{0,3}|呃+[,,。\s]*\.{0,3}|'
+ r'不对[,,。]?[,,\s]?(等等|重新|让我)|错了[,,。]?\s*让我|'
+ r'让我(重新|再次?)(想|试|来|考虑|计算)|'
+ r'我(再|重新)(想想|试试|来一次|考虑)|'
+ r'我(越来越|有点)?(搞不清楚?|不确定|迷糊了?|乱了?)|'
+ r'这(变得|太|越来越)(复杂|乱|难以?理清)|'
+ r'我(好像|似乎|又)(搞|弄)错(了)?|我(又犯|再次犯)(了)?错|'
+ r'一直(出错|犯错|搞错)'
+ r')',
+ re.UNICODE,
+)
+
+_JA_HESITATE = re.compile(
+ r'('
+ r'ちょっと待って|待って待って|いや待って|えっと+[、。\s]*\.{0,3}|'
+ r'うーん+[、。\s]*\.{0,3}|あれ[、。]?[、。\s]*(また|もう一度)|'
+ r'もう一度考え直|やり直し|混乱してきた|わからなくなって'
+ r')',
+ re.UNICODE,
+)
+
+_KO_HESITATE = re.compile(
+ r'('
+ r'잠깐[,\s]*\.{0,3}|아\s*잠깐|잠깐만요?|'
+ r'음+[,\s]*\.{0,3}|어+[,\s]*\.{0,3}|'
+ r'다시\s*(생각|시작|해보|해야)|'
+ r'헷갈(리기|리네|려서)|'
+ r'계속\s*(틀리|실수|잘못)'
+ r')',
+ re.UNICODE,
+)
+
+_HESITATE_PATTERNS = (_EN_HESITATE, _ZH_HESITATE, _JA_HESITATE, _KO_HESITATE)
+
+# 'let me' deliberately excluded — canonical agent-prelude phrasing
+# ("Let me read the file...") and would over-fire on long agent trajectories.
+_CASCADE_RE = re.compile(
+ r'\b(wait|actually|hmm|no\s+wait|oh\s+wait|'
+ r'i\s+was\s+wrong|i\s+made\s+an?\s+(error|mistake))\b|'
+ r'(等等|不对|重新|错了|嗯{2,}|让我再)',
+ re.IGNORECASE | re.UNICODE,
+)
+
+# Cover both `` and `` block forms.
+_THINK_BLOCK_RE = re.compile(r'(.*?)', re.DOTALL | re.IGNORECASE)
+
+
+# ── Detection helpers ─────────────────────────────────────────────────────────
+
+
+@dataclass(frozen=True)
+class _StuckThresholds:
+ hesitation_density: float
+ cascade_threshold: int
+ cascade_window: int
+ repetition_threshold: float
+ ngram_size: int
+ ngram_min_words: int
+
+
+def _hesitation_density(text: str) -> float:
+ """Hesitation markers per 1000 chars across all language patterns."""
+ count = sum(len(p.findall(text)) for p in _HESITATE_PATTERNS)
+ return count / max(len(text), 1) * 1000
+
+
+def _has_correction_cascade(text: str, threshold: int, window: int) -> bool:
+ """True iff ``threshold`` cascade markers fall within any ``window``-char span."""
+ starts = [m.start() for m in _CASCADE_RE.finditer(text)]
+ if len(starts) < threshold:
+ return False
+ return any(starts[i + threshold - 1] - starts[i] <= window for i in range(len(starts) - threshold + 1))
+
+
+def _high_repetition(text: str, threshold: float, ngram_size: int, ngram_min_words: int) -> bool:
+ if cjk_ratio(text[:500]) > 0.3:
+ tokens = list(text.replace(' ', '').replace('\n', ''))
+ min_tokens = ngram_min_words * ngram_size
+ else:
+ tokens = text.split()
+ min_tokens = ngram_min_words
+ if len(tokens) < min_tokens:
+ return False
+ ngrams = [tuple(tokens[i:i + ngram_size]) for i in range(len(tokens) - ngram_size + 1)]
+ if not ngrams:
+ return False
+ return (1.0 - len(set(ngrams)) / len(ngrams)) > threshold
+
+
+def _is_segment_stuck(text: str, t: _StuckThresholds) -> bool:
+ if not text:
+ return False
+ return (_hesitation_density(text) > t.hesitation_density
+ or _has_correction_cascade(text, t.cascade_threshold, t.cascade_window)
+ or _high_repetition(text, t.repetition_threshold, t.ngram_size, t.ngram_min_words))
+
+
+def _split_think(text: str) -> Tuple[str, str]:
+ """Return (think_block_inner, post_think_response). Pre-think text is treated as response."""
+ m = _THINK_BLOCK_RE.search(text)
+ if not m:
+ return '', text
+ return m.group(1), text[m.end():]
+
+
+# ── Preprocessor ─────────────────────────────────────────────────────────────
+
+
+class DeadLoopFilter(Preprocessor):
+
+ def __init__(
+ self,
+ hesitation_density_threshold: float = 7.0,
+ cascade_window: int = 800,
+ cascade_threshold: int = 5,
+ repetition_threshold: float = 0.45,
+ ngram_size: int = 8,
+ ngram_min_words: int = 30,
+ think_hesitation_density_threshold: float = 15.0,
+ think_cascade_threshold: int = 20,
+ think_repetition_threshold: float = 0.65,
+ ) -> None:
+ super().__init__()
+ # Two threshold profiles: laxer inside reasoning (free to ramble),
+ # stricter on the visible response.
+ self._response_th = _StuckThresholds(
+ hesitation_density=hesitation_density_threshold,
+ cascade_threshold=cascade_threshold,
+ cascade_window=cascade_window,
+ repetition_threshold=repetition_threshold,
+ ngram_size=ngram_size,
+ ngram_min_words=ngram_min_words,
+ )
+ self._think_th = _StuckThresholds(
+ hesitation_density=think_hesitation_density_threshold,
+ cascade_threshold=think_cascade_threshold,
+ cascade_window=cascade_window,
+ repetition_threshold=think_repetition_threshold,
+ ngram_size=ngram_size,
+ ngram_min_words=ngram_min_words,
+ )
+
+ def _is_stuck(self, text: str, reasoning: str = '') -> bool:
+ think_part, response_part = _split_think(text)
+ if reasoning and not think_part:
+ think_part = reasoning
+ return (_is_segment_stuck(think_part, self._think_th)
+ or _is_segment_stuck(response_part.strip(), self._response_th))
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ for row in rows:
+ messages = row.get('messages') or []
+ if is_agent_row(messages):
+ out.append(row)
+ continue
+ asst_msgs = [m for m in messages if isinstance(m, dict) and m.get('role') == 'assistant']
+ if not asst_msgs:
+ out.append(row)
+ continue
+ if any(self._is_stuck(
+ msg_content_text(m).strip(),
+ (m.get('reasoning_content') or m.get('thinking') or '').strip(),
+ ) for m in asst_msgs):
+ dropped.append(dict(row, drop_reason='dead_loop'))
+ else:
+ out.append(row)
+ return out, dropped
diff --git a/src/twinkle_agentic/preprocessor/dedup_filter.py b/src/twinkle_agentic/preprocessor/dedup_filter.py
new file mode 100644
index 000000000..feabac98b
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/dedup_filter.py
@@ -0,0 +1,123 @@
+import hashlib
+import json
+import re
+from typing import Any, Dict, List, Tuple
+
+from twinkle.preprocessor import Preprocessor
+
+from .utils import msg_content_text
+
+_SYSTEM_INJECTION_RE = re.compile(r'^<(?:system-reminder|system_reminder|context|user_info|attached_files)[ >]',
+ re.IGNORECASE)
+
+
+def _is_real_user(msg: Dict[str, Any]) -> bool:
+ if msg.get('role') != 'user':
+ return False
+ text = msg_content_text(msg).strip()
+ if not text:
+ return False
+ return not _SYSTEM_INJECTION_RE.match(text)
+
+
+def _head_tail(text: str, n: int) -> str:
+ text = text.strip()
+ if len(text) <= n * 2:
+ return text
+ return text[:n] + text[-n:]
+
+
+def _prefix_signature(messages: List[Dict[str, Any]], user_chars: int, asst_chars: int) -> str:
+ """Hash of the first real user turn (head+tail) + its first assistant reply (head+tail).
+
+ Skips system messages and system-injected user messages so the signature reflects the
+ actual conversation prefix, not template boilerplate. Falls back to the first two
+ non-empty assistant contents when no real user is present.
+ """
+ user_text = ''
+ asst_text = ''
+ seen_user = False
+ for msg in messages:
+ if not seen_user:
+ if _is_real_user(msg):
+ user_text = _head_tail(msg_content_text(msg), user_chars)
+ seen_user = True
+ continue
+ if msg.get('role') == 'assistant':
+ t = msg_content_text(msg).strip()
+ if t:
+ asst_text = _head_tail(t, asst_chars)
+ break
+ if not seen_user:
+ parts: List[str] = []
+ for msg in messages:
+ if msg.get('role') == 'assistant':
+ t = msg_content_text(msg).strip()
+ if t:
+ parts.append(_head_tail(t, asst_chars))
+ if len(parts) == 2:
+ break
+ user_text = parts[0] if parts else ''
+ asst_text = parts[1] if len(parts) > 1 else ''
+ return hashlib.md5(json.dumps([user_text, asst_text], ensure_ascii=False).encode()).hexdigest()
+
+
+def _full_hash(messages: List[Dict[str, Any]]) -> str:
+ return hashlib.md5(json.dumps(messages, ensure_ascii=False, sort_keys=True).encode()).hexdigest()
+
+
+class DedupFilter(Preprocessor):
+ """Global longest-wins deduplication over a fully materialized row collection.
+
+ Contract:
+ - Pure in-memory single pass. No state files, no locks, no shared cross-process state,
+ no cross-call memory. Same input → same output, every time.
+ - Must see the entire dataset in ONE __call__. NOT a per-batch pipeline step:
+ do not place inside QualityPreprocessor (which calls steps per Dataset.map batch
+ — per-batch state cannot express a global longest-wins decision).
+ - Run on List[Dict] before or after the QP pipeline; the caller is responsible for
+ materializing the dataset and re-wrapping the kept rows.
+
+ Semantics:
+ - Signature = first real user (head+tail) + first assistant reply (head+tail).
+ System and system-injected user messages are skipped.
+ - Within a signature group, the row with the most messages wins; exact-content
+ duplicates (matching full-hash) are silently collapsed; ties on message count
+ but different content keep the first-seen row.
+ - All non-winners are returned as dropped with drop_reason='duplicate'.
+ """
+
+ def __init__(self, prefix_chars: int = 100, asst_chars: int = 100):
+ super().__init__()
+ self._prefix = prefix_chars
+ self._asst_chars = asst_chars
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ # sig -> {'idx': winner index in `rows`, 'n': msg count, 'fh': full-message hash}
+ best: Dict[str, Dict[str, Any]] = {}
+ keep: List[bool] = [False] * len(rows)
+ dropped: List[Dict[str, Any]] = []
+
+ for i, row in enumerate(rows):
+ msgs = row.get('messages') or []
+ sig = _prefix_signature(msgs, self._prefix, self._asst_chars)
+ n = len(msgs)
+ fh = _full_hash(msgs)
+ cur = best.get(sig)
+ if cur is None:
+ best[sig] = {'idx': i, 'n': n, 'fh': fh}
+ keep[i] = True
+ elif fh == cur['fh']:
+ dropped.append(dict(row, drop_reason='duplicate'))
+ elif n > cur['n']:
+ # Longer version wins — demote the previous winner
+ keep[cur['idx']] = False
+ dropped.append(dict(rows[cur['idx']], drop_reason='duplicate'))
+ best[sig] = {'idx': i, 'n': n, 'fh': fh}
+ keep[i] = True
+ else:
+ dropped.append(dict(row, drop_reason='duplicate'))
+
+ kept = [rows[i] for i, k in enumerate(keep) if k]
+ return kept, dropped
diff --git a/src/twinkle_agentic/preprocessor/hard_filter.py b/src/twinkle_agentic/preprocessor/hard_filter.py
new file mode 100644
index 000000000..d9710ee53
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/hard_filter.py
@@ -0,0 +1,214 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Hard rule-based row filter (greetings, shallow replies, deny-listed system, length caps)."""
+import re
+from typing import Any, Dict, List, Optional, Tuple
+
+from twinkle.preprocessor import Preprocessor
+from .utils import msg_content_text, msg_has_media, cjk_ratio
+
+# ── Language detection ────────────────────────────────────────────────────────
+
+
+# ── Simple-query patterns ─────────────────────────────────────────────────────
+
+_EN_GREETING_RE = re.compile(
+ r'^(h+e+l+l+o+|h+i+|hey+|yo+|howdy|greetings|'
+ r'good\s+(morning|afternoon|evening|night|day)|'
+ r'what\'?s\s+up|how\'?s\s+it\s+going|how\s+are\s+you)'
+ r'[\s,!.?]*$',
+ re.IGNORECASE,
+)
+
+_EN_SIMPLE_RE = re.compile(
+ r'^('
+ r'(what|who|where|when|why|how)\s+(is|are|was|were|does|do|did|has|have|can|could|would|should)\b.{0,30}|'
+ r'(what|who|where|when|why|how)\'s\b.{0,30}|'
+ r'(is|are|was|were|do|does|did|can|could|would|should|may|might)\s+(it|this|that|you|there|they|he|she)\b.{0,30}|'
+ r'(tell\s+me(\s+(about|more))?|explain(\s+to\s+me)?|define|describe|list|summarize|give\s+me)\b.{0,20}|'
+ r'(please\s+)?(help\s+me|assist\s+me)\b.{0,20}'
+ r')\s*[?!.]?$',
+ re.IGNORECASE | re.DOTALL,
+)
+
+_ZH_GREETING_RE = re.compile(
+ r'^(你好+|您好+|早上好|下午好|晚上好|大家好|嗨+|哈+喽+|哈+|喂+|hello+|hi+)'
+ r'[\s,,!!。.]*$',
+ re.UNICODE,
+)
+
+_ZH_SIMPLE_RE = re.compile(
+ r'^('
+ r'.{0,7}(是什么|是啥|啥意思|是何|什么意思|怎么样|如何|为什么|为啥)[??。]?|'
+ r'(什么|啥|哪|谁|何|怎么|怎样|为什么|为啥|几|多少|何时|何地).{0,7}[??。]?|'
+ r'(介绍|解释|说明|告诉我|帮我说说|请问|能说说|讲讲).{0,5}|'
+ r'(请\s*(给出|介绍|解释|说明|提供|列举|讲讲|阐述|描述|概述|举例|分析|说一下)|'
+ r'能否\s*(给出|设计|提供|介绍|解释|说明)).{0,10}'
+ r')\s*[??!!。]?$',
+ re.UNICODE,
+)
+
+_JA_GREETING_RE = re.compile(
+ r'^(こんにちは+|こんばんは+|おはよう(ございます)?|やあ+|どうも+|はじめまして|よろしく(おねがいします)?)'
+ r'[\s!!。.]*$',
+ re.UNICODE,
+)
+
+_JA_SIMPLE_RE = re.compile(
+ r'^('
+ r'.{0,7}(とは何ですか|って何|とはなんですか|について教えて(ください)?|はどうですか|ですか)[??]?|'
+ r'(何|なに|どこ|いつ|誰|だれ|なぜ|どうして|どう|どれ|どの).{0,7}[??。]?'
+ r')\s*[??!!。]?$',
+ re.UNICODE,
+)
+
+_KO_GREETING_RE = re.compile(
+ r'^(안녕(하세요|하십니까)?|좋은\s*(아침|오후|저녁)|반갑습니다|여보세요)'
+ r'[\s!!.]*$',
+ re.UNICODE,
+)
+
+_KO_SIMPLE_RE = re.compile(
+ r'^('
+ r'.{0,7}(이?란\s*무엇|는\s*무엇|은\s*무엇|이?\s*뭐|가\s*뭐)[인가요까요]?[??]?|'
+ r'(무엇|뭐|어디|언제|누가|왜|어떻게).{0,7}[??]?|'
+ r'.{0,7}(에\s*대해|에\s*관해)\s*(알려주|설명해)[세요주십시오]?'
+ r')\s*[??!!]?$',
+ re.UNICODE,
+)
+
+_CJK_SIMPLE_REGEXES = (_ZH_GREETING_RE, _ZH_SIMPLE_RE, _JA_GREETING_RE, _JA_SIMPLE_RE, _KO_GREETING_RE, _KO_SIMPLE_RE)
+_LATIN_SIMPLE_REGEXES = (_EN_GREETING_RE, _EN_SIMPLE_RE)
+
+
+# ── Content helpers ──────────────────────────────────────────────────────────
+
+
+def _has_tool_calls(msg: Dict[str, Any]) -> bool:
+ """Truthy ``tool_calls`` excluding the empty-array sentinels '' / '[]' / []."""
+ tc = msg.get('tool_calls')
+ if not tc:
+ return False
+ if isinstance(tc, str):
+ s = tc.strip()
+ return bool(s) and s != '[]'
+ return bool(tc)
+
+
+def _is_simple_query(text: str, min_user_chars: int, min_user_chars_cjk: int) -> bool:
+ """True if ``text`` is a greeting or trivially simple question."""
+ t = text.strip()
+ if not t:
+ return True
+ cjk = cjk_ratio(t) >= 0.3
+ threshold = min_user_chars_cjk if cjk else min_user_chars
+ if len(t) < threshold:
+ return True
+ regexes = _CJK_SIMPLE_REGEXES if cjk else _LATIN_SIMPLE_REGEXES
+ return any(r.match(t) for r in regexes)
+
+
+def _has_thinking(msg: Dict[str, Any], min_chars: int) -> bool:
+ """True if an assistant message carries a sufficiently long thinking chain."""
+ thinking = msg.get('thinking') or msg.get('reasoning_content') or ''
+ if isinstance(thinking, str):
+ return len(thinking.strip()) >= min_chars
+ return bool(thinking)
+
+
+# ── Preprocessor ─────────────────────────────────────────────────────────────
+
+
+class HardFilter(Preprocessor):
+
+ def __init__(
+ self,
+ min_user_chars: int = 10,
+ min_user_chars_cjk: int = 6,
+ min_assistant_chars_2turn: int = 80,
+ min_thinking_chars: int = 200,
+ allow_incomplete_role: bool = False,
+ system_deny_keywords: Optional[List[str]] = None,
+ max_chars_per_round: Optional[int] = None,
+ max_total_chars: Optional[int] = None,
+ max_rounds: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ self._min_user_chars = min_user_chars
+ self._min_user_chars_cjk = min_user_chars_cjk
+ self._min_assistant_chars_2turn = min_assistant_chars_2turn
+ self._min_thinking_chars = min_thinking_chars
+ self.allow_incomplete_role = allow_incomplete_role
+ self._system_deny_re = (re.compile('|'.join(re.escape(k) for k in system_deny_keywords), re.IGNORECASE)
+ if system_deny_keywords else None)
+ self._max_chars_per_round = max_chars_per_round
+ self._max_total_chars = max_total_chars
+ self._max_rounds = max_rounds
+
+ def _drop_reason(self, row: Dict[str, Any], messages: List[Any]) -> Optional[str]:
+ """Apply rules in order; return first matching drop_reason, or None to keep."""
+ if not isinstance(messages, list):
+ return 'invalid_messages'
+
+ user_msgs = [m for m in messages if isinstance(m, dict) and m.get('role') == 'user']
+ asst_msgs = [m for m in messages if isinstance(m, dict) and m.get('role') == 'assistant']
+
+ if not user_msgs:
+ return None if self.allow_incomplete_role else 'no_user'
+
+ # Rule 1: single-turn trivial query (only meaningful when user content is plain text).
+ if len(user_msgs) == 1:
+ user_content = user_msgs[0].get('content')
+ if isinstance(user_content, str) and _is_simple_query(user_content, self._min_user_chars,
+ self._min_user_chars_cjk):
+ if not asst_msgs or not _has_thinking(asst_msgs[0], self._min_thinking_chars):
+ return 'trivial_single_turn'
+
+ # Rule 2: two-turn shallow reply without thinking.
+ if len(user_msgs) == 1 and len(asst_msgs) == 1:
+ asst = asst_msgs[0]
+ if (len(msg_content_text(asst)) < self._min_assistant_chars_2turn
+ and not _has_thinking(asst, self._min_thinking_chars)):
+ return 'shallow_reply'
+
+ # Rule 3: every assistant turn is content-empty, has no thinking, and has no tool_calls.
+ # Multimodal non-text parts (images etc.) also count as substantive.
+ if asst_msgs and all(not msg_content_text(m).strip() and not msg_has_media(m)
+ and not _has_thinking(m, self._min_thinking_chars) and not _has_tool_calls(m)
+ for m in asst_msgs):
+ return 'all_empty_assistant'
+
+ # Rule 4: system prompt matches deny keywords.
+ if self._system_deny_re:
+ sys_text = next((msg_content_text(m) for m in messages if isinstance(m, dict)
+ and m.get('role') == 'system'), '')
+ if self._system_deny_re.search(sys_text):
+ return 'system_deny_keyword'
+
+ # Rule 5: per-round character length limit (counted on textual projection).
+ if self._max_chars_per_round and any(
+ len(msg_content_text(m)) > self._max_chars_per_round for m in messages if isinstance(m, dict)):
+ return 'round_too_long'
+
+ # Rule 6: total conversation character length limit.
+ if self._max_total_chars:
+ total = sum(len(msg_content_text(m)) for m in messages if isinstance(m, dict))
+ if total > self._max_total_chars:
+ return 'total_too_long'
+
+ # Rule 7: max rounds (user-assistant pairs).
+ if self._max_rounds and len(asst_msgs) > self._max_rounds:
+ return 'too_many_rounds'
+
+ return None
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ for row in rows:
+ reason = self._drop_reason(row, row.get('messages') or [])
+ if reason is None:
+ out.append(row)
+ else:
+ dropped.append(dict(row, drop_reason=reason))
+ return out, dropped
diff --git a/src/twinkle_agentic/preprocessor/intent_classifier.py b/src/twinkle_agentic/preprocessor/intent_classifier.py
new file mode 100644
index 000000000..aaf69d7f0
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/intent_classifier.py
@@ -0,0 +1,412 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import re
+from collections import Counter
+from typing import Any, Dict, List, Optional, Tuple
+
+from twinkle.data_format import pack_value
+from twinkle.preprocessor import Preprocessor
+from twinkle.utils import get_logger
+
+from .utils import msg_content_text, normalize_tool_calls
+
+logger = get_logger(only_local_master=False)
+
+# Reasoning block regex covers both and forms.
+_THINK_BLOCK_RE = re.compile(r'(.*?)', re.DOTALL | re.IGNORECASE)
+
+# ── Intent categories ─────────────────────────────────────────────────────────
+INTENT_TOOL_CALL = 'tool_call'
+INTENT_CODE = 'code'
+INTENT_MATH = 'math'
+INTENT_COMPLEX_LOGIC = 'complex_logic'
+INTENT_REASONING = 'reasoning'
+INTENT_USER_DISSATISFACTION = 'user_dissatisfaction'
+INTENT_OTHER = 'other'
+
+# ── Heuristic patterns ────────────────────────────────────────────────────────
+_CODE_BLOCK_RE = re.compile(r'```[\s\S]{10,}?```')
+_CODE_KEYWORD_RE = re.compile(r'\b(def |class |import |from |function |const |let |var |return |if \(|for \(|while \(|'
+ r'#include|public class|private |protected |async |await |yield |throw |throws |catch |'
+ r'switch |case |break |continue |void |struct |enum |interface |abstract |static |final |'
+ r'namespace |package |module |export |lambda |func |fn |println|console\.log)\b|'
+ # Symbolic call / arrow signatures occur even without the keywords above.
+ r'(?:[a-zA-Z_]\w*\([^)\n]*\)\s*\{|=>\s*\{|->\s*[A-Za-z_]\w*)')
+
+_MATH_LATEX_RE = re.compile(
+ r'(\$\$.+?\$\$|\$[^$\n]+?\$|'
+ r'\\frac|\\sum|\\int|\\lim|\\begin\{(equation|align|matrix)|'
+ r'\\mathbb|\\partial|\\nabla|\\sqrt|\\overline|'
+ r'\\boxed|\\text\{|\\mathrm|\\langle|\\rangle|\\cdot|'
+ r'\\times|\\div|\\pm|\\leq|\\geq|\\neq|\\approx|\\equiv|'
+ r'\\infty|\\pi|\\alpha|\\beta|\\gamma|\\theta|\\lambda|\\mu|\\sigma|\\prod|\\to|\\rightarrow|'
+ r'\\\[.+?\\\]|'
+ # R1-distill writes math in plain Unicode without $...$; catch operators, Greek, sub/super digits, fractions.
+ r'[×÷±°∑∏∫√∂∇∞∈∋⊂⊃⊆⊇≤≥≠≈≡≅∝⇒⇔]|'
+ r'[α-ωΔΘΛΞΠΣΦΨΩ]|'
+ r'[⁰¹²³⁴-⁹₀-₉]|'
+ r'[½⅓⅔¼¾⅛⅜⅝⅞]|'
+ # Arithmetic equation pattern catches '30 ÷ 6 = 5' even when other markers are absent.
+ r'\d+\s*[×÷\*/\+\-]\s*\d+\s*=\s*\d+|'
+ # ≥4 comma-separated integers — number-sequence pattern.
+ r'\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+|'
+ # 'x = 5' / 'a = -3' style assignment.
+ r'[a-zA-Z]\s*=\s*-?\d+|'
+ # Chinese math vocabulary (strong indicators; ≥2 hits required so single occurrences in non-math text are safe).
+ r'积分|微分|导数|求导|偏导|梯度|极限|矩阵|向量|行列式|特征值|特征向量|'
+ r'多项式|因式分解|不等式|方程组?|二次方程|线性方程|求解|解方程|未知数|化简|约分|通分|因式|代入|应用题|算式|算术|计算题|一元(?:一次|二次|三次|方程|不等式|多项式)|二元(?:一次|二次|方程)?|'
+ r'平方|立方|开方|根号|对数|指数函数|三角函数|正弦|余弦|正切|余切|反三角|'
+ r'概率|期望值?|方差|标准差|分布|随机变量|均值|中位数|众数|百分比|比例|比率|'
+ r'子集|并集|交集|空集|集合|映射|'
+ r'乘以|除以|平方根|立方根|平方米|立方米|'
+ r'系数|常数项|首项|项数|公差|公比|'
+ r'切线|法线|渐近线|对称轴|双曲线|抛物线|椭圆|'
+ # Geometry.
+ r'三角形|四边形|多边形|长方形|正方形|圆形|圆锥|圆柱|球体|平行四边形|梯形|菱形|'
+ r'半径|直径|周长|面积|体积|对角线|内角|外角|锐角|钝角|直角|平角|余角|补角|勾股|弧度|象限|坐标系|'
+ # Sequences / number theory / elementary math.
+ r'数列|数字序列|等差数列|等比数列|等差|等比|通项|递推公式|'
+ r'奇数(?:位|项)?|偶数(?:位|项)?|质数|素数|合数|整数|小数|分数|有理数|无理数|实数|'
+ r'因数|倍数|公因数|公倍数|最大公约数|最小公倍数|阶乘|排列组合|'
+ r'余数|商(?=是|为|等)|被除数|除数|被乘数|乘数|'
+ r'(?:加|减|乘|除)\d+|'
+ r'第\d+(?:位|项)|'
+ # English math vocabulary.
+ r'\b(integral|differential|derivative|gradient|polynomial|equation|inequality|'
+ r'matrix|vector|determinant|eigenvalue|eigenvector|coefficient|'
+ r'logarithm|exponential|sqrt|theorem|lemma|proof|qed|axiom|corollary|'
+ r'sine|cosine|tangent|cosecant|secant|cotangent|arcsin|arccos|arctan|'
+ r'probability|variance|expectation|distribution|stddev|deviation|median|mean|mode|'
+ r'subset|superset|union|intersection|multiply|divide|squared|cubed|factorial|'
+ r'radius|diameter|circumference|perimeter|hypotenuse|congruent|parallel|perpendicular)\b|'
+ r'\w_\{[^}]+\}|\w\^\{[^}]+\})',
+ re.DOTALL,
+)
+
+# ── Complex logic patterns ────────────────────────────────────────────────────
+_LOGIC_STRUCTURE_RE = re.compile(
+ # Sequential reasoning markers (Chinese)
+ r'首先.{4,}其次|其次.{4,}最后|第一.{4,}第二.{4,}第三|'
+ r'一方面.{4,}另一方面|从.{1,6}角度|'
+ # Conditional / branching (Chinese)
+ r'如果.{2,30}那么|假设.{2,30}则|若.{2,20}则|'
+ r'分(为|成).{0,5}(种|类|个).{0,10}(情况|情形|场景|类型)|分情况讨论|'
+ # Causal chains (Chinese)
+ r'因为.{2,40}所以|由于.{2,40}因此|既然.{2,30}那么|'
+ r'导致.{2,30}进而|之所以.{2,30}是因为|'
+ # Synthesis / conclusion (Chinese)
+ r'综上(所述)?|综合(以上|来看|分析)|总[的而]言之|由此可[得见知]|'
+ # Comparison / trade-off (Chinese)
+ r'优缺点|利弊|优劣|权衡|对比分析|相比之下|'
+ # Multi-constraint reasoning (Chinese)
+ r'需要同时满足|同时考虑|兼顾|约束条件|'
+ # Sequential reasoning markers (English)
+ r'\b(first(ly)?|second(ly)?|third(ly)?|finally|furthermore|moreover|in addition|' # noqa: E501
+ r'on (the )?one hand|on the other hand|' # noqa: E501
+ r'as a result|consequently|therefore|hence|thus|accordingly)\b|'
+ # Conditional / branching (English)
+ r'\b(if .{5,30} then|assuming .{5,30} then|in (case|scenario) .{2,10}(A|B|1|2)|' # noqa: E501
+ r'case \d|scenario \d)\b|'
+ # Synthesis (English)
+ r'\b(in (conclusion|summary)|to (summarize|conclude)|overall|all things considered|' # noqa: E501
+ r'weighing .{3,20} against|pros and cons|trade-?offs?|advantages .{0,10} disadvantages)\b',
+ re.DOTALL | re.IGNORECASE,
+)
+
+_DISSATISFACTION_ZH_RE = re.compile(
+ # Quality / correctness complaints.
+ r'不[满好对行准确靠谱严]|不太[行好对准]|不正确|不准确|不对劲|不靠谱|不严谨|'
+ # Severity intensifiers.
+ r'太(差|慢|烂|傻|笨|垃圾|菜|弱|水|差劲|low)|这(么)?(差|烂|垃圾|傻|破|low)|'
+ # Redo / retry.
+ r'重[做来新答试]|重新(回答|做|来|算|想|考虑|生成)|再(答|来|做|算|想|试)一(次|遍|回|下)|你再答|'
+ # Wrong / errors.
+ r'错了?|错误|又错|搞错|弄错|出错|完全错|全错|大错|根本不(对|是)|压根不(对|是)|'
+ # Off-topic / unhelpful.
+ r'有问题|没用|没帮助|答非所问|文不对题|牛头不对|风马牛|跑题|偏题|偏离|跑偏|'
+ # Stop talking nonsense.
+ r'别瞎|别乱|别胡|你在说(什么|啥)|这是什么|这都什么|'
+ r'离谱|搞什么|质量(太|很差)|胡(说|扯|言|乱|写|编|闹)|瞎(编|说|扯|写|想|猜|蒙|讲)|'
+ # Random / illogical.
+ r'莫名其妙|一塌糊涂|一派胡言|谬(论|误)|废话|屁话|没逻辑|没道理|说不通|不合逻辑|'
+ # Negative emotion.
+ r'不(满意|开心|高兴)|失望|让(我|人)失望|烦人|真烦|厌|气死|'
+ # Misunderstanding / model failure.
+ r'你(没|不)(懂|理解|明白|听懂)|理解错|抓不住重点|没get|没get到|'
+ r'我说的不是|我问的不是|这不是我(说|问|想|要)|你听(错|不懂)|没听懂|'
+ # Time / value waste.
+ r'浪费时间|没意义|没价值|垃圾|废物|'
+ # Generic anger.
+ r'什么(玩意|东西|鬼)|你这是|你这答', )
+_DISSATISFACTION_EN_RE = re.compile(
+ # Negative adjectives.
+ r'\b(wrong|incorrect|useless|terrible|awful|horrible|bad|poor|lousy|sloppy|stupid|dumb|'
+ r'idiotic|ridiculous|broken|misleading|infuriating|annoying|disappointing|disappointed|'
+ r'unacceptable|unhelpful|inaccurate|imprecise|sub[- ]?par|low[- ]?quality)\b|'
+ # "not X" complaints.
+ r'\bnot (correct|right|good|helpful|useful|accurate|relevant|making sense|'
+ r'what (i|I) (asked|wanted|meant|need|expected|requested))\b|'
+ # Negation phrasings.
+ r'(doesn\'?t|does not|didn\'?t|did not) (make sense|work|help|fit|match|address)|'
+ r'makes? (no|zero|little) sense|'
+ # Redo / retry.
+ r'\b(redo|try again|do (it|this|that) again|start over|start again|do over|do better|'
+ r'once more|again from scratch)\b|'
+ # Insults / bullshit.
+ r'\b(nonsense|garbage|trash|crap|bullshit|bs|baloney|hogwash|gibberish)\b|'
+ r'(low|poor|bad|terrible) quality|waste of (time|effort|energy)|'
+ # Misunderstanding.
+ r'you (misunderstood|don\'?t understand|didn\'?t (get it|understand|listen)|missed (the|my) point)|'
+ r'that\'?s (not what|wrong|incorrect|terrible|garbage|nonsense|useless)|'
+ # Profanity.
+ r'\b(WTF|wth|what the (heck|hell|fuck))\b|'
+ # Off-target.
+ r'\b(off[- ]topic|missed the mark|way off|completely off|totally wrong|nowhere near)\b|'
+ r'not (even|really|quite) (close|right|correct)|'
+ # Sarcasm / disbelief.
+ r'come on|are you (serious|kidding|joking|sure)|'
+ r'\bfrustrat\w+\b',
+ re.IGNORECASE,
+)
+
+# ── Helpers ───────────────────────────────────────────────────────────────────
+
+
+def _pair_assistant(messages: List[Dict[str, Any]], idx: int, role: str) -> Optional[int]:
+ """Resolve which assistant idx represents the round that owns a signal at (idx, role)."""
+ if role == 'assistant':
+ return idx
+ if role == 'user':
+ for j in range(idx + 1, len(messages)):
+ m = messages[j]
+ if isinstance(m, dict) and m.get('role') == 'assistant':
+ return j
+ return None
+
+
+# ── Intent detectors (extensible pipeline) ────────────────────────────────────
+
+
+class IntentDetector:
+ """Base class. Each subclass sets ``intent`` and implements ``__call__``.
+
+ ``__call__(messages)`` returns a list of assistant indices (key rounds) that
+ match this intent within the given trajectory. An empty list means no match.
+ Set ``definitive = True`` so the pipeline short-circuits on this detector
+ (used for hard signals such as tool calls).
+ """
+
+ intent: str = ''
+ definitive: bool = False
+
+ def __call__(self, messages: List[Dict[str, Any]]) -> List[int]:
+ raise NotImplementedError
+
+
+class _RegexDetector(IntentDetector):
+ """Common scaffolding: scan messages, run ``_match`` on each text, pair to assistant."""
+
+ role_filter: Optional[str] = None
+
+ def _match(self, text: str) -> bool:
+ return False
+
+ def __call__(self, messages):
+ rounds = set()
+ for idx, m in enumerate(messages):
+ if not isinstance(m, dict):
+ continue
+ role = m.get('role')
+ # tool/system messages can never resolve to a key round (see _pair_assistant)
+ # and tool outputs are often multi-MB — skip to avoid wasted regex scans.
+ if role not in ('assistant', 'user'):
+ continue
+ if self.role_filter and role != self.role_filter:
+ continue
+ text = msg_content_text(m)
+ if not text or not self._match(text):
+ continue
+ asst_idx = _pair_assistant(messages, idx, role)
+ if asst_idx is not None:
+ rounds.add(asst_idx)
+ return sorted(rounds)
+
+
+class ToolCallDetector(IntentDetector):
+ """Mark every assistant turn that carries a real (non-empty) ``tool_calls`` payload."""
+
+ intent = INTENT_TOOL_CALL
+ definitive = True
+
+ def __call__(self, messages):
+ return [
+ i for i, m in enumerate(messages)
+ if isinstance(m, dict) and m.get('role') == 'assistant' and normalize_tool_calls(m)
+ ]
+
+
+class CodeDetector(_RegexDetector):
+ intent = INTENT_CODE
+
+ def __init__(self, threshold: int = 3) -> None:
+ self.threshold = threshold
+
+ def _match(self, text):
+ blocks = _CODE_BLOCK_RE.findall(text)
+ if blocks:
+ return True
+ return len(_CODE_KEYWORD_RE.findall(text)) >= self.threshold
+
+
+class MathDetector(_RegexDetector):
+ intent = INTENT_MATH
+
+ def __init__(self, threshold: int = 4) -> None:
+ self.threshold = threshold
+
+ def _match(self, text):
+ return len(_MATH_LATEX_RE.findall(text)) >= self.threshold
+
+
+class ComplexLogicDetector(_RegexDetector):
+ intent = INTENT_COMPLEX_LOGIC
+ role_filter = 'assistant'
+
+ def __init__(self, threshold: int = 6) -> None:
+ self.threshold = threshold
+
+ def _match(self, text):
+ return len(_LOGIC_STRUCTURE_RE.findall(text)) >= self.threshold
+
+
+class ReasoningDetector(IntentDetector):
+ """Detect assistant turns with explicit reasoning chains (reasoning_content or blocks)."""
+
+ intent = INTENT_REASONING
+
+ def __init__(self, min_chars: int = 200) -> None:
+ self._min_chars = min_chars
+
+ def __call__(self, messages):
+ rounds = []
+ for i, m in enumerate(messages):
+ if not isinstance(m, dict) or m.get('role') != 'assistant':
+ continue
+ rc = m.get('reasoning_content') or ''
+ if isinstance(rc, str) and len(rc.strip()) >= self._min_chars:
+ rounds.append(i)
+ continue
+ match = _THINK_BLOCK_RE.search(msg_content_text(m))
+ if match and len(match.group(1).strip()) >= self._min_chars:
+ rounds.append(i)
+ return rounds
+
+
+class UserDissatisfactionDetector(_RegexDetector):
+ intent = INTENT_USER_DISSATISFACTION
+ role_filter = 'user'
+
+ def _match(self, text):
+ return bool(_DISSATISFACTION_ZH_RE.search(text) or _DISSATISFACTION_EN_RE.search(text))
+
+ def __call__(self, messages):
+ # Dissatisfaction is a reaction — require at least one prior assistant turn.
+ seen_assistant = False
+ rounds = set()
+ for idx, m in enumerate(messages):
+ if not isinstance(m, dict):
+ continue
+ role = m.get('role')
+ if role == 'assistant':
+ seen_assistant = True
+ continue
+ if role != 'user' or not seen_assistant:
+ continue
+ text = msg_content_text(m)
+ if text and self._match(text):
+ asst_idx = _pair_assistant(messages, idx, role)
+ if asst_idx is not None:
+ rounds.add(asst_idx)
+ return sorted(rounds)
+
+
+# ── Preprocessor ──────────────────────────────────────────────────────────────
+
+
+class IntentClassifier(Preprocessor):
+ """Annotate each trajectory with its primary intent and key-round indices.
+
+ Pure-heuristic, no LLM. Each intent is a pluggable :class:`IntentDetector`;
+ pass ``detectors=[...]`` to extend or override.
+
+ Annotates per row::
+
+ row['intent'] # primary intent string
+ row['user_data'] += [('key_rounds', list[int]), # assistant indices
+ ('intents', dict[str, str])] # per-round intent
+ """
+
+ DEFAULT_DETECTORS: List[IntentDetector] = [
+ ToolCallDetector(),
+ CodeDetector(),
+ MathDetector(),
+ ComplexLogicDetector(),
+ ReasoningDetector(),
+ UserDissatisfactionDetector(),
+ ]
+
+ def __init__(
+ self,
+ detectors: Optional[List[IntentDetector]] = None,
+ intent_field: str = 'intent',
+ drop_no_key_rounds: bool = True,
+ ) -> None:
+ super().__init__()
+ self._intent_field = intent_field
+ self._drop_no_key_rounds = drop_no_key_rounds
+ self._detectors = list(detectors) if detectors is not None else list(self.DEFAULT_DETECTORS)
+
+ def _detect(self, messages: List[Dict[str, Any]]) -> Dict[int, str]:
+ """Run detector pipeline; later detectors never override earlier intent on the same round."""
+ round_intents: Dict[int, str] = {}
+ for det in self._detectors:
+ rounds = det(messages)
+ if not rounds:
+ continue
+ for idx in rounds:
+ round_intents.setdefault(idx, det.intent)
+ if det.definitive:
+ break
+ return round_intents
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ if not rows:
+ return rows, []
+
+ out = []
+ dropped = []
+ for row in rows:
+ row = dict(row)
+ messages = row.get('messages')
+ round_intents = (self._detect(messages) if isinstance(messages, list) and messages else {})
+
+ if round_intents:
+ primary = Counter(round_intents.values()).most_common(1)[0][0]
+ # Stored entries are (key, json.dumps(value)) for Arrow stability across shards.
+ existing = list(row.get('user_data') or [])
+ user_data = [(k, v) for (k, v) in existing if k not in ('key_rounds', 'intents')]
+ user_data.append(('key_rounds', pack_value(sorted(round_intents))))
+ user_data.append(('intents', pack_value({str(k): v for k, v in round_intents.items()})))
+ row['user_data'] = user_data
+ else:
+ if self._drop_no_key_rounds:
+ dropped.append(dict(row, drop_reason='no_key_rounds'))
+ continue
+ primary = INTENT_OTHER
+
+ row[self._intent_field] = primary
+ out.append(row)
+
+ dist = Counter(r[self._intent_field] for r in out)
+ logger.info(f'[IntentClassifier] distribution: {dict(dist)}')
+ return out, dropped
diff --git a/src/twinkle_agentic/preprocessor/llm_backend.py b/src/twinkle_agentic/preprocessor/llm_backend.py
new file mode 100644
index 000000000..a6e917ab3
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/llm_backend.py
@@ -0,0 +1,344 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Abstract LLM backend for preprocessor pipeline.
+
+Supports two modes:
+ - OpenAIBackend: httpx-based calls to any OpenAI-compatible HTTP server
+ - SamplerBackend: direct calls to Twinkle vLLMSampler Ray actor (no HTTP)
+"""
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Tuple
+
+from twinkle.utils import get_logger
+
+logger = get_logger(only_local_master=False)
+
+
+class LLMBackend(ABC):
+ """Abstract base for LLM inference used by QualityPreprocessor stages."""
+
+ @abstractmethod
+ def chat(
+ self,
+ messages: List[Dict[str, Any]],
+ *,
+ temperature: float = 0.0,
+ max_tokens: int = 16,
+ n: int = 1,
+ ) -> List[Dict[str, str]]:
+ """Chat completion.
+
+ Returns:
+ List of n choices, each a dict with keys 'content' and 'reasoning_content'.
+ """
+
+ def chat_batch(
+ self,
+ messages_list: List[List[Dict[str, Any]]],
+ *,
+ temperature: float = 0.0,
+ max_tokens: int = 16,
+ n: int = 1,
+ ) -> List[List[Dict[str, str]]]:
+ """Batched chat completion. Returns one List[choice] per input messages list.
+
+ Default impl loops over `chat`; backends should override to fan out concurrently
+ (HTTP) or pass the full list to the underlying sampler in a single call (vLLM DP).
+ """
+ return [self.chat(m, temperature=temperature, max_tokens=max_tokens, n=n) for m in messages_list]
+
+ @abstractmethod
+ def prompt_logprobs(self, messages: List[Dict[str, Any]]) -> Optional[List]:
+ """Evaluate prompt tokens without generation.
+
+ Returns:
+ List of per-token logprob entries (format varies by backend but
+ is compatible with _extract_logprob helpers), or None on failure.
+ """
+
+ @abstractmethod
+ def prompt_logprobs_ids(self, input_ids_list: List[List[int]]) -> List[List]:
+ """Batched: evaluate raw token-id prompts without chat template wrapping.
+
+ Used for unconditional perplexity (e.g. IFD denominator). Caller MUST
+ supply a list of token-id sequences; for distributed backends the list
+ length must satisfy backend-specific batching constraints (e.g.
+ ``len >= dp_world_size`` for SamplerBackend).
+ """
+
+ def embeddings(self, texts: List[str]) -> Any:
+ """Compute text embeddings. Override in backends that support it."""
+ raise NotImplementedError(f'{type(self).__name__} does not support embeddings')
+
+
+class OpenAIBackend(LLMBackend):
+ """Backend wrapping any OpenAI-compatible HTTP endpoint."""
+
+ def __init__(
+ self,
+ endpoint: str,
+ model: str = 'default',
+ api_key: str = '',
+ timeout: float = 120.0,
+ ):
+ import httpx
+ headers = {'Content-Type': 'application/json'}
+ if api_key:
+ headers['Authorization'] = f'Bearer {api_key}'
+ self._client = httpx.Client(timeout=timeout, headers=headers)
+ base = endpoint.rstrip('/')
+ self._chat_endpoint = f'{base}/v1/chat/completions'
+ self._embed_endpoint = f'{base}/v1/embeddings'
+ self._model = model
+
+ @property
+ def model(self) -> str:
+ return self._model
+
+ def chat(
+ self,
+ messages: List[Dict[str, Any]],
+ *,
+ temperature: float = 0.0,
+ max_tokens: int = 16,
+ n: int = 1,
+ ) -> List[Dict[str, str]]:
+ try:
+ resp = self._client.post(
+ self._chat_endpoint,
+ json={
+ 'model': self._model,
+ 'messages': messages,
+ 'temperature': temperature,
+ 'max_tokens': max_tokens,
+ 'n': n,
+ })
+ resp.raise_for_status()
+ choices = resp.json().get('choices', [])
+ results = []
+ for c in choices:
+ msg = c.get('message') or {}
+ results.append({
+ 'content': msg.get('content') or '',
+ 'reasoning_content': msg.get('reasoning_content') or '',
+ })
+ return results
+ except Exception as e:
+ logger.warning(f'[OpenAIBackend] chat failed: {e}')
+ return []
+
+ def chat_batch(
+ self,
+ messages_list: List[List[Dict[str, Any]]],
+ *,
+ temperature: float = 0.0,
+ max_tokens: int = 16,
+ n: int = 1,
+ max_workers: int = 16,
+ ) -> List[List[Dict[str, str]]]:
+ """Concurrent chat: vLLM HTTP server multiplexes requests; httpx.Client is thread-safe."""
+ from concurrent.futures import ThreadPoolExecutor
+ if not messages_list:
+ return []
+ workers = max(1, min(max_workers, len(messages_list)))
+ results: List[List[Dict[str, str]]] = [[] for _ in messages_list]
+ with ThreadPoolExecutor(max_workers=workers) as ex:
+ futs = {
+ ex.submit(self.chat, m, temperature=temperature, max_tokens=max_tokens, n=n): i
+ for i, m in enumerate(messages_list)
+ }
+ for fut in futs:
+ results[futs[fut]] = fut.result()
+ return results
+
+ def prompt_logprobs(self, messages: List[Dict[str, Any]]) -> Optional[List]:
+ try:
+ resp = self._client.post(
+ self._chat_endpoint,
+ json={
+ 'model': self._model,
+ 'messages': messages,
+ 'max_tokens': 0,
+ 'prompt_logprobs': 1,
+ })
+ resp.raise_for_status()
+ return resp.json().get('prompt_logprobs')
+ except Exception:
+ return None
+
+ def prompt_logprobs_ids(self, input_ids_list: List[List[int]]) -> List[List]:
+ endpoint = self._chat_endpoint.rsplit('/', 2)[0] + '/v1/completions'
+ results: List[List] = []
+ for input_ids in input_ids_list:
+ resp = self._client.post(
+ endpoint,
+ json={
+ 'model': self._model,
+ 'prompt': list(input_ids),
+ 'max_tokens': 0,
+ 'echo': True,
+ 'prompt_logprobs': 1,
+ })
+ resp.raise_for_status()
+ data = resp.json()
+ choices = data.get('choices') or []
+ if choices and 'prompt_logprobs' in choices[0]:
+ results.append(choices[0]['prompt_logprobs'])
+ else:
+ results.append(data['prompt_logprobs'])
+ return results
+
+ def embeddings(self, texts: List[str]):
+ import numpy as np
+ resp = self._client.post(
+ self._embed_endpoint, json={
+ 'model': self._model,
+ 'input': texts,
+ })
+ resp.raise_for_status()
+ data = resp.json().get('data', [])
+ data_sorted = sorted(data, key=lambda x: x.get('index', 0))
+ return np.array([d['embedding'] for d in data_sorted], dtype=np.float32)
+
+
+class SamplerBackend(LLMBackend):
+ """Backend wrapping a Twinkle vLLMSampler (Ray actor, no HTTP overhead)."""
+
+ def __init__(
+ self,
+ sampler,
+ embed_endpoint: str = '',
+ embed_model: str = 'bge-m3',
+ ):
+ """
+ Args:
+ sampler: A vLLMSampler instance (with template already set).
+ embed_endpoint: Optional OpenAI-compatible endpoint for embeddings.
+ embed_model: Model name for embeddings.
+ """
+ self._sampler = sampler
+ self._embed_endpoint = embed_endpoint
+ self._embed_model = embed_model
+ self._embed_client = None
+ if embed_endpoint:
+ import httpx
+ self._embed_client = httpx.Client(timeout=120.0)
+ self._embed_url = f'{embed_endpoint.rstrip("/")}/v1/embeddings'
+
+ def chat(
+ self,
+ messages: List[Dict[str, Any]],
+ *,
+ temperature: float = 0.0,
+ max_tokens: int = 16,
+ n: int = 1,
+ ) -> List[Dict[str, str]]:
+ from twinkle.data_format import SamplingParams
+ trajectory = {'messages': messages}
+ params = SamplingParams(
+ temperature=temperature,
+ max_tokens=max_tokens,
+ num_samples=n,
+ )
+ try:
+ responses = self._sampler.sample(trajectory, params)
+ results = []
+ for resp in responses:
+ for seq in resp.sequences:
+ text = seq.decoded or ''
+ reasoning = ''
+ if '' in text:
+ parts = text.split('', 1)
+ reasoning = parts[0].split('')[-1].strip()
+ text = parts[1].strip()
+ results.append({'content': text, 'reasoning_content': reasoning})
+ return results
+ except Exception as e:
+ logger.warning(f'[SamplerBackend] chat failed: {e}')
+ return []
+
+ @staticmethod
+ def _split_think(text: str) -> Tuple[str, str]:
+ if '' in text:
+ parts = text.split('', 1)
+ return parts[1].strip(), parts[0].split('')[-1].strip()
+ return text, ''
+
+ def chat_batch(
+ self,
+ messages_list: List[List[Dict[str, Any]]],
+ *,
+ temperature: float = 0.0,
+ max_tokens: int = 16,
+ n: int = 1,
+ ) -> List[List[Dict[str, str]]]:
+ """One sampler dispatch over the full list; lets vLLM DP workers stay saturated."""
+ from twinkle.data_format import SamplingParams
+ if not messages_list:
+ return []
+ device_mesh = getattr(self._sampler, 'device_mesh', None)
+ dp_world_size = getattr(device_mesh, 'dp_world_size', 1) or 1
+ n_inputs = len(messages_list)
+ feats = [{'messages': m} for m in messages_list]
+ # Pad the dispatch so every DP worker has at least one item; trim duplicates after.
+ if n_inputs < dp_world_size:
+ feats = feats + [feats[-1]] * (dp_world_size - n_inputs)
+ params = SamplingParams(temperature=temperature, max_tokens=max_tokens, num_samples=n)
+ try:
+ responses = self._sampler.sample(feats, params)
+ except Exception as e:
+ logger.warning(f'[SamplerBackend] chat_batch failed: {e}')
+ return [[] for _ in range(n_inputs)]
+ responses = list(responses)[:n_inputs]
+ out: List[List[Dict[str, str]]] = []
+ for resp in responses:
+ choices: List[Dict[str, str]] = []
+ for seq in (getattr(resp, 'sequences', None) or []):
+ text, reasoning = self._split_think(seq.decoded or '')
+ choices.append({'content': text, 'reasoning_content': reasoning})
+ out.append(choices)
+ while len(out) < n_inputs:
+ out.append([])
+ return out
+
+ def prompt_logprobs(self, messages: List[Dict[str, Any]]) -> Optional[List]:
+ from twinkle.data_format import SamplingParams
+ trajectory = {'messages': messages}
+ params = SamplingParams(max_tokens=0, prompt_logprobs=1)
+ try:
+ responses = self._sampler.sample(trajectory, params)
+ if responses and responses[0].prompt_logprobs is not None:
+ return responses[0].prompt_logprobs
+ return None
+ except Exception as e:
+ logger.warning(f'[SamplerBackend] prompt_logprobs failed: {e}')
+ return None
+
+ def prompt_logprobs_ids(self, input_ids_list: List[List[int]]) -> List[List]:
+ from twinkle.data_format import SamplingParams
+ if not isinstance(input_ids_list, list) or not input_ids_list:
+ raise ValueError('prompt_logprobs_ids requires a non-empty List[List[int]].')
+ device_mesh = getattr(self._sampler, 'device_mesh', None)
+ dp_world_size = getattr(device_mesh, 'dp_world_size', 1) or 1
+ if len(input_ids_list) < dp_world_size:
+ raise ValueError(f'SamplerBackend.prompt_logprobs_ids requires at least '
+ f'dp_world_size={dp_world_size} inputs to keep all DP workers busy, '
+ f'got {len(input_ids_list)}. Batch upstream before calling.')
+ feats = [{'input_ids': list(ids)} for ids in input_ids_list]
+ params = SamplingParams(max_tokens=0, prompt_logprobs=1)
+ responses = self._sampler.sample(feats, params)
+ return [r.prompt_logprobs for r in responses]
+
+ def embeddings(self, texts: List[str]):
+ if self._embed_client is None:
+ raise NotImplementedError('SamplerBackend requires embed_endpoint for embeddings. '
+ 'Pass embed_endpoint when constructing SamplerBackend.')
+ import numpy as np
+ resp = self._embed_client.post(
+ self._embed_url, json={
+ 'model': self._embed_model,
+ 'input': texts,
+ })
+ resp.raise_for_status()
+ data = resp.json().get('data', [])
+ data_sorted = sorted(data, key=lambda x: x.get('index', 0))
+ return np.array([d['embedding'] for d in data_sorted], dtype=np.float32)
diff --git a/src/twinkle_agentic/preprocessor/message_normalizer.py b/src/twinkle_agentic/preprocessor/message_normalizer.py
new file mode 100644
index 000000000..b82a497a6
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/message_normalizer.py
@@ -0,0 +1,214 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Normalize message sequences to standard OpenAI multi-turn schema.
+
+Three passes (each idempotent):
+ 0. **Heartbeat strip** — drop heartbeat polling rounds (user + assistant).
+ 1. **Tool-call normalization** — rewrite embedded tool calls (Cline XML,
+ ReAct, VCP, Hermes) to ``tool_calls`` + ``role=tool``.
+ 2. **Consecutive-role merge** — merge adjacent same-role messages into one
+ (content joined by newline). Empty messages inside a run are dropped.
+ Single-element runs are preserved verbatim (keeps multimodal list
+ content intact).
+
+All passes use ``msg_content_text`` to project content (str | list-of-parts)
+to plain text for inspection. List content with only text parts is treated
+identically to plain strings; truly multimodal single-element runs are
+preserved verbatim by the merge pass.
+"""
+import json
+import re
+from typing import Any, Dict, List, Tuple
+
+from twinkle.preprocessor import Preprocessor
+from twinkle.template.tools import ToolCallRegistry
+
+from .utils import msg_content_text, msg_has_media
+
+# IGNORECASE absorbs every variant ("Read HEARTBEAT.md", "HEARTBEAT_OK",
+# "duplicate heartbeat", etc.) under the single token "heartbeat".
+_HEARTBEAT_USER_RE = re.compile(r'heartbeat|keep.?alive', re.IGNORECASE)
+_HEARTBEAT_ASST_RE = re.compile(r'heartbeat', re.IGNORECASE)
+
+
+# ── Pass 0: heartbeat strip ─────────────────────────────────────────────────
+
+
+def _strip_heartbeat(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ out: List[Dict[str, Any]] = []
+ skip_next_assistant = False
+ for m in messages:
+ if not isinstance(m, dict):
+ continue
+ role = m.get('role', '')
+ if role == 'developer':
+ m = dict(m)
+ m['role'] = 'system'
+ role = 'system'
+ text = msg_content_text(m)[:300]
+ if role == 'user' and _HEARTBEAT_USER_RE.search(text):
+ skip_next_assistant = True
+ continue
+ if role == 'assistant' and not m.get('tool_calls'):
+ if skip_next_assistant or _HEARTBEAT_ASST_RE.search(text):
+ skip_next_assistant = False
+ continue
+ skip_next_assistant = False
+ out.append(m)
+ while out and out[0].get('role') not in ('user', 'system'):
+ out.pop(0)
+ return out
+
+
+# ── Pass 1: tool-call normalization ─────────────────────────────────────────
+
+
+def _normalize_tool_calls(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Rewrite embedded tool calls in assistant messages to OpenAI schema."""
+ out: List[Dict[str, Any]] = []
+ call_counter = 0
+ i = 0
+ n = len(messages)
+ while i < n:
+ msg = messages[i]
+ text = msg_content_text(msg)
+ parser = (
+ ToolCallRegistry.detect_first(text) if msg.get('role') == 'assistant' and not msg.get('tool_calls')
+ and text else None)
+ parsed = parser.parse(text) if parser else None
+ if not parsed:
+ out.append(msg)
+ i += 1
+ continue
+
+ tc_list = []
+ for tc in parsed:
+ call_counter += 1
+ args = tc['function']['arguments']
+ tc_list.append({
+ 'id': f'call_norm_{call_counter:04d}',
+ 'type': 'function',
+ 'function': {
+ 'name': tc['function']['name'],
+ 'arguments': json.dumps(args, ensure_ascii=False) if isinstance(args, dict) else str(args),
+ },
+ })
+ out.append({
+ 'role': 'assistant',
+ 'content': parser.clean(text),
+ 'tool_calls': json.dumps(tc_list, ensure_ascii=False),
+ 'tool_call_id': '',
+ })
+
+ # Consume following user messages as tool results — one per tool call.
+ j = i + 1
+ for tc_idx, tc in enumerate(tc_list):
+ if j >= n or messages[j].get('role') != 'user':
+ break
+ nxt_text = msg_content_text(messages[j])
+ if not nxt_text:
+ break
+ if parser.detect_result(nxt_text):
+ body = parser.parse_result(nxt_text)
+ elif tc_idx == 0 and len(tc_list) == 1:
+ body = nxt_text
+ else:
+ break
+ out.append({
+ 'role': 'tool',
+ 'content': body,
+ 'tool_calls': '',
+ 'tool_call_id': tc['id'],
+ })
+ j += 1
+ i = j
+ return out
+
+
+# ── Pass 2: consecutive-role merge ──────────────────────────────────────────
+
+
+def _is_atomic(msg: Dict[str, Any]) -> bool:
+ """Atomic = never merge: tool results + assistant turns carrying tool_calls."""
+ role = msg.get('role', '')
+ return role == 'tool' or (role == 'assistant' and msg.get('tool_calls'))
+
+
+def _is_blank_content(msg: Dict[str, Any]) -> bool:
+ if msg_has_media(msg):
+ return False
+ return not msg_content_text(msg).strip()
+
+
+def _merge_consecutive(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ out: List[Dict[str, Any]] = []
+ i = 0
+ n = len(messages)
+ while i < n:
+ msg = messages[i]
+
+ if _is_atomic(msg):
+ # Drop only blank-string tool messages; preserve everything else verbatim.
+ if msg.get('role') == 'tool' and _is_blank_content(msg):
+ i += 1
+ continue
+ out.append(msg)
+ i += 1
+ continue
+
+ role = msg.get('role', '')
+ j = i + 1
+ run = [msg]
+ while j < n and messages[j].get('role') == role and not _is_atomic(messages[j]):
+ run.append(messages[j])
+ j += 1
+
+ if len(run) == 1:
+ # Preserve original shape (incl. multimodal list content); drop only blank strings.
+ if not _is_blank_content(msg):
+ out.append(msg)
+ else:
+ non_blank = [m for m in run if not _is_blank_content(m)]
+ if not non_blank:
+ i = j
+ continue
+ has_str = any(isinstance(m.get('content'), str) for m in non_blank)
+ has_list = any(isinstance(m.get('content'), list) for m in non_blank)
+ if has_str and has_list:
+ # Mixed types — keep each individually, don't merge.
+ out.extend(non_blank)
+ elif has_list:
+ parts: list = []
+ for m in non_blank:
+ parts.extend(m.get('content'))
+ merged = dict(non_blank[0])
+ merged['content'] = parts
+ out.append(merged)
+ else:
+ merged = dict(non_blank[0])
+ merged['content'] = '\n'.join(msg_content_text(m).strip() for m in non_blank)
+ out.append(merged)
+ i = j
+ return out
+
+
+# ── Combined normalizer ─────────────────────────────────────────────────────
+
+
+class MessageNormalizer(Preprocessor):
+ """Three-pass message normalizer (heartbeat strip + tool-call rewrite + role merge).
+
+ Multimodal list-shaped content passes through every stage untouched.
+ This is a mapper — it never drops rows.
+ """
+
+ def __call__(self, rows: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ for row in rows:
+ msgs = row.get('messages')
+ if not isinstance(msgs, list) or not msgs:
+ continue
+ msgs = _strip_heartbeat(msgs)
+ msgs = _normalize_tool_calls(msgs)
+ msgs = _merge_consecutive(msgs)
+ row['messages'] = msgs
+ return rows, []
diff --git a/src/twinkle_agentic/preprocessor/message_sanity.py b/src/twinkle_agentic/preprocessor/message_sanity.py
new file mode 100644
index 000000000..07a389cff
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/message_sanity.py
@@ -0,0 +1,343 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""MessageSanityFilter — structural and content sanity for messages-format datasets.
+
+Architecture: check-pipeline pattern. Each check is a standalone function with
+signature ``(messages, is_agent, cfg) -> bool`` (True = pass). The filter class
+iterates enabled checks in order.
+"""
+import json
+import re
+from typing import Any, Dict, List, Optional, Tuple
+
+from twinkle.preprocessor import Preprocessor
+
+from .utils import (build_sensitive_regex, cjk_ratio, is_agent_row, load_sensitive_words, msg_content_text,
+ msg_has_media, msg_has_payload, normalize_tool_calls)
+
+# Backward-compat re-exports.
+_msg_content_text = msg_content_text
+_normalize_tool_calls = normalize_tool_calls
+
+_VALID_ROLES = {'system', 'user', 'assistant', 'tool'}
+_IDENTIFIER_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_.\-]*$')
+
+# ══════════════════════════════════════════════════════════════════════════════
+# Transforms (applied before checks, may modify messages)
+# ══════════════════════════════════════════════════════════════════════════════
+
+
+def consolidate_system_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Fold multiple system messages into one at index 0."""
+ sys_count = sum(1 for m in messages if isinstance(m, dict) and m.get('role') == 'system')
+ misplaced = any(isinstance(m, dict) and m.get('role') == 'system' and i != 0 for i, m in enumerate(messages))
+ if sys_count <= 1 and not misplaced:
+ return messages
+ sys_chunks: List[str] = []
+ rest: List[Dict[str, Any]] = []
+ template: Optional[Dict[str, Any]] = None
+ for m in messages:
+ if isinstance(m, dict) and m.get('role') == 'system':
+ if template is None:
+ template = m
+ text = msg_content_text(m).strip()
+ if text:
+ sys_chunks.append(text)
+ else:
+ rest.append(m)
+ return [dict(template, content='\n\n'.join(sys_chunks))] + rest
+
+
+def trim_to_last_assistant(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Trim trailing messages so the conversation ends with an assistant that has visible content."""
+ for i in range(len(messages) - 1, -1, -1):
+ m = messages[i]
+ if isinstance(m, dict) and m.get('role') == 'assistant' and msg_has_payload(m):
+ return messages[:i + 1]
+ return []
+
+
+# ══════════════════════════════════════════════════════════════════════════════
+# Check functions: (messages, is_agent, cfg) -> bool (True = pass)
+# ══════════════════════════════════════════════════════════════════════════════
+
+
+def check_role_order(messages: List[Dict[str, Any]], is_agent: bool, cfg: dict) -> bool:
+ """Validate conversational role ordering."""
+ if not messages:
+ return False
+ seen_user = False
+ seen_assistant = False
+ saw_first_non_system = False
+ for i, m in enumerate(messages):
+ if not isinstance(m, dict):
+ return False
+ role = m.get('role')
+ if role not in _VALID_ROLES:
+ return False
+ if role == 'system':
+ if i != 0:
+ return False
+ continue
+ if not saw_first_non_system:
+ if role != 'user':
+ return False
+ saw_first_non_system = True
+ if role == 'user':
+ seen_user = True
+ elif role == 'assistant':
+ if not seen_user:
+ return False
+ seen_assistant = True
+ elif role == 'tool':
+ if is_agent:
+ if not seen_assistant:
+ return False
+ else:
+ prev = messages[i - 1] if i > 0 else None
+ if not isinstance(prev, dict):
+ return False
+ prev_role = prev.get('role')
+ if prev_role not in ('assistant', 'tool'):
+ return False
+ if prev_role == 'assistant' and not normalize_tool_calls(prev):
+ return False
+ return True
+
+
+def check_tool_matching(messages: List[Dict[str, Any]], is_agent: bool, cfg: dict) -> bool:
+ """Tool_call_id matching.
+
+ Non-agent: bidirectional strict equality (every call has response AND vice versa).
+ Agent: forward-only (every tool message must reference an existing call).
+ """
+ all_call_ids: set = set()
+ i = 0
+ while i < len(messages):
+ m = messages[i]
+ if not isinstance(m, dict) or m.get('role') != 'assistant':
+ i += 1
+ continue
+ norm_tcs = normalize_tool_calls(m)
+ if not norm_tcs:
+ i += 1
+ continue
+ expected_ids = {tc['id'] for tc in norm_tcs if isinstance(tc, dict) and tc.get('id')}
+ if not expected_ids:
+ i += 1
+ continue
+ all_call_ids.update(expected_ids)
+ actual_ids: set = set()
+ j = i + 1
+ while j < len(messages):
+ nxt = messages[j]
+ if not isinstance(nxt, dict) or nxt.get('role') != 'tool':
+ break
+ tid = nxt.get('tool_call_id')
+ if tid:
+ actual_ids.add(tid)
+ j += 1
+ if not is_agent:
+ if actual_ids != expected_ids:
+ return False
+ i = j
+ # Agent forward check: every tool message's tool_call_id must exist in some assistant's tool_calls.
+ if is_agent and all_call_ids:
+ for m in messages:
+ if isinstance(m, dict) and m.get('role') == 'tool':
+ tid = m.get('tool_call_id')
+ if tid and tid not in all_call_ids:
+ return False
+ return True
+
+
+def check_content_integrity(messages: List[Dict[str, Any]], is_agent: bool, cfg: dict) -> bool:
+ """Min turns, max length, duplicate detection, tool_calls structural validity."""
+ min_turns = cfg.get('min_turns', 2)
+ max_msg_chars = cfg.get('max_msg_chars', 50000)
+ user_count = 0
+ assistant_count = 0
+ for i, m in enumerate(messages):
+ if not isinstance(m, dict):
+ return False
+ role = m.get('role')
+ content = msg_content_text(m)
+ norm_tcs = normalize_tool_calls(m)
+ if role == 'user':
+ user_count += 1
+ elif role == 'assistant':
+ assistant_count += 1
+ if not content.strip() and not norm_tcs:
+ return False
+ elif role == 'system' and not content.strip():
+ return False
+ if content and len(content) > max_msg_chars:
+ return False
+ if norm_tcs is not None:
+ for tc in norm_tcs:
+ func = tc.get('function')
+ name = func.get('name', '') if isinstance(func, dict) else ''
+ if not name or not _IDENTIFIER_RE.match(name):
+ return False
+ args = func.get('arguments') if isinstance(func, dict) else None
+ if isinstance(args, str):
+ try:
+ json.loads(args)
+ except (ValueError, json.JSONDecodeError):
+ return False
+ # Consecutive-duplicate detection — skip tool messages and messages carrying REAL tool_calls.
+ if i > 0 and role != 'tool' and norm_tcs is None and content:
+ prev = messages[i - 1]
+ if (isinstance(prev, dict) and prev.get('role') == role and normalize_tool_calls(prev) is None
+ and msg_content_text(prev) == content):
+ return False
+ if user_count < 1 or assistant_count < 1:
+ return False
+ if (user_count + assistant_count) < min_turns:
+ return False
+ return True
+
+
+def check_lang_match(messages: List[Dict[str, Any]], is_agent: bool, cfg: dict) -> bool:
+ """False if user is CJK-dominant but assistant is pure Latin (or vice versa)."""
+ cjk_threshold = 0.3
+ mismatch_threshold = 0.02
+ user_text = ''
+ asst_text = ''
+ for m in messages:
+ if not isinstance(m, dict):
+ continue
+ if m.get('role') == 'user':
+ user_text += msg_content_text(m)
+ elif m.get('role') == 'assistant':
+ asst_text += msg_content_text(m)
+ if len(asst_text) < 50:
+ return True
+ user_cjk = cjk_ratio(user_text)
+ asst_cjk = cjk_ratio(asst_text)
+ if user_cjk >= cjk_threshold and asst_cjk < mismatch_threshold:
+ return False
+ if user_cjk < mismatch_threshold and asst_cjk >= cjk_threshold:
+ return False
+ return True
+
+
+def check_agent_min_visible(messages: List[Dict[str, Any]], is_agent: bool, cfg: dict) -> bool:
+ """Agent rows must have minimum visible text across all assistant turns."""
+ if not is_agent:
+ return True
+ min_chars = cfg.get('min_agent_visible_chars', 200)
+ if min_chars <= 0:
+ return True
+ total = 0
+ for m in messages:
+ if isinstance(m, dict) and m.get('role') == 'assistant':
+ total += len(msg_content_text(m).strip())
+ rc = m.get('reasoning_content')
+ if isinstance(rc, str):
+ total += len(rc.strip())
+ return total >= min_chars
+
+
+def check_sensitive_words(messages: List[Dict[str, Any]], is_agent: bool, cfg: dict) -> bool:
+ """False if any message content matches the sensitive-word regex."""
+ regex = cfg.get('_sensitive_re')
+ if not regex:
+ return True
+ return not any(regex.search(msg_content_text(m)) for m in messages if isinstance(m, dict))
+
+
+# ══════════════════════════════════════════════════════════════════════════════
+# Filter class
+# ══════════════════════════════════════════════════════════════════════════════
+
+_DEFAULT_CHECKS = [
+ ('role_order', check_role_order),
+ ('tool_matching', check_tool_matching),
+ ('content_integrity', check_content_integrity),
+ ('lang_match', check_lang_match),
+ ('agent_min_visible', check_agent_min_visible),
+ ('sensitive_words', check_sensitive_words),
+]
+
+
+class MessageSanityFilter(Preprocessor):
+ """Structural and content sanity filter for messages-format datasets.
+
+ Each check is a named function returning True (pass) or False (drop), and
+ is individually enable-able via the constructor flags.
+ """
+
+ def __init__(
+ self,
+ check_role_order: bool = True,
+ check_tool_matching: bool = True,
+ check_content_integrity: bool = True,
+ check_lang_match: bool = True,
+ check_agent_min_visible: bool = True,
+ trim_to_assistant: bool = True,
+ filter_sensitive: bool = True,
+ sensitive_words_file: Optional[str] = None,
+ extra_sensitive_words: Optional[List[str]] = None,
+ min_turns: int = 2,
+ max_msg_chars: int = 80000,
+ min_agent_visible_chars: int = 50,
+ ) -> None:
+ super().__init__()
+ self._trim = trim_to_assistant
+
+ words = load_sensitive_words(sensitive_words_file) if sensitive_words_file else set()
+ if extra_sensitive_words:
+ words.update(w.strip() for w in extra_sensitive_words if w and w.strip())
+
+ self._cfg: Dict[str, Any] = {
+ 'min_turns': min_turns,
+ 'max_msg_chars': max_msg_chars,
+ 'min_agent_visible_chars': min_agent_visible_chars,
+ '_sensitive_re': build_sensitive_regex(words),
+ }
+
+ enabled = {
+ 'role_order': check_role_order,
+ 'tool_matching': check_tool_matching,
+ 'content_integrity': check_content_integrity,
+ 'lang_match': check_lang_match,
+ 'agent_min_visible': check_agent_min_visible,
+ 'sensitive_words': filter_sensitive,
+ }
+ self._checks = [(name, fn) for name, fn in _DEFAULT_CHECKS if enabled.get(name, True)]
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ for row in rows:
+ messages = row.get('messages')
+ if not isinstance(messages, list) or not messages:
+ dropped.append(dict(row, drop_reason='invalid_messages'))
+ continue
+ is_agent = is_agent_row(messages)
+
+ normalized = consolidate_system_messages(messages)
+ if normalized is not messages:
+ messages = normalized
+ row = dict(row, messages=messages)
+
+ if self._trim:
+ messages = trim_to_last_assistant(messages)
+ if not messages:
+ dropped.append(dict(row, drop_reason='no_assistant'))
+ continue
+ row = dict(row, messages=messages)
+
+ reason = self._run_checks(messages, is_agent)
+ if reason is None:
+ out.append(row)
+ else:
+ dropped.append(dict(row, drop_reason=reason))
+ return out, dropped
+
+ def _run_checks(self, messages: List[Dict[str, Any]], is_agent: bool) -> Optional[str]:
+ for name, fn in self._checks:
+ if not fn(messages, is_agent, self._cfg):
+ return name
+ return None
diff --git a/src/twinkle_agentic/preprocessor/model_filter.py b/src/twinkle_agentic/preprocessor/model_filter.py
new file mode 100644
index 000000000..fe238b1ed
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/model_filter.py
@@ -0,0 +1,40 @@
+import re
+from typing import Any, Dict, List, Optional, Sequence, Tuple
+
+from twinkle.preprocessor import Preprocessor
+
+# Each entry is the discriminating prefix only; a shared variant tail is appended uniformly
+# so suffixes like -Instruct, -Thinking-2507, -Distill-Qwen-7B, -Air are accepted everywhere.
+_DEFAULT_PATTERNS = [
+ r'minimax/minimax-m[23][\d.]*',
+ r'opengvlab/internvl[\d._]+-2\d{2}b',
+ r'qwen/qwen3[\d.]*-[123]\d{2}b(?:-a\d+b)?',
+ r'qwen/qwen3-coder',
+ r'xiaomimimo/mimo-v[\d.]+',
+ r'(?:zhipuai|z-ai)/glm-[56][\d.]*',
+ r'deepseek-ai/deepseek-(?:r1|v[34])',
+ r'moonshotai/kimi',
+ r'stepfun-ai/step',
+]
+
+# Hyphen MUST be inside the class (e.g. Kimi-K2-Instruct), '*' allows bare-prefix matches.
+_VARIANT_TAIL = r'[-\w.]*'
+
+
+class ModelFilter(Preprocessor):
+ """Keep only rows whose model_id matches an allowed family (case-insensitive)."""
+
+ def __init__(self, patterns: Optional[Sequence[str]] = None, field: str = 'model_id'):
+ self._field = field
+ pats = patterns if patterns is not None else _DEFAULT_PATTERNS
+ self._re = re.compile('|'.join(f'(?:{p}{_VARIANT_TAIL})' for p in pats), re.IGNORECASE)
+
+ def __call__(self, rows: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ kept, dropped = [], []
+ for r in rows:
+ if self._re.fullmatch(r.get(self._field) or ''):
+ kept.append(r)
+ else:
+ dropped.append(dict(r, drop_reason='model_not_allowed'))
+ return kept, dropped
diff --git a/src/twinkle_agentic/preprocessor/pii_presidio_filter.py b/src/twinkle_agentic/preprocessor/pii_presidio_filter.py
new file mode 100644
index 000000000..9dafd061f
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/pii_presidio_filter.py
@@ -0,0 +1,404 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Multi-language, multi-country PII rewriter via Presidio + spaCy NER + Faker.
+
+Coverage:
+ Names/Locations/Orgs: PERSON, LOCATION, ORGANIZATION (NER, en + zh)
+ Network/contact: EMAIL_ADDRESS, IP_ADDRESS, URL
+ Finance: CREDIT_CARD (Luhn), IBAN_CODE, CRYPTO, US_BANK_NUMBER, CN_BANK
+ Government IDs: US_SSN, US_ITIN, US_PASSPORT, US_DRIVER_LICENSE,
+ UK_NHS, UK_NINO, IN_AADHAAR, IN_PAN, AU_ABN, SG_NRIC,
+ IT_FISCAL_CODE, ES_NIF, ES_NIE, CN_ID
+ Phones: PHONE_NUMBER (libphonenumber), CN_PHONE, CN_LANDLINE
+ Other: DATE_TIME, MEDICAL_LICENSE, NRP
+
+Strategies (per entity, configurable via ``entity_strategy``):
+ ``mask`` -> keep edges, mask middle (numeric IDs/cards)
+ ``replace`` -> Faker fake value (names/emails — preserves text fluency)
+ ``redact`` -> drop the span entirely
+ ``hash`` -> sha256 prefix (deterministic, deidentified, joinable)
+
+Consistency: same source value → same fake value within a batch (and optionally
+across batches via ``persistent_consistency``), so dialogues stay coherent.
+"""
+import hashlib
+import threading
+from enum import Enum
+from typing import Any, Dict, List, Optional, Sequence, Tuple # noqa: F401
+
+from twinkle.preprocessor import Preprocessor
+
+# ─── Validators ─────────────────────────────────────────────────────────────────
+
+_ID_WEIGHTS = (7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2)
+_ID_CHECKS = '10X98765432'
+
+
+def _is_valid_cn_id(s: str) -> bool:
+ if len(s) != 18 or not s[:17].isdigit():
+ return False
+ total = sum(int(s[i]) * _ID_WEIGHTS[i] for i in range(17))
+ return _ID_CHECKS[total % 11] == s[17].upper()
+
+
+def _is_valid_luhn(s: str) -> bool:
+ digits = [int(c) for c in s if c.isdigit()]
+ if len(digits) < 13:
+ return False
+ checksum = 0
+ for i, d in enumerate(reversed(digits)):
+ if i % 2 == 1:
+ d = d * 2 - 9 if d * 2 > 9 else d * 2
+ checksum += d
+ return checksum % 10 == 0
+
+
+# ─── Replacement primitives ─────────────────────────────────────────────────────
+
+
+class Strategy(str, Enum):
+ MASK = 'mask'
+ REPLACE = 'replace'
+ REDACT = 'redact'
+ HASH = 'hash'
+
+ @classmethod
+ def coerce(cls, value: 'str | Strategy') -> 'Strategy':
+ try:
+ return cls(value) if not isinstance(value, cls) else value
+ except ValueError as e:
+ allowed = ', '.join(s.value for s in cls)
+ raise ValueError(f'Unknown strategy {value!r}. Allowed: {allowed}') from e
+
+
+def _mask_keep_edges(s: str, head: int = 3, tail: int = 4, ch: str = '*') -> str:
+ if len(s) <= head + tail:
+ return ch * len(s)
+ return s[:head] + ch * (len(s) - head - tail) + s[-tail:]
+
+
+def _hash_short(s: str, salt: str = '') -> str:
+ return hashlib.sha256((salt + s).encode('utf-8')).hexdigest()[:12]
+
+
+# ─── Faker dispatcher (per-instance, thread-safe) ───────────────────────────────
+
+
+class FakerProvider:
+ """Maps Presidio entity_type → Faker provider call, with lang-locale cache."""
+
+ _PROVIDER: Dict[str, Any] = {
+ 'PERSON': lambda f: f.name(),
+ 'LOCATION': lambda f: f.city(),
+ 'ORGANIZATION': lambda f: f.company(),
+ 'EMAIL_ADDRESS': lambda f: f.email(),
+ 'PHONE_NUMBER': lambda f: f.phone_number(),
+ 'CN_PHONE': lambda f: f.phone_number(),
+ 'CN_LANDLINE': lambda f: f.phone_number(),
+ 'IP_ADDRESS': lambda f: f.ipv4(),
+ 'URL': lambda f: f.url(),
+ 'IBAN_CODE': lambda f: f.iban(),
+ 'CREDIT_CARD': lambda f: f.credit_card_number(),
+ 'US_BANK_NUMBER': lambda f: f.credit_card_number(),
+ 'CN_BANK': lambda f: f.credit_card_number(),
+ 'CRYPTO': lambda f: f.sha256()[:34],
+ 'DATE_TIME': lambda f: str(f.date()),
+ }
+ _LOCALE: Dict[str, str] = {'zh': 'zh_CN', 'en': 'en_US'}
+
+ def __init__(self) -> None:
+ self._cache: Dict[str, Any] = {}
+ self._lock = threading.Lock()
+
+ def faker(self, lang: str):
+ if lang not in self._cache:
+ with self._lock:
+ if lang not in self._cache:
+ from faker import Faker
+ self._cache[lang] = Faker(self._LOCALE.get(lang, 'en_US'))
+ return self._cache[lang]
+
+ def fake_for(self, entity: str, original: str, lang: str) -> str:
+ f = self.faker(lang)
+ provider = self._PROVIDER.get(entity.upper())
+ if provider is not None:
+ return provider(f)
+ # Same-length opaque alnum for unknown entities; downstream length checks survive.
+ return f.bothify('?' * 2 + '#' * max(2, len(original) - 2)).upper()
+
+
+# ─── CN recognizers (module-level so they introspect/pickle cleanly) ────────────
+
+
+def _cn_recognizer_classes():
+ """Lazy-imported once; PatternRecognizer requires presidio_analyzer at import time."""
+ from presidio_analyzer import Pattern, PatternRecognizer
+
+ class CNIDRecognizer(PatternRecognizer):
+
+ def validate_result(self, pattern_text: str) -> bool:
+ return _is_valid_cn_id(pattern_text)
+
+ class CNBankRecognizer(PatternRecognizer):
+
+ def validate_result(self, pattern_text: str) -> bool:
+ return _is_valid_luhn(pattern_text)
+
+ return Pattern, PatternRecognizer, CNIDRecognizer, CNBankRecognizer
+
+
+def _build_cn_recognizers(languages: Sequence[str]) -> List[Any]:
+ Pattern, PatternRecognizer, CNIDRecognizer, CNBankRecognizer = _cn_recognizer_classes()
+ specs = [
+ ('CN_ID', r'(? None:
+ super().__init__()
+ self._require_deps()
+
+ self._languages: List[str] = list(languages)
+ self._spacy_models = dict(self.DEFAULT_SPACY_MODELS)
+ if spacy_models:
+ self._spacy_models.update(spacy_models)
+ for lang in self._languages:
+ if lang not in self._spacy_models:
+ raise ValueError(f'No spaCy model configured for language {lang!r}')
+
+ self._strategy = {k: Strategy.coerce(v) for k, v in self.DEFAULT_ENTITY_STRATEGY.items()}
+ if entity_strategy:
+ self._strategy.update({k.upper(): Strategy.coerce(v) for k, v in entity_strategy.items()})
+ self._default_strategy = Strategy.coerce(default_strategy)
+
+ self._score_threshold = score_threshold
+ self._roles = set(roles)
+ self._consistency = consistency
+ self._persistent_consistency = persistent_consistency
+ self._hash_salt = hash_salt
+ self._record_counts = record_counts
+
+ self._faker = FakerProvider()
+ self._persistent_map: Dict[Tuple[str, str], str] = {}
+ self._analyzer = self._build_analyzer()
+ # Restrict analyze() to entities we act on AND that the registry actually supports per language;
+ # avoids 'Entity X doesn't have the corresponding recognizer in language : Y' warnings.
+ wanted = {e for e in self._strategy if e not in self.IGNORED_ENTITIES}
+ registry = self._analyzer.registry
+ self._allowed_entities: Dict[str, List[str]] = {
+ lang: sorted(wanted & set(registry.get_supported_entities(languages=[lang])))
+ for lang in self._languages
+ }
+
+ # ── construction ────────────────────────────────────────────────────────
+
+ @classmethod
+ def _require_deps(cls) -> None:
+ try:
+ import faker # noqa: F401
+ import presidio_analyzer # noqa: F401
+ import presidio_anonymizer # noqa: F401
+ import spacy # noqa: F401
+ except ImportError as e:
+ raise ImportError(f'{e}. {cls.INSTALL_HINT}') from e
+
+ def _build_analyzer(self):
+ from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
+ from presidio_analyzer.nlp_engine import NlpEngineProvider
+
+ nlp_conf = {
+ 'nlp_engine_name': 'spacy',
+ 'models': [{
+ 'lang_code': lang,
+ 'model_name': self._spacy_models[lang]
+ } for lang in self._languages],
+ }
+ nlp_engine = NlpEngineProvider(nlp_configuration=nlp_conf).create_engine()
+ # NER pipe is the heaviest spaCy component and we discard all NER entities; disable to save 2-4x latency.
+ for nlp in getattr(nlp_engine, 'nlp', {}).values():
+ for pipe in ('ner', 'parser', 'attribute_ruler', 'lemmatizer'):
+ if pipe in nlp.pipe_names:
+ nlp.disable_pipe(pipe)
+ registry = RecognizerRegistry(supported_languages=self._languages)
+ registry.load_predefined_recognizers(languages=self._languages, nlp_engine=nlp_engine)
+ for r in _build_cn_recognizers(self._languages):
+ registry.add_recognizer(r)
+ return AnalyzerEngine(registry=registry, nlp_engine=nlp_engine, supported_languages=self._languages)
+
+ # ── language routing ────────────────────────────────────────────────────
+
+ def _resolve_language(self, text: str) -> str:
+ cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
+ guess = 'zh' if cjk / max(1, len(text)) > self.CJK_LANG_THRESHOLD else 'en'
+ return guess if guess in self._languages else self._languages[0]
+
+ # ── replacement ─────────────────────────────────────────────────────────
+
+ def _replacement_for(
+ self,
+ entity: str,
+ original: str,
+ lang: str,
+ local_map: Dict[Tuple[str, str], str],
+ ) -> str:
+ strategy = self._strategy.get(entity.upper(), self._default_strategy)
+ if strategy is Strategy.REDACT:
+ return ''
+ if strategy is Strategy.HASH:
+ return f'<{entity}:{_hash_short(original, self._hash_salt)}>'
+ if strategy is Strategy.MASK:
+ return _mask_keep_edges(original)
+ # Strategy.REPLACE — Faker with optional consistency cache.
+ if not self._consistency:
+ return self._faker.fake_for(entity, original, lang)
+ cache = self._persistent_map if self._persistent_consistency else local_map
+ key = (entity.upper(), original)
+ if key not in cache:
+ cache[key] = self._faker.fake_for(entity, original, lang)
+ return cache[key]
+
+ @classmethod
+ def _min_length(cls, entity: str) -> int:
+ return cls.DEFAULT_MIN_LENGTH.get(entity.upper(), cls.MIN_LENGTH_FALLBACK)
+
+ # ── span dedup ──────────────────────────────────────────────────────────
+
+ @staticmethod
+ def _dedupe_overlaps(results: List[Any]) -> List[Any]:
+ """Greedy interval scheduling: keep highest-score span per overlapping region."""
+ ordered = sorted(results, key=lambda r: (-r.score, -(r.end - r.start), r.start))
+ kept: List[Any] = []
+ for r in ordered:
+ if any(r.start < k.end and r.end > k.start for k in kept):
+ continue
+ kept.append(r)
+ return kept
+
+ # ── core scrubbing ──────────────────────────────────────────────────────
+
+ def _scrub_text(
+ self,
+ text: str,
+ local_map: Dict[Tuple[str, str], str],
+ ) -> Tuple[str, Dict[str, int]]:
+ if not text:
+ return text, {}
+ lang = self._resolve_language(text)
+ results = self._analyzer.analyze(
+ text=text, language=lang, entities=self._allowed_entities.get(lang), score_threshold=self._score_threshold)
+ if not results:
+ return text, {}
+
+ spans = self._dedupe_overlaps(results)
+ spans = [r for r in spans if r.entity_type.upper() not in self.IGNORED_ENTITIES]
+ spans = [r for r in spans if (r.end - r.start) >= self._min_length(r.entity_type)]
+ if not spans:
+ return text, {}
+ # Reverse-sort so in-place index slicing stays valid.
+ spans.sort(key=lambda r: r.start, reverse=True)
+ out = text
+ hits: Dict[str, int] = {}
+ for r in spans:
+ original = out[r.start:r.end]
+ replacement = self._replacement_for(r.entity_type, original, lang, local_map)
+ out = out[:r.start] + replacement + out[r.end:]
+ hits[r.entity_type] = hits.get(r.entity_type, 0) + 1
+ return out, hits
+
+ def _scrub_row(
+ self,
+ row: Dict[str, Any],
+ local_map: Dict[Tuple[str, str], str],
+ ) -> Dict[str, int]:
+ row_hits: Dict[str, int] = {}
+ for m in row.get('messages') or []:
+ if not isinstance(m, dict) or m.get('role') not in self._roles:
+ continue
+ content = m.get('content')
+ if not isinstance(content, str) or not content:
+ continue
+ new_content, hits = self._scrub_text(content, local_map)
+ if hits:
+ m['content'] = new_content
+ for k, v in hits.items():
+ row_hits[k] = row_hits.get(k, 0) + v
+ return row_hits
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ local_map: Dict[Tuple[str, str], str] = {}
+ for row in rows:
+ row_hits = self._scrub_row(row, local_map)
+ if self._record_counts:
+ if row_hits:
+ row['_pii_hits'] = row_hits
+ else:
+ row.pop('_pii_hits', None)
+ return rows, []
diff --git a/src/twinkle_agentic/preprocessor/refuse_filter.py b/src/twinkle_agentic/preprocessor/refuse_filter.py
new file mode 100644
index 000000000..434497eba
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/refuse_filter.py
@@ -0,0 +1,155 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Drop rows whose first assistant reply is a self-referential refusal."""
+import re
+from typing import Any, Dict, List, Tuple
+
+from twinkle.preprocessor import Preprocessor
+
+# ── English refusal patterns ──────────────────────────────────────────────────
+#
+# Design principle: require a SELF-REFERENTIAL subject (I/we) + a task-directed
+# inability/refusal verb. This avoids false positives on:
+# "I cannot stress enough…" "I cannot find the bug…"
+# "The API cannot handle null" "You cannot use this without auth"
+
+# I/we + modal inability + task verb.
+_EN_CORE = re.compile(
+ r'\b(i|we)\b.{0,25}\b('
+ r"can'?t|cannot|am\s+not\s+able|are\s+not\s+able|"
+ r"won'?t|will\s+not|am\s+unable|are\s+unable|"
+ r'must\s+decline|have\s+to\s+decline|'
+ r'decline\s+to|refuse\s+to|'
+ r'am\s+not\s+(allowed|permitted|authorized|comfortable)\s+to|'
+ r'are\s+not\s+(allowed|permitted|authorized)'
+ r')\b.{0,60}\b('
+ r'help|assist|answer|respond|provide|generate|create|produce|'
+ r'fulfill|comply|address|process|complete|handle|discuss|support'
+ r')\b',
+ re.IGNORECASE | re.DOTALL,
+)
+
+# Apology opener + refusal: "I'm sorry, but I can't…" / "Unfortunately I cannot…"
+_EN_APOLOGY = re.compile(
+ r'\b(i\'?m\s+sorry|i\s+apologize|unfortunately|i\s+regret)\b.{0,80}'
+ r'\b(can\'?t|cannot|unable|won\'?t|will\s+not|must\s+decline|have\s+to\s+decline|'
+ r'not\s+(allowed|able|comfortable|appropriate))\b',
+ re.IGNORECASE | re.DOTALL,
+)
+
+# Policy / content violation signal.
+_EN_POLICY = re.compile(
+ r'\b(this|that|your|the)\s+(request|question|prompt|content|topic|task)\b.{0,60}'
+ r'\b(violates?|goes?\s+against|is\s+(inappropriate|not\s+(appropriate|allowed|permitted|'
+ r'something\s+i\s+can)))\b',
+ re.IGNORECASE | re.DOTALL,
+)
+
+# Standalone declarative refusals. The trailing task-verb gate attaches only to the
+# "as an ai…" alternative — the first three alternatives are sufficiently specific
+# on their own ("I must decline", "this falls outside what I", "I refuse to").
+_EN_STANDALONE = re.compile(
+ r'\b(i|we)\s+(must|have\s+to|am\s+going\s+to|need\s+to)\s+(decline|refuse)\b'
+ r'|\b(i|we)\s+(decline|refuse)\s+(this|your|to)\b'
+ r'|\bthis\s+(falls\s+outside|is\s+outside|is\s+beyond)\s+(what\s+i|my)\b'
+ r'|\bas\s+an\s+ai[,.]?\s+i\s+(can\'?t|cannot|am\s+not\s+able|won\'?t)\b'
+ r'.{0,40}\b(help|assist|answer|respond|provide|generate|create|fulfill|comply|'
+ r'address|process|complete|handle|discuss|support)\b',
+ re.IGNORECASE | re.DOTALL, # DOTALL: refusal phrase + task verb may straddle a newline.
+)
+
+# ── Chinese refusal patterns ──────────────────────────────────────────────────
+
+_ZH_APOLOGY = re.compile(
+ r'(非常|十分|很|极为)?抱歉[,,。\s]{0,5}.{0,40}(无法|不能|不可以|不便|没有办法)|'
+ r'对不起[,,。\s]{0,5}.{0,40}(无法|不能|不可以|不便)',
+ re.UNICODE | re.DOTALL,
+)
+
+_ZH_SELF = re.compile(
+ r'我(无法|不能|不可以|没有办法|不便|不适合|不被允许|不被授权)'
+ r'.{0,30}(帮|回答|提供|生成|处理|协助|完成|执行|回复|解答|协|帮助)',
+ re.UNICODE | re.DOTALL,
+)
+
+_ZH_VIOLATION = re.compile(
+ r'(您的|这个|该)(请求|问题|内容|话题).{0,20}(违反|不当|不合适|超出了?我)',
+ re.UNICODE | re.DOTALL,
+)
+
+_ZH_AI_ID = re.compile(
+ r'作为(AI|人工智能|语言模型|大模型)[,,].{0,30}(无法|不能|不便|不应该|不适合)'
+ r'.{0,20}(帮|回答|提供|生成|处理|协助|完成|执行|回复|解答|讨论|参与|评论|创作|输出)',
+ re.UNICODE | re.DOTALL,
+)
+
+# ── Japanese refusal patterns ─────────────────────────────────────────────────
+
+_JA_PATTERNS = (
+ re.compile(r'(申し訳|恐れ入り)ます(が|けれど).{0,40}(できません|お答えできません|対応できません)', re.UNICODE | re.DOTALL),
+ re.compile(r'(回答|対応|お答え)(する|いたす)ことは?できません', re.UNICODE),
+ re.compile(r'ご要望には?お(応え|答え)できません', re.UNICODE),
+ re.compile(r'(その|この)(リクエスト|質問|依頼).{0,20}(お断り|辞退|対応できません)', re.UNICODE | re.DOTALL),
+)
+
+# ── Korean refusal patterns ───────────────────────────────────────────────────
+
+_KO_PATTERNS = (
+ re.compile(r'(죄송하지만|유감스럽게도).{0,40}(드릴 수 없|없습니다|못합니다)', re.UNICODE | re.DOTALL),
+ re.compile(r'(답변|도움|처리|제공)(드리기|하기)\s*(어렵|불가|할 수 없)', re.UNICODE),
+ re.compile(r'(요청|질문|내용).{0,20}(거절|거부|응할 수 없)', re.UNICODE | re.DOTALL),
+)
+
+_ALL_PATTERNS = (_EN_CORE, _EN_APOLOGY, _EN_POLICY, _EN_STANDALONE, _ZH_APOLOGY, _ZH_SELF, _ZH_VIOLATION,
+ _ZH_AI_ID) + _JA_PATTERNS + _KO_PATTERNS
+
+# Strip both `` and `` blocks before scanning so reasoning-trace
+# refusal-like phrasing doesn't get mistaken for a real user-facing refusal.
+_THINK_BLOCK_RE = re.compile(r'.*?\s*', re.DOTALL | re.IGNORECASE)
+
+
+# ── Helpers ──────────────────────────────────────────────────────────────────
+
+
+def _text(content: Any) -> str:
+ """Project content to plain text. Multimodal list → concat of text parts only."""
+ if isinstance(content, list):
+ return ''.join(p.get('text', '') for p in content if isinstance(p, dict) and p.get('type') == 'text')
+ return content if isinstance(content, str) else ''
+
+
+def _is_refusal(text: str, check_window: int) -> bool:
+ """Return True if the text contains a self-referential refusal signal."""
+ window = text[:check_window]
+ return any(p.search(window) for p in _ALL_PATTERNS)
+
+
+# ── Preprocessor ─────────────────────────────────────────────────────────────
+
+
+class RefuseFilter(Preprocessor):
+
+ def __init__(self, check_window: int = 600) -> None:
+ super().__init__()
+ self._check_window = check_window
+
+ def _is_refusal_row(self, row: Dict[str, Any]) -> bool:
+ messages = row.get('messages') or []
+ first_asst = next(
+ (m for m in messages if isinstance(m, dict) and m.get('role') == 'assistant'),
+ None,
+ )
+ if first_asst is None:
+ return False
+ reply = _THINK_BLOCK_RE.sub('', _text(first_asst.get('content'))).strip()
+ return bool(reply) and _is_refusal(reply, self._check_window)
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ for row in rows:
+ if self._is_refusal_row(row):
+ dropped.append(dict(row, drop_reason='refusal'))
+ else:
+ out.append(row)
+ return out, dropped
diff --git a/src/twinkle_agentic/preprocessor/score_filter.py b/src/twinkle_agentic/preprocessor/score_filter.py
new file mode 100644
index 000000000..17688c31f
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/score_filter.py
@@ -0,0 +1,801 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Pluggable per-round scorer/filter for SFT key rounds.
+
+Architecture:
+
+ ScoreFilter(backend, scorers=[...])
+ ├── pre-fetches logprobs once if any scorer requires them
+ ├── runs each Scorer in order, collecting ScoreResult per round
+ ├── trace dump (per-round JSON, multi_turn-style)
+ └── AND aggregation: a round is kept iff every scorer returns passed=True.
+
+Built-in scorers (each is its own class):
+ ChrMinScorer chr_dist_min_pos. LOW = hard = keep.
+ SIFDScorer IFD / S-IFD-50 / S-IFD-75. Default observe-only.
+ PassNScorer Self-rollouts judged by an LLM. extras carry rollouts/verdicts.
+ ParaphraseScorer chr_min over a model paraphrase produced under GT injection.
+
+Decoupling:
+ * key_rounds missing/empty → every assistant turn becomes a candidate round.
+ * intents=None → no intent-based gating (all rounds processed).
+"""
+import json
+import os
+import re
+import time
+from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
+
+from twinkle.data_format import pack_value, user_data_get
+from twinkle.preprocessor import Preprocessor
+from twinkle.template import Template
+from twinkle.utils import get_logger
+from ..data_format import RoundContext, Scorer, ScoreResult
+from .llm_backend import LLMBackend
+from .utils import _chr_min_distinct, _ifd_family_metrics, _lp_to_jsonable, _pad_batch, _to_int_list
+
+logger = get_logger(only_local_master=False)
+
+_MIN_RESPONSE_TOKENS = 5
+
+
+def _user_data_lookup(user_data: Any, key: str) -> Any:
+ """Pull a value by key from packed user_data; returns the JSON-decoded value."""
+ return user_data_get(user_data, key)
+
+# ============================================================================
+# Built-in scorers
+# ============================================================================
+
+
+class ChrMinScorer:
+ """chr_dist_min_pos. Dual-threshold: keep samples in [low, high)."""
+ name = 'chr_min'
+ requires_logprobs = True
+
+ def __init__(self, threshold: float = 0.47):
+ self._threshold = float(threshold)
+
+ def score(self, contexts: List[RoundContext]) -> List[ScoreResult]:
+ out: List[ScoreResult] = []
+ for ctx in contexts:
+ cond_lp = ctx.features.get('cond_lp')
+ asst_lp = ctx.features.get('asst_lp')
+ score = _chr_min_distinct(
+ cond_lp,
+ asst_lp,
+ ctx.cond_ids,
+ ctx.asst_ids,
+ ctx.n_prompt,
+ )
+ passed = (score is None) or (score < self._threshold)
+ out.append(ScoreResult(
+ score=score,
+ passed=passed,
+ extras={'threshold': self._threshold},
+ ))
+ return out
+
+
+class SIFDScorer:
+ """IFD / S-IFD-50 / S-IFD-75. Observation-only by default."""
+ name = 'sifd'
+ requires_logprobs = True
+
+ def __init__(self, ifd_threshold: Optional[float] = None):
+ # If set, passed = (ifd >= threshold). HIGH IFD = hard = keep.
+ self._ifd_threshold = ifd_threshold
+
+ def score(self, contexts: List[RoundContext]) -> List[ScoreResult]:
+ out: List[ScoreResult] = []
+ for ctx in contexts:
+ cond_lp = ctx.features.get('cond_lp')
+ asst_lp = ctx.features.get('asst_lp')
+ fam = _ifd_family_metrics(cond_lp, asst_lp, ctx.cond_ids, ctx.asst_ids, ctx.n_prompt)
+ score = fam.get('ifd')
+ if self._ifd_threshold is None or score is None:
+ passed = True
+ else:
+ passed = score >= self._ifd_threshold
+ out.append(ScoreResult(score=score, passed=passed, extras=dict(fam)))
+ return out
+
+
+_JUDGE_SYSTEM_PROMPT = """\
+You are a strict but fair answer grader. Judge whether the [Model Answer] is acceptable based on the reference answer (Ground Truth).
+Evaluate the following three aspects; if any has a major issue, return FAIL:
+
+1. Computational/factual correctness: whether the final conclusion, numbers, and key factual statements match the reference answer;
+2. Reasoning/approach similarity: whether the solution path, key steps, and considered dimensions are close to the reference answer;
+ For open-ended questions (no single correct answer), assess whether the style, stance, and considered dimensions align with the reference answer;
+3. Completeness: the answer is not truncated, ends naturally, and covers all points of the question.
+
+First give a brief 1-3 sentence justification, then on the last line strictly output:
+PASS or FAIL""" # noqa
+
+
+class PassNScorer:
+ """Self-rollouts (n × per round) judged by an LLM."""
+ name = 'pass_n'
+ requires_logprobs = False
+
+ def __init__(
+ self,
+ backend: LLMBackend,
+ judge_api=None,
+ judge_model: Optional[str] = None,
+ judge_base_url: Optional[str] = None,
+ judge_api_key: Optional[str] = None,
+ judge_client_kwargs: Optional[Dict[str, Any]] = None,
+ n: int = 4,
+ min_pass: int = 0,
+ sample_temperature: float = 0.7,
+ sample_max_tokens: int = 4096,
+ judge_temperature: float = 0.0,
+ judge_max_tokens: int = 512,
+ judge_max_rollout_chars: int = 8000,
+ judge_max_workers: int = 8,
+ ):
+ self._backend = backend
+ self._judge_api = self._build_judge_api(judge_api, judge_model, judge_base_url, judge_api_key,
+ judge_client_kwargs)
+ self._n = max(1, int(n))
+ self._min_pass = int(min_pass)
+ self._sample_temperature = float(sample_temperature)
+ self._sample_max_tokens = int(sample_max_tokens)
+ self._judge_temperature = float(judge_temperature)
+ self._judge_max_tokens = int(judge_max_tokens)
+ self._judge_max_rollout_chars = int(judge_max_rollout_chars)
+ self._judge_max_workers = max(1, int(judge_max_workers))
+ if self._judge_api is None:
+ logger.warning('[PassNScorer] no judge_api configured; rollouts will be sampled '
+ 'without verdicts (every round trivially passes).')
+
+ @staticmethod
+ def _build_judge_api(api, model, base_url, api_key, client_kwargs):
+ if api is not None:
+ return api
+ if not model:
+ return None
+ from twinkle_agentic.protocol.openai import OpenAI as OpenAIAPI
+ return OpenAIAPI(model=model, api_key=api_key, base_url=base_url, client_kwargs=client_kwargs)
+
+ @staticmethod
+ def _extract_text_from_choice(choice: Any) -> str:
+ if not isinstance(choice, dict):
+ return ''
+ parts: List[str] = []
+ rc = choice.get('reasoning_content')
+ if isinstance(rc, str) and rc.strip():
+ parts.append(f'\n{rc.strip()}\n')
+ content = choice.get('content')
+ if isinstance(content, str) and content.strip():
+ parts.append(content.strip())
+ if parts:
+ return '\n\n'.join(parts)
+ return content if isinstance(content, str) else ''
+
+ @staticmethod
+ def _truncate(text: str, max_chars: int) -> str:
+ if not isinstance(text, str) or max_chars <= 0 or len(text) <= max_chars:
+ return text
+ head = max_chars * 2 // 3
+ tail = max_chars - head - 32
+ if tail <= 0:
+ return text[:max_chars]
+ return text[:head] + '\n\n...[truncated]...\n\n' + text[-tail:]
+
+ @staticmethod
+ def _parse_verdict(judge_text: str) -> Optional[bool]:
+ if not isinstance(judge_text, str):
+ return None
+ compact = ''.join(judge_text.upper().split())
+ has_pass = 'PASS' in compact
+ has_fail = 'FAIL' in compact
+ if has_pass and not has_fail:
+ return True
+ if has_fail and not has_pass:
+ return False
+ # Fallback: keyword scan in the tail (last 200 chars, post-compact).
+ tail = compact[-200:]
+ if 'PASS' in tail and 'FAIL' not in tail:
+ return True
+ if 'FAIL' in tail and 'PASS' not in tail:
+ return False
+ return None
+
+ def _judge_one(self, user_prompt: str, gt_text: str, rollout_text: str) -> Tuple[bool, str]:
+ if self._judge_api is None:
+ return True, '(no judge configured)'
+ if not rollout_text or not rollout_text.strip():
+ return False, '(empty rollout)'
+ from twinkle.data_format.sampling import SamplingParams
+ body = (f'[问题]\n{self._truncate(user_prompt, self._judge_max_rollout_chars)}\n\n'
+ f'[参考答案]\n{self._truncate(gt_text, self._judge_max_rollout_chars)}\n\n'
+ f'[模型回答]\n{self._truncate(rollout_text, self._judge_max_rollout_chars)}\n\n'
+ '请评分。')
+ trajectory = {
+ 'messages': [
+ {
+ 'role': 'system',
+ 'content': _JUDGE_SYSTEM_PROMPT
+ },
+ {
+ 'role': 'user',
+ 'content': body
+ },
+ ]
+ }
+ sp = SamplingParams(
+ temperature=self._judge_temperature,
+ max_tokens=self._judge_max_tokens,
+ num_samples=1,
+ )
+ # extra_body forwards `enable_thinking=False` so the judge skips CoT.
+ msg = self._judge_api(trajectory, sp, extra_body={'enable_thinking': False})
+ if isinstance(msg, list):
+ msg = msg[0] if msg else {}
+ text = msg.get('content', '') if isinstance(msg, dict) else str(msg)
+ text = text or ''
+ verdict = self._parse_verdict(text)
+ # Conservative default: ambiguous verdict → FAIL.
+ return bool(verdict) if verdict is not None else False, text
+
+ def score(self, contexts: List[RoundContext]) -> List[ScoreResult]:
+ if not contexts:
+ return []
+ ctx_msgs = [ctx.context_messages for ctx in contexts]
+ batched = self._backend.chat_batch(
+ ctx_msgs,
+ temperature=self._sample_temperature,
+ max_tokens=self._sample_max_tokens,
+ n=self._n,
+ ) or []
+
+ while len(batched) < len(contexts):
+ batched.append([])
+
+ from concurrent.futures import ThreadPoolExecutor
+ work: List[Tuple[int, int, str, str, str]] = []
+ for i, (ctx, choices) in enumerate(zip(contexts, batched)):
+ if not isinstance(choices, list):
+ continue
+ for r_i, choice in enumerate(choices):
+ rt = self._extract_text_from_choice(choice)
+ work.append((i, r_i, ctx.user_prompt, ctx.asst_text, rt))
+
+ verdict_by_round: Dict[int, List[Tuple[int, bool, str]]] = {}
+ if work and self._judge_api is not None:
+
+ def _do(item):
+ i, r_i, up, gt, rt = item
+ ok, raw = self._judge_one(up, gt, rt)
+ return i, r_i, ok, raw
+
+ with ThreadPoolExecutor(max_workers=self._judge_max_workers) as ex:
+ for i, r_i, ok, raw in ex.map(_do, work):
+ verdict_by_round.setdefault(i, []).append((r_i, ok, raw))
+
+ out: List[ScoreResult] = []
+ for i, (ctx, choices) in enumerate(zip(contexts, batched)):
+ rollouts = [{
+ 'rollout_idx': r_i,
+ 'content': self._extract_text_from_choice(c)
+ } for r_i, c in enumerate(choices or [])]
+ verdicts = sorted(verdict_by_round.get(i, []), key=lambda x: x[0])
+ judgments = [{'rollout_idx': r_i, 'passed': bool(p), 'judge_raw': raw} for r_i, p, raw in verdicts]
+ pass_count = sum(1 for _, p, _ in verdicts if p)
+ score = (pass_count / self._n) if rollouts else None
+ passed = pass_count >= self._min_pass
+ out.append(
+ ScoreResult(
+ score=score,
+ passed=passed,
+ extras={
+ 'pass_count': pass_count,
+ 'n_rollouts': len(rollouts),
+ 'rollouts': rollouts,
+ 'judgments': judgments,
+ 'min_pass': self._min_pass,
+ },
+ ))
+
+ scored = [r for r in out if r.score is not None]
+ if scored:
+ avg = sum(r.score for r in scored) / len(scored)
+ logger.info(f'[PassNScorer] graded {len(scored)}/{len(out)} rounds × {self._n} '
+ f'rollouts; avg pass-rate = {avg:.3f}')
+ return out
+
+
+class ParaphraseScorer:
+ """Generate a model paraphrase under GT injection, then re-score chr_min."""
+ name = 'paraphrase'
+ # Owns its own logprob fetch on the rewritten asst tokens.
+ requires_logprobs = False
+
+ def __init__(
+ self,
+ backend: LLMBackend,
+ template: Template,
+ chr_min_threshold: Optional[float] = None,
+ prompt_budget: int = 4096,
+ sample_temperature: float = 0.7,
+ sample_max_tokens: int = 4096,
+ max_prompt_tokens: int = 1024,
+ ):
+ self._backend = backend
+ self._template = template
+ self._threshold = chr_min_threshold
+ self._prompt_budget = int(prompt_budget)
+ self._sample_temperature = float(sample_temperature)
+ self._sample_max_tokens = int(sample_max_tokens)
+ self._max_prompt_tokens = int(max_prompt_tokens)
+
+ @staticmethod
+ def _inject_gt(context_messages, gt_text):
+ msgs = [dict(m) if isinstance(m, dict) else m for m in context_messages]
+ instr = f"""\
+Below is the reference answer to this question, for your reference only:
+
+
+{gt_text}
+
+
+Based on the reference answer above, please provide a complete answer to the preceding question in your own words and reasoning. Output your answer directly; do not repeat the reference answer verbatim.""" # noqa
+ if msgs and isinstance(msgs[-1], dict) and msgs[-1].get('role') == 'user':
+ last = dict(msgs[-1])
+ last['content'] = (last.get('content') or '') + '\n\n' + instr
+ msgs[-1] = last
+ else:
+ msgs.append({'role': 'user', 'content': instr})
+ return msgs
+
+ def _truncate_gt(self, gt_text: str, n_prompt: int) -> Optional[str]:
+ # 80 = conservative instruction-template overhead.
+ budget = self._prompt_budget - n_prompt - 80
+ if budget < 50:
+ return None
+ gt_ids = _to_int_list(self._template.tokenizer(gt_text, add_special_tokens=False)['input_ids'])
+ if len(gt_ids) <= budget:
+ return gt_text
+ return self._template.tokenizer.decode(gt_ids[:budget], skip_special_tokens=False)
+
+ def _encode_prompt(self, ctx_msgs):
+ ids = _to_int_list(self._template.encode({'messages': list(ctx_msgs)}, add_generation_prompt=True)['input_ids'])
+ if self._max_prompt_tokens <= 0 or len(ids) <= self._max_prompt_tokens:
+ return ids
+ return ids[-self._max_prompt_tokens:]
+
+ def score(self, contexts: List[RoundContext]) -> List[ScoreResult]:
+ if not contexts:
+ return []
+
+ keys: List[int] = []
+ augmented: List[List[Dict[str, Any]]] = []
+ for i, ctx in enumerate(contexts):
+ gt = self._truncate_gt(ctx.asst_text, ctx.n_prompt)
+ if gt is None or not ctx.context_messages:
+ continue
+ keys.append(i)
+ augmented.append(self._inject_gt(ctx.context_messages, gt))
+
+ out: List[ScoreResult] = [
+ ScoreResult(score=None, passed=True, extras={'reason': 'paraphrase skipped'}) for _ in contexts
+ ]
+ if not keys:
+ return out
+
+ batched = self._backend.chat_batch(
+ augmented,
+ temperature=self._sample_temperature,
+ max_tokens=self._sample_max_tokens,
+ n=1,
+ ) or []
+
+ # Re-tokenize against the ORIGINAL (no-GT) context so logprobs reflect
+ # pure self-conditional probability of the paraphrase.
+ para_data: Dict[int, Tuple[List[int], int, List[int], str]] = {}
+ for i, choices in zip(keys, batched):
+ text = None
+ if choices:
+ c0 = choices[0]
+ if isinstance(c0, dict):
+ text = c0.get('content')
+ if not isinstance(text, str) or not text.strip():
+ continue
+ ctx = contexts[i]
+ prompt_ids = self._encode_prompt(ctx.context_messages)
+ asst_ids = _to_int_list(self._template.tokenizer(text, add_special_tokens=False)['input_ids'])
+ if len(asst_ids) < _MIN_RESPONSE_TOKENS + 1:
+ continue
+ cond_ids = prompt_ids + asst_ids
+ para_data[i] = (cond_ids, len(prompt_ids), asst_ids, text)
+
+ if not para_data:
+ return out
+
+ ordered = list(para_data.keys())
+ cond_batch = [para_data[i][0] for i in ordered]
+ asst_batch = [para_data[i][2] for i in ordered]
+ cond_lps = self._backend.prompt_logprobs_ids(cond_batch)
+ asst_lps = self._backend.prompt_logprobs_ids(asst_batch)
+
+ for i, cond_lp, asst_lp in zip(ordered, cond_lps, asst_lps):
+ cond_ids, n_prompt, asst_ids, text = para_data[i]
+ score = _chr_min_distinct(cond_lp, asst_lp, cond_ids, asst_ids, n_prompt)
+ if self._threshold is None or score is None:
+ passed = True
+ else:
+ passed = score < self._threshold
+ out[i] = ScoreResult(
+ score=score,
+ passed=passed,
+ extras={
+ 'paraphrase_text': text,
+ 'n_prompt': n_prompt,
+ 'cond_lp': _lp_to_jsonable(cond_lp),
+ 'asst_lp': _lp_to_jsonable(asst_lp),
+ 'threshold': self._threshold,
+ },
+ )
+
+ logger.info(f'[ParaphraseScorer] paraphrased + scored {len(para_data)}/'
+ f'{len(contexts)} rounds')
+ return out
+
+
+# ============================================================================
+# ScoreFilter (Preprocessor entry point)
+# ============================================================================
+
+
+class ScoreFilter(Preprocessor):
+ """Score and filter assistant turns by a pluggable scorer set.
+
+ A round is kept iff every scorer returns ``passed=True``. Rows that lose
+ all key rounds are dropped (configurable via ``keep_if_no_key_rounds``).
+
+ Decoupling rules:
+ * `key_rounds` missing/empty in `user_data` → every assistant turn
+ becomes a candidate round.
+ * `intents=None` → no intent-based gating.
+ """
+
+ def __init__(
+ self,
+ template: Template,
+ backend: LLMBackend,
+ scorers: List[Scorer],
+ intents: Optional[Iterable[str]] = None,
+ keep_if_no_key_rounds: bool = False,
+ drop_row_on_any_fail: bool = True,
+ max_prompt_tokens: int = 1024,
+ trace_dir: Optional[str] = None,
+ trace_callback: Optional[Callable[[Dict[str, Any]], bool]] = None,
+ success_callback: Optional[Callable[[Dict[str, Any]], bool]] = None,
+ ):
+ super().__init__()
+ if not isinstance(template, Template):
+ raise TypeError(f'ScoreFilter requires a `Template` instance, got '
+ f'{type(template).__name__}.')
+ self._template = template
+ self._backend = backend
+ self._scorers = list(scorers)
+ self._intents: Optional[Set[str]] = (None if intents is None else set(intents))
+ self._keep_if_no_key_rounds = bool(keep_if_no_key_rounds)
+ self._drop_row_on_any_fail = bool(drop_row_on_any_fail)
+ self._max_prompt_tokens = int(max_prompt_tokens)
+ self._trace_dir = trace_dir
+ self._trace_callback = trace_callback
+ self._success_callback = success_callback
+ if self._trace_dir:
+ import shutil
+ if os.path.exists(self._trace_dir):
+ shutil.rmtree(self._trace_dir)
+ os.makedirs(self._trace_dir, exist_ok=True)
+
+ def __call__(self, rows):
+ rows_list = self.map_col_to_row(rows)
+ contexts = self._build_contexts(rows_list)
+ dropped: List[Dict[str, Any]] = []
+ if contexts:
+ score_table = self._score_contexts(contexts)
+ self._log_score_summary(contexts, score_table)
+ if self._trace_dir:
+ self._write_traces(contexts, score_table)
+ rows_list, dropped = self._apply_filter(rows_list, contexts, score_table)
+ return rows_list, dropped
+
+ def _log_score_summary(self, contexts, score_table):
+ for scorer in self._scorers:
+ scores = [
+ t[scorer.name].score for t in score_table if scorer.name in t and t[scorer.name].score is not None
+ ]
+ if not scores:
+ continue
+ n_pass = sum(1 for t in score_table if scorer.name in t and t[scorer.name].passed)
+ extras_sample = {}
+ for t in score_table:
+ if scorer.name in t and t[scorer.name].extras:
+ extras_sample = t[scorer.name].extras
+ break
+ extra_keys = [k for k in extras_sample if k != 'threshold']
+ extra_stats = ''
+ for k in extra_keys:
+ vals = [
+ t[scorer.name].extras.get(k) for t in score_table
+ if scorer.name in t and t[scorer.name].extras and t[scorer.name].extras.get(k) is not None
+ ]
+ if vals and isinstance(vals[0], (int, float)):
+ avg = sum(vals) / len(vals)
+ extra_stats += f', {k}_avg={avg:.4f}'
+ logger.info(f'[ScoreFilter/{scorer.name}] n={len(scores)}, '
+ f'mean={sum(scores) / len(scores):.4f}, '
+ f'min={min(scores):.4f}, max={max(scores):.4f}, '
+ f'pass={n_pass}/{len(score_table)}'
+ f'{extra_stats}')
+
+ # ---- scoring (inlined DefaultScoreCalculator) --------------------------
+
+ def _score_contexts(self, contexts: List[RoundContext]) -> List[Dict[str, ScoreResult]]:
+ if any(getattr(s, 'requires_logprobs', False) for s in self._scorers):
+ self._attach_logprobs(contexts)
+ out: List[Dict[str, ScoreResult]] = [dict() for _ in contexts]
+ for scorer in self._scorers:
+ results = scorer.score(contexts)
+ if len(results) != len(contexts):
+ raise RuntimeError(f'scorer {scorer.name!r} returned {len(results)} results '
+ f'for {len(contexts)} contexts')
+ for i, r in enumerate(results):
+ out[i][scorer.name] = r
+ return out
+
+ def _attach_logprobs(self, contexts: List[RoundContext]) -> None:
+ cond_batch = [ctx.cond_ids for ctx in contexts]
+ asst_batch = [ctx.asst_ids for ctx in contexts]
+ floor = self._batch_floor()
+ cond_padded, n_cond = _pad_batch(cond_batch, floor)
+ asst_padded, n_asst = _pad_batch(asst_batch, floor)
+ cond_lps = self._backend.prompt_logprobs_ids(cond_padded)[:n_cond]
+ asst_lps = self._backend.prompt_logprobs_ids(asst_padded)[:n_asst]
+ for ctx, c, a in zip(contexts, cond_lps, asst_lps):
+ ctx.features['cond_lp'] = c
+ ctx.features['asst_lp'] = a
+
+ def _batch_floor(self) -> int:
+ sampler = getattr(self._backend, '_sampler', None)
+ device_mesh = getattr(sampler, 'device_mesh', None)
+ return getattr(device_mesh, 'dp_world_size', 1) or 1
+
+ # ---- context construction --------------------------------------------
+
+ def _build_contexts(self, rows: List[Dict[str, Any]]) -> List[RoundContext]:
+ out: List[RoundContext] = []
+ for ri, row in enumerate(rows):
+ messages = row.get('messages') if isinstance(row, dict) else None
+ if not isinstance(messages, list):
+ continue
+ user_data = row.get('user_data') if isinstance(row, dict) else None
+ key_rounds = _user_data_lookup(user_data, 'key_rounds')
+ if not isinstance(key_rounds, list) or not key_rounds:
+ key_rounds = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get('role') == 'assistant']
+ for rnd_idx, asst_idx in enumerate(key_rounds):
+ if not isinstance(asst_idx, int):
+ continue
+ intent = self._lookup_intent(row, asst_idx)
+ if self._intents is not None and intent not in self._intents:
+ continue
+ ctx = self._prepare_round(row, messages, ri, rnd_idx, asst_idx, intent)
+ if ctx is not None:
+ out.append(ctx)
+ return out
+
+ def _prepare_round(
+ self,
+ row: Dict[str, Any],
+ messages: List[Dict[str, Any]],
+ ri: int,
+ rnd_idx: int,
+ asst_idx: int,
+ intent: Optional[str],
+ ) -> Optional[RoundContext]:
+ if not (0 <= asst_idx < len(messages)):
+ return None
+ asst_msg = messages[asst_idx]
+ if not isinstance(asst_msg, dict) or asst_msg.get('role') != 'assistant':
+ return None
+ asst_text = asst_msg.get('content') or ''
+ if isinstance(asst_text, list):
+ asst_text = ' '.join(
+ p.get('text', '') for p in asst_text if isinstance(p, dict) and p.get('type') == 'text')
+ if not asst_text.strip():
+ return None
+ context_messages = messages[:asst_idx]
+ if not context_messages:
+ return None
+ prompt_ids = self._encode_prompt_within_budget(context_messages)
+ # Raw asst_ids (no chat-template wrapping) so cond/asst share byte-equal
+ # A-token sequences; otherwise chr_min positions desync.
+ asst_ids = _to_int_list(self._template.tokenizer(asst_text, add_special_tokens=False)['input_ids'])
+ if len(asst_ids) < _MIN_RESPONSE_TOKENS + 1:
+ return None
+ return RoundContext(
+ row_idx=ri,
+ rnd_idx=rnd_idx,
+ asst_idx=asst_idx,
+ row=row,
+ intent=intent,
+ messages=messages,
+ context_messages=context_messages,
+ cond_ids=prompt_ids + asst_ids,
+ n_prompt=len(prompt_ids),
+ asst_ids=asst_ids,
+ asst_text=asst_text,
+ user_prompt=self._render_user_prompt(context_messages),
+ )
+
+ def _encode_prompt_within_budget(self, ctx_msgs: List[Dict[str, Any]]) -> List[int]:
+ ctx = list(ctx_msgs)
+ ids = _to_int_list(self._template.encode({'messages': ctx}, add_generation_prompt=True)['input_ids'])
+ budget = self._max_prompt_tokens
+ if budget <= 0 or len(ids) <= budget:
+ return ids
+ has_sys = bool(ctx) and isinstance(ctx[0], dict) and ctx[0].get('role') == 'system'
+ body_start = 1 if has_sys else 0
+ while len(ctx) - body_start > 1:
+ ctx.pop(body_start)
+ ids = _to_int_list(self._template.encode({'messages': ctx}, add_generation_prompt=True)['input_ids'])
+ if len(ids) <= budget:
+ return ids
+ # Single message still over budget → keep tail tokens.
+ return ids[-budget:]
+
+ @staticmethod
+ def _render_user_prompt(ctx_msgs: List[Dict[str, Any]]) -> str:
+ parts: List[str] = []
+ for m in ctx_msgs:
+ if not isinstance(m, dict):
+ continue
+ role = m.get('role') or 'user'
+ content = m.get('content', '')
+ if isinstance(content, list):
+ content = ' '.join(
+ p.get('text', '') for p in content if isinstance(p, dict) and p.get('type') == 'text')
+ if isinstance(content, str) and content.strip():
+ parts.append(f'[{role}] {content.strip()}')
+ return '\n\n'.join(parts)
+
+ @staticmethod
+ def _lookup_intent(row: Dict[str, Any], asst_idx: int) -> Optional[str]:
+ user_data = row.get('user_data') if isinstance(row, dict) else None
+ intents = _user_data_lookup(user_data, 'intents')
+ if not isinstance(intents, dict):
+ return None
+ v = intents.get(asst_idx)
+ if v is None:
+ v = intents.get(str(asst_idx))
+ return v if isinstance(v, str) else None
+
+ # ---- trace dump (multi_turn-style) -----------------------------------
+
+ def _write_traces(
+ self,
+ contexts: List[RoundContext],
+ score_table: List[Dict[str, ScoreResult]],
+ ) -> None:
+ for i, ctx in enumerate(contexts):
+ try:
+ scores = score_table[i] if i < len(score_table) else {}
+ kept = all(r.passed for r in scores.values()) if scores else True
+ record = self._build_trace_record(ctx, scores, kept)
+ if self._trace_callback is not None and not bool(self._trace_callback(record)):
+ continue
+ success = (bool(self._success_callback(record)) if self._success_callback is not None else kept)
+ prefix = 'ok' if success else 'fail'
+ rid = f'{ctx.row_idx}-{ctx.asst_idx}-{i}-{int(time.time() * 1000)}'
+ rid = re.sub(r'[^A-Za-z0-9_\-.]+', '_', rid)[:64]
+ path = os.path.join(self._trace_dir, f'{prefix}-{rid}.json')
+ with open(path, 'w', encoding='utf-8') as f:
+ json.dump(record, f, ensure_ascii=False, indent=2, default=str)
+ except Exception as e:
+ # Observability must never break filtering; surface the cause.
+ logger.warning(f'[ScoreFilter] trace dump failed for row={ctx.row_idx} '
+ f'asst={ctx.asst_idx}: {e}')
+
+ @staticmethod
+ def _build_trace_record(
+ ctx: RoundContext,
+ scores: Dict[str, ScoreResult],
+ kept: bool,
+ ) -> Dict[str, Any]:
+ return {
+ 'row_idx': ctx.row_idx,
+ 'rnd_idx': ctx.rnd_idx,
+ 'asst_idx': ctx.asst_idx,
+ 'intent': ctx.intent,
+ 'messages': ctx.messages,
+ 'n_prompt': ctx.n_prompt,
+ 'cond_ids': ctx.cond_ids,
+ 'asst_ids': ctx.asst_ids,
+ 'features': {
+ k: (_lp_to_jsonable(v) if k.endswith('_lp') else v)
+ for k, v in ctx.features.items()
+ },
+ 'scores': {
+ name: {
+ 'score': r.score,
+ 'passed': r.passed,
+ 'extras': r.extras
+ }
+ for name, r in scores.items()
+ },
+ 'kept': bool(kept),
+ }
+
+ # ---- aggregation & row reassembly ------------------------------------
+
+ def _apply_filter(
+ self,
+ rows: List[Dict[str, Any]],
+ contexts: List[RoundContext],
+ score_table: List[Dict[str, ScoreResult]],
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ per_row: Dict[int, Dict[str, Any]] = {}
+ for i, ctx in enumerate(contexts):
+ scores = score_table[i] if i < len(score_table) else {}
+ passed = all(r.passed for r in scores.values()) if scores else True
+ slot = per_row.setdefault(ctx.row_idx, {
+ 'kept': [],
+ 'failed': 0,
+ })
+ if passed:
+ slot['kept'].append(ctx.asst_idx)
+ else:
+ slot['failed'] += 1
+
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ n_removed_rounds = 0
+ n_removed_rows = 0
+ for ri, row in enumerate(rows):
+ user_data = row.get('user_data') if isinstance(row, dict) else None
+ kr_val = _user_data_lookup(user_data, 'key_rounds')
+ had_key_rounds = isinstance(kr_val, list) and bool(kr_val)
+ decision = per_row.get(ri)
+
+ if decision is None:
+ # Row produced no contexts (no asst turns or filtered by intent).
+ if had_key_rounds and not self._keep_if_no_key_rounds:
+ n_removed_rows += 1
+ dropped.append(dict(row, drop_reason='score_no_context'))
+ continue
+ if self._intents is not None and not self._keep_if_no_key_rounds:
+ n_removed_rows += 1
+ dropped.append(dict(row, drop_reason='score_no_context'))
+ continue
+ out.append(row)
+ continue
+
+ n_removed_rounds += decision['failed']
+ kept = decision['kept']
+ if had_key_rounds:
+ if not kept:
+ n_removed_rows += 1
+ dropped.append(dict(row, drop_reason='score_all_rounds_failed'))
+ continue
+ new_row = dict(row)
+ # Re-pack key_rounds; keep all other entries as-is (already packed).
+ rebuilt = [(k, v) for (k, v) in (user_data or []) if k != 'key_rounds']
+ rebuilt.append(('key_rounds', pack_value(list(kept))))
+ new_row['user_data'] = rebuilt
+ out.append(new_row)
+ else:
+ if decision['failed'] > 0 and self._drop_row_on_any_fail:
+ n_removed_rows += 1
+ dropped.append(dict(row, drop_reason='score_round_failed'))
+ continue
+ out.append(row)
+
+ logger.info(f'[ScoreFilter] removed {n_removed_rounds} rounds, '
+ f'dropped {n_removed_rows} rows, kept {len(out)}/{len(rows)}')
+ return out, dropped
diff --git a/src/twinkle_agentic/preprocessor/token_soup.py b/src/twinkle_agentic/preprocessor/token_soup.py
new file mode 100644
index 000000000..97df97eb6
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/token_soup.py
@@ -0,0 +1,152 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Token-soup / garbled output detector for assistant messages."""
+import re
+import unicodedata
+from typing import Any, Dict, List, Tuple
+
+from twinkle.preprocessor import Preprocessor
+
+from .utils import msg_content_text
+
+# ── Pre-compiled patterns ─────────────────────────────────────────────────────
+
+_REPLACEMENT_CHAR_RE = re.compile(r'\ufffd')
+
+# Non-printable control chars; \t \n \r kept as legitimate whitespace.
+_CONTROL_CHAR_RE = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]')
+
+_PRIVATE_USE_RE = re.compile(r'[\ue000-\uf8ff\U000f0000-\U000fffff\U00100000-\U0010ffff]')
+
+# Bracket BERT-style tokens stay case-sensitive via (?-i:...) — lowercase "[mask]"/"[pad]"
+# collide with bitmask-DP variables like dp[mask].
+_SPECIAL_TOKEN_RE = re.compile(
+ r'(<\|[^|>\n]{1,40}\|>||(?-i:\[/?(?:PAD|UNK|SEP|CLS|MASK)\])|?unk>|?pad>|<0x[0-9A-Fa-f]{2}>)',
+ re.IGNORECASE,
+)
+
+# Same printable char repeated 20+ times. Excludes ASCII rule chars, digits, box-drawing,
+# block elements, geometric shapes, braille, and dashes.
+_SINGLE_CHAR_REPEAT_RE = re.compile(
+ r'([^\s\n\-=_.\*\+~#|><0-9\u2013-\u2015\u2500-\u25ff\u2800-\u28ff\u30fc\uff0d])\1{19,}')
+
+
+# ── Unicode script classifier ─────────────────────────────────────────────────
+
+
+def _script_of(cp: int) -> str:
+ """Map a codepoint to a coarse script bucket."""
+ if cp <= 0x024F:
+ return 'latin'
+ if 0x0370 <= cp <= 0x03FF:
+ return 'greek'
+ if 0x0400 <= cp <= 0x04FF:
+ return 'cyrillic'
+ if 0x0590 <= cp <= 0x05FF:
+ return 'hebrew'
+ if 0x0600 <= cp <= 0x06FF:
+ return 'arabic'
+ if 0x0900 <= cp <= 0x097F:
+ return 'devanagari'
+ if 0x0E00 <= cp <= 0x0E7F:
+ return 'thai'
+ if 0x3040 <= cp <= 0x309F:
+ return 'hiragana'
+ if 0x30A0 <= cp <= 0x30FF:
+ return 'katakana'
+ if 0x4E00 <= cp <= 0x9FFF:
+ return 'cjk'
+ if 0xAC00 <= cp <= 0xD7A3:
+ return 'hangul'
+ if 0xE000 <= cp <= 0xF8FF:
+ return 'private'
+ return 'other'
+
+
+def _script_chaos(text: str, min_chars: int = 40) -> float:
+ """Fraction of adjacent letter/number pairs that switch script."""
+ chars = [c for c in text if unicodedata.category(c)[0] in ('L', 'N')]
+ if len(chars) < min_chars:
+ return 0.0
+ scripts = [_script_of(ord(c)) for c in chars]
+ switches = sum(a != b for a, b in zip(scripts, scripts[1:]))
+ return switches / (len(scripts) - 1)
+
+
+# ── Detector ──────────────────────────────────────────────────────────────────
+
+
+def _ratio(pattern: re.Pattern, text: str) -> float:
+ return len(pattern.findall(text)) / max(len(text), 1)
+
+
+def _is_token_soup(
+ text: str,
+ replacement_char_ratio: float = 0.02,
+ control_char_ratio: float = 0.01,
+ private_use_ratio: float = 0.03,
+ special_token_count: int = 20,
+ script_chaos_threshold: float = 0.55,
+ script_chaos_min_chars: int = 40,
+ max_chars: int = 0,
+) -> bool:
+ """True if `text` exhibits any garbled-output signal."""
+ if not text:
+ return False
+ # Token-soup signals are statistical/uniform; head sample is sufficient.
+ if max_chars and len(text) > max_chars:
+ text = text[:max_chars]
+ if _ratio(_REPLACEMENT_CHAR_RE, text) > replacement_char_ratio:
+ return True
+ if _ratio(_CONTROL_CHAR_RE, text) > control_char_ratio:
+ return True
+ if _ratio(_PRIVATE_USE_RE, text) > private_use_ratio:
+ return True
+ if len(_SPECIAL_TOKEN_RE.findall(text)) >= special_token_count:
+ return True
+ if _SINGLE_CHAR_REPEAT_RE.search(text):
+ return True
+ if _script_chaos(text, script_chaos_min_chars) > script_chaos_threshold:
+ return True
+ return False
+
+
+# ── Preprocessor ──────────────────────────────────────────────────────────────
+
+
+class TokenSoupFilter(Preprocessor):
+
+ def __init__(
+ self,
+ replacement_char_ratio: float = 0.02,
+ control_char_ratio: float = 0.01,
+ private_use_ratio: float = 0.03,
+ special_token_count: int = 20,
+ script_chaos_threshold: float = 0.55,
+ script_chaos_min_chars: int = 40,
+ max_chars: int = 0,
+ ) -> None:
+ super().__init__()
+ self._cfg = dict(
+ replacement_char_ratio=replacement_char_ratio,
+ control_char_ratio=control_char_ratio,
+ private_use_ratio=private_use_ratio,
+ special_token_count=special_token_count,
+ script_chaos_threshold=script_chaos_threshold,
+ script_chaos_min_chars=script_chaos_min_chars,
+ max_chars=max_chars,
+ )
+
+ def __call__(self, rows) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ rows = self.map_col_to_row(rows)
+ out: List[Dict[str, Any]] = []
+ dropped: List[Dict[str, Any]] = []
+ for row in rows:
+ asst_texts = [
+ msg_content_text(m).strip() for m in (row.get('messages') or [])
+ if isinstance(m, dict) and m.get('role') == 'assistant'
+ ]
+ if any(_is_token_soup(t, **self._cfg) for t in asst_texts):
+ dropped.append(dict(row, drop_reason='token_soup'))
+ else:
+ out.append(row)
+ return out, dropped
diff --git a/src/twinkle_agentic/preprocessor/utils.py b/src/twinkle_agentic/preprocessor/utils.py
new file mode 100644
index 000000000..1004b1c78
--- /dev/null
+++ b/src/twinkle_agentic/preprocessor/utils.py
@@ -0,0 +1,367 @@
+"""Pure helpers shared across preprocessor modules."""
+import json
+import math
+import os
+import re
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+
+def _extract_logprob(lp, token_id: Optional[int] = None) -> Optional[float]:
+ if lp is None:
+ return None
+ if isinstance(lp, (int, float)):
+ return float(lp)
+ if not isinstance(lp, dict):
+ return None
+ # vLLM with prompt_logprobs=1 returns top-1 PLUS actual token if they differ;
+ # actual is appended LAST, so iter-first picks the wrong (top-1) one.
+ entry = None
+ if token_id is not None:
+ entry = lp.get(token_id)
+ if entry is None:
+ entry = lp.get(str(token_id))
+ if entry is None:
+ entry = next(iter(lp.values()), None)
+ if entry is None:
+ return None
+ if hasattr(entry, 'logprob'):
+ return float(entry.logprob)
+ if isinstance(entry, dict):
+ v = entry.get('logprob')
+ return float(v) if v is not None else None
+ if isinstance(entry, (int, float)):
+ return float(entry)
+ return None
+
+
+def _to_int_list(x) -> List[int]:
+ if hasattr(x, 'tolist'):
+ return x.tolist()
+ return list(x)
+
+
+def _chr_min_distinct(
+ cond_lp: List,
+ asst_lp: List,
+ cond_ids: List[int],
+ asst_ids: List[int],
+ n_prompt: int,
+ exclude_ids: Optional[Set[int]] = None,
+) -> Optional[float]:
+ """chr_dist_min_pos: fraction of distinct asst-token ids whose
+ per-occurrence min(cond_lp - asst_lp) is strictly positive."""
+ if not asst_lp or not cond_lp or not asst_ids:
+ return None
+ n_a = min(len(asst_lp), len(asst_ids))
+ n_c = len(cond_lp)
+ by_tok: Dict[int, List[float]] = {}
+ for i in range(n_a):
+ ci = n_prompt + i
+ if ci >= n_c:
+ break
+ tid = asst_ids[i]
+ if tid is None:
+ continue
+ if exclude_ids is not None and int(tid) in exclude_ids:
+ continue
+ a = _extract_logprob(asst_lp[i], tid)
+ c_tok = cond_ids[ci] if ci < len(cond_ids) else None
+ c = _extract_logprob(cond_lp[ci], c_tok)
+ if a is None or c is None:
+ continue
+ by_tok.setdefault(int(tid), []).append(c - a)
+ if not by_tok:
+ return None
+ pos = sum(1 for diffs in by_tok.values() if min(diffs) > 0)
+ return pos / len(by_tok)
+
+
+def _chr_min_weighted(
+ cond_lp: List,
+ asst_lp: List,
+ cond_ids: List[int],
+ asst_ids: List[int],
+ n_prompt: int,
+) -> Optional[float]:
+ """Magnitude-weighted chr_min: each distinct token contributes |min_delta|
+ as weight; returns sum(pos_weights) / sum(all_weights)."""
+ if not asst_lp or not cond_lp or not asst_ids:
+ return None
+ n_a = min(len(asst_lp), len(asst_ids))
+ n_c = len(cond_lp)
+ by_tok: Dict[int, List[float]] = {}
+ for i in range(n_a):
+ ci = n_prompt + i
+ if ci >= n_c:
+ break
+ tid = asst_ids[i]
+ if tid is None:
+ continue
+ a = _extract_logprob(asst_lp[i], tid)
+ c_tok = cond_ids[ci] if ci < len(cond_ids) else None
+ c = _extract_logprob(cond_lp[ci], c_tok)
+ if a is None or c is None:
+ continue
+ by_tok.setdefault(int(tid), []).append(c - a)
+ if not by_tok:
+ return None
+ total_w = 0.0
+ pos_w = 0.0
+ for diffs in by_tok.values():
+ md = min(diffs)
+ w = abs(md)
+ total_w += w
+ if md > 0:
+ pos_w += w
+ if total_w == 0:
+ return None
+ return pos_w / total_w
+
+
+def _ifd_family_metrics(
+ cond_lp: List,
+ asst_lp: List,
+ cond_ids: List[int],
+ asst_ids: List[int],
+ n_prompt: int,
+) -> Dict[str, Any]:
+ """IFD (Cherry-LLM) and S-IFD-{50,75} (T-SHIRT) for one round."""
+ if not asst_lp or not cond_lp or not asst_ids:
+ return {}
+ n_a = min(len(asst_lp), len(asst_ids))
+ n_c = len(cond_lp)
+ deltas: List[float] = []
+ for i in range(n_a):
+ ci = n_prompt + i
+ if ci >= n_c:
+ break
+ tid = asst_ids[i]
+ if tid is None:
+ continue
+ a = _extract_logprob(asst_lp[i], tid)
+ c_tok = cond_ids[ci] if ci < len(cond_ids) else None
+ c = _extract_logprob(cond_lp[ci], c_tok)
+ if a is None or c is None:
+ continue
+ deltas.append(c - a)
+ if not deltas:
+ return {}
+ n = len(deltas)
+ mean_delta = sum(deltas) / n
+ out: Dict[str, Any] = {
+ 'n_tokens': n,
+ 'mean_delta': mean_delta,
+ 'ifd': math.exp(-mean_delta),
+ }
+ abs_sorted = sorted(range(n), key=lambda i: abs(deltas[i]), reverse=True)
+ for k_pct in (50, 75):
+ keep = max(1, int(round(n * k_pct / 100)))
+ sub = [deltas[i] for i in abs_sorted[:keep]]
+ out[f's_ifd_{k_pct}'] = math.exp(-sum(sub) / len(sub))
+ return out
+
+
+def _mean_logprob_delta(
+ cond_lp: List,
+ asst_lp: List,
+ cond_ids: List[int],
+ asst_ids: List[int],
+ n_prompt: int,
+) -> Optional[float]:
+ """Mean per-token (cond_lp - asst_lp) over the response span."""
+ if not asst_lp or not cond_lp or not asst_ids:
+ return None
+ n_a = min(len(asst_lp), len(asst_ids))
+ n_c = len(cond_lp)
+ deltas: List[float] = []
+ for i in range(n_a):
+ ci = n_prompt + i
+ if ci >= n_c:
+ break
+ tid = asst_ids[i]
+ if tid is None:
+ continue
+ a = _extract_logprob(asst_lp[i], tid)
+ c_tok = cond_ids[ci] if ci < len(cond_ids) else None
+ c = _extract_logprob(cond_lp[ci], c_tok)
+ if a is None or c is None:
+ continue
+ deltas.append(c - a)
+ if not deltas:
+ return None
+ return sum(deltas) / len(deltas)
+
+
+def _lp_to_jsonable(lp_list):
+ """Convert per-position prompt_logprobs into JSON-safe form."""
+ out = []
+ for lp in (lp_list or []):
+ if lp is None:
+ out.append(None)
+ continue
+ if isinstance(lp, (int, float)):
+ out.append(float(lp))
+ continue
+ if not isinstance(lp, dict):
+ out.append(repr(lp))
+ continue
+ d = {}
+ for k, v in lp.items():
+ if hasattr(v, 'logprob'):
+ d[str(k)] = {
+ 'logprob': float(v.logprob),
+ 'rank': getattr(v, 'rank', None),
+ 'decoded': getattr(v, 'decoded_token', None)
+ }
+ elif isinstance(v, dict):
+ d[str(k)] = v
+ else:
+ d[str(k)] = repr(v)
+ out.append(d)
+ return out
+
+
+def _pad_batch(batch: List[List[int]], floor: int) -> Tuple[List[List[int]], int]:
+ n = len(batch)
+ if n >= floor or not batch:
+ return batch, n
+ return list(batch) + [batch[-1]] * (floor - n), n
+
+
+# ══════════════════════════════════════════════════════════════════════════════
+# Message-format utilities
+# ══════════════════════════════════════════════════════════════════════════════
+
+
+def msg_content_text(msg: Dict[str, Any]) -> str:
+ """Extract plain text from a message's content (str | list | dict)."""
+ c = msg.get('content')
+ if isinstance(c, str):
+ return c
+ if isinstance(c, list):
+ return ' '.join(p.get('text', '') for p in c if isinstance(p, dict) and p.get('type') == 'text')
+ if isinstance(c, dict) and c.get('type') == 'text':
+ return c.get('text', '')
+ return ''
+
+
+def msg_has_media(msg: Dict[str, Any]) -> bool:
+ """True if message content contains non-text parts (image/audio/video)."""
+ c = msg.get('content')
+ return isinstance(c, list) and any(isinstance(p, dict) and p.get('type') not in ('text', None) for p in c)
+
+
+def msg_has_payload(msg: Dict[str, Any]) -> bool:
+ """True if a message carries any substantive payload (text, tool_calls, reasoning, or media)."""
+ return bool(
+ msg_content_text(msg).strip()
+ or msg.get('tool_calls')
+ or msg.get('reasoning_content') or msg.get('thinking')
+ or msg_has_media(msg)
+ )
+
+
+_CJK_RE = re.compile(r'[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3]')
+
+
+def cjk_ratio(text: str) -> float:
+ """Fraction of characters that are CJK (Chinese/Japanese/Korean)."""
+ return len(_CJK_RE.findall(text)) / max(len(text), 1)
+
+
+def normalize_tool_calls(msg: Dict[str, Any]) -> Optional[List[Any]]:
+ """Return ``tool_calls`` as a list of dicts, handling PyArrow/HF serialization artifacts."""
+ tcs = msg.get('tool_calls')
+ if isinstance(tcs, str):
+ s = tcs.strip()
+ if not s:
+ return None
+ try:
+ decoded = json.loads(s)
+ except (json.JSONDecodeError, ValueError):
+ return None
+ if not isinstance(decoded, list) or not decoded:
+ return None
+ tcs = decoded
+ if not isinstance(tcs, list) or not tcs:
+ return None
+ result = []
+ for tc in tcs:
+ if isinstance(tc, str):
+ try:
+ tc = json.loads(tc)
+ except (json.JSONDecodeError, ValueError):
+ return None
+ if not isinstance(tc, dict):
+ return None
+ func = tc.get('function')
+ if isinstance(func, str):
+ try:
+ func = json.loads(func)
+ except (json.JSONDecodeError, ValueError):
+ return None
+ tc = dict(tc, function=func)
+ result.append(tc)
+ return result
+
+
+CJK_CHARS_RE = re.compile(r'[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3]')
+
+
+def cjk_ratio(text: str) -> float:
+ """Fraction of non-whitespace characters that are CJK."""
+ chars = text.replace(' ', '').replace('\n', '').replace('\t', '')
+ if not chars:
+ return 0.0
+ return len(CJK_CHARS_RE.findall(chars)) / len(chars)
+
+
+def load_sensitive_words(path: Optional[str]) -> Set[str]:
+ """Load from external file (one word per line). Blank lines and #-comments ignored."""
+ if not path or not os.path.isfile(path):
+ return set()
+ words: Set[str] = set()
+ with open(path, encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if line and not line.startswith('#'):
+ words.add(line)
+ return words
+
+
+def build_sensitive_regex(words: Set[str]) -> Optional['re.Pattern']:
+ """Build a compiled regex from a set of words. Returns None if empty."""
+ if not words:
+ return None
+ cjk_words = []
+ latin_words = []
+ cjk_re = re.compile(r'[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3]')
+ for w in sorted(words):
+ if cjk_re.search(w):
+ cjk_words.append(re.escape(w))
+ else:
+ latin_words.append(re.escape(w))
+ parts = []
+ if latin_words:
+ parts.append(r'\b(' + '|'.join(latin_words) + r')\b')
+ if cjk_words:
+ parts.append('(' + '|'.join(cjk_words) + ')')
+ return re.compile('|'.join(parts), re.IGNORECASE)
+
+
+def is_agent_row(messages) -> bool:
+ """Return True if the conversation contains tool interactions (agent trace).
+
+ After MessageNormalizer runs, all non-standard formats are already converted
+ to standard tool_calls / role=tool — so checking those two signals suffices.
+ """
+ if not isinstance(messages, list):
+ return False
+ for m in messages:
+ if not isinstance(m, dict):
+ continue
+ if m.get('role') == 'tool':
+ return True
+ if normalize_tool_calls(m):
+ return True
+ return False
diff --git a/src/twinkle_agentic/rollout/multi_turn.py b/src/twinkle_agentic/rollout/multi_turn.py
index ed6f10d23..e25e89de9 100644
--- a/src/twinkle_agentic/rollout/multi_turn.py
+++ b/src/twinkle_agentic/rollout/multi_turn.py
@@ -6,7 +6,7 @@
import time
from typing import Any, Callable, Dict, List, Optional
-from twinkle.data_format import Trajectory
+from twinkle.data_format import Trajectory, user_data_get
from twinkle.data_format.sampling import SampleResponse, SamplingParams
from twinkle.infra import remote_class, remote_function
from twinkle.template.base import Template
@@ -289,11 +289,8 @@ def _serialize_for_trace(cls, traj: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _extract_ground_truth(traj: Dict[str, Any]) -> str:
- """Pull ``ground_truth`` out of ``user_data`` (list of kv pairs)."""
- for kv in (traj.get('user_data') or []):
- if (isinstance(kv, (list, tuple)) and len(kv) >= 2 and kv[0] == 'ground_truth'):
- return kv[1] or ''
- return ''
+ """Pull ``ground_truth`` out of packed ``user_data``."""
+ return user_data_get(traj.get('user_data'), 'ground_truth', '') or ''
@staticmethod
def _resolve_traj_id(traj: Dict[str, Any], fallback_idx: int) -> str:
@@ -304,13 +301,12 @@ def _resolve_traj_id(traj: Dict[str, Any], fallback_idx: int) -> str:
``{timestamp_ms}-{fallback_idx}`` so concurrent rollouts do not
overwrite each other's files.
"""
- for kv in (traj.get('user_data') or []):
- if (isinstance(kv, (list, tuple)) and len(kv) >= 2 and kv[0] in ('id', 'prompt_id')):
- val = kv[1]
- if val not in (None, ''):
- safe = re.sub(r'[^A-Za-z0-9_\-.]+', '_', str(val))[:64]
- if safe:
- return safe
+ for key in ('id', 'prompt_id'):
+ val = user_data_get(traj.get('user_data'), key)
+ if val not in (None, ''):
+ safe = re.sub(r'[^A-Za-z0-9_\-.]+', '_', str(val))[:64]
+ if safe:
+ return safe
return f'{int(time.time() * 1000)}-{fallback_idx}'
def _build_trace_record(
diff --git a/src/twinkle/loss_scale/__init__.py b/src/twinkle_agentic/sampler/__init__.py
similarity index 59%
rename from src/twinkle/loss_scale/__init__.py
rename to src/twinkle_agentic/sampler/__init__.py
index 7a67f94eb..93d4eec2e 100644
--- a/src/twinkle/loss_scale/__init__.py
+++ b/src/twinkle_agentic/sampler/__init__.py
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-from .base import LossScale
+from .router_sampler import RouterSampler
diff --git a/src/twinkle_agentic/sampler/router_sampler.py b/src/twinkle_agentic/sampler/router_sampler.py
new file mode 100644
index 000000000..ec57343e0
--- /dev/null
+++ b/src/twinkle_agentic/sampler/router_sampler.py
@@ -0,0 +1,197 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import httpx
+import math
+from copy import copy
+from typing import Any, Dict, List, Literal, Optional, Union
+
+from twinkle import get_logger
+from twinkle.data_format import SampledSequence, SampleResponse, SamplingParams, Trajectory
+
+logger = get_logger()
+
+
+def _entropy_from_topk(logprobs_per_token: List[List[tuple]]) -> float:
+ """Mean per-token entropy approximated from top-K logprobs (renormalized)."""
+ if not logprobs_per_token:
+ return float('inf')
+ total = 0.0
+ for candidates in logprobs_per_token:
+ if not candidates:
+ total += float('inf')
+ continue
+ lps = [lp for _, lp in candidates]
+ max_lp = max(lps)
+ # numerically stable softmax over top-K
+ exps = [math.exp(lp - max_lp) for lp in lps]
+ z = sum(exps)
+ total += sum(-(e / z) * (lp - max_lp - math.log(z)) for e, lp in zip(exps, lps))
+ return total / len(logprobs_per_token)
+
+
+def _mean_logp(logprobs_per_token: List[List[tuple]], tokens: List[int]) -> float:
+ """Mean log-probability of generated tokens (sequence-level confidence)."""
+ if not logprobs_per_token or not tokens:
+ return float('-inf')
+ total = 0.0
+ count = 0
+ for t, candidates in enumerate(logprobs_per_token):
+ if t >= len(tokens) or not candidates:
+ continue
+ tok = tokens[t]
+ lp = next((v for tid, v in candidates if tid == tok), None)
+ if lp is None:
+ lp = candidates[0][1]
+ total += lp
+ count += 1
+ return total / max(count, 1)
+
+
+class RouterSampler:
+ """Confidence-based routing sampler.
+
+ Generates with a local sampler first; if confidence is low, falls back
+ to an OpenAI-compatible endpoint (stronger model).
+ """
+
+ def __init__(
+ self,
+ sampler,
+ fallback_endpoint: str,
+ fallback_model: str = 'default',
+ fallback_api_key: str = '',
+ method: Literal['entropy', 'logp'] = 'entropy',
+ threshold: float = 2.0,
+ top_k_logprobs: int = 10,
+ fallback_temperature: float = 0.7,
+ fallback_max_tokens: int = 4096,
+ timeout: float = 120.0,
+ ):
+ """
+ Args:
+ sampler: Inner sampler instance (e.g. vLLMSampler).
+ fallback_endpoint: OpenAI-compatible API base URL.
+ fallback_model: Model name for fallback requests.
+ fallback_api_key: Bearer token for fallback API.
+ method: Confidence metric — 'entropy' (route when H > threshold)
+ or 'logp' (route when mean logp < threshold).
+ threshold: Routing threshold. For entropy: higher = more routing.
+ For logp: lower (more negative) = more routing.
+ top_k_logprobs: Number of top logprobs to request from inner sampler.
+ fallback_temperature: Temperature for fallback generation.
+ fallback_max_tokens: Max tokens for fallback generation.
+ timeout: HTTP timeout for fallback requests.
+ """
+ self.sampler = sampler
+ self._method = method
+ self._threshold = threshold
+ self._top_k = top_k_logprobs
+ self._fb_temperature = fallback_temperature
+ self._fb_max_tokens = fallback_max_tokens
+ self._fb_endpoint = f'{fallback_endpoint.rstrip("/")}/v1/chat/completions'
+ self._fb_model = fallback_model
+ headers = {'Content-Type': 'application/json'}
+ if fallback_api_key:
+ headers['Authorization'] = f'Bearer {fallback_api_key}'
+ self._client = httpx.Client(timeout=timeout, headers=headers)
+
+ @property
+ def template(self):
+ return self.sampler.template
+
+ def set_template(self, *args, **kwargs):
+ return self.sampler.set_template(*args, **kwargs)
+
+ def _should_route(self, seq: SampledSequence) -> bool:
+ if not seq.logprobs:
+ return True
+ if self._method == 'entropy':
+ score = _entropy_from_topk(seq.logprobs)
+ return score > self._threshold
+ score = _mean_logp(seq.logprobs, seq.tokens)
+ return score < self._threshold
+
+ def _fallback_generate(self, trajectory: Trajectory) -> Optional[str]:
+ messages = trajectory.get('messages', [])
+ if not messages:
+ return None
+ api_messages = []
+ for m in messages:
+ if not isinstance(m, dict):
+ continue
+ entry = {'role': m.get('role', 'user')}
+ content = m.get('content', '')
+ if isinstance(content, list):
+ parts = []
+ for block in content:
+ if isinstance(block, dict) and block.get('type') == 'text':
+ parts.append(block.get('text', ''))
+ content = '\n'.join(parts) if parts else ''
+ entry['content'] = content or ''
+ api_messages.append(entry)
+ try:
+ resp = self._client.post(
+ self._fb_endpoint,
+ json={
+ 'model': self._fb_model,
+ 'messages': api_messages,
+ 'temperature': self._fb_temperature,
+ 'max_tokens': self._fb_max_tokens,
+ })
+ resp.raise_for_status()
+ choices = resp.json().get('choices', [])
+ if choices:
+ return (choices[0].get('message') or {}).get('content', '')
+ except Exception as e:
+ logger.warning(f'RouterSampler fallback failed: {e}')
+ return None
+
+ def sample(
+ self,
+ inputs: Union[Dict, List[Dict]],
+ sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None,
+ adapter_name: str = '',
+ adapter_path: Optional[str] = None,
+ **kwargs,
+ ) -> List[SampleResponse]:
+ """Sample with confidence-based routing to fallback model."""
+ if sampling_params is None:
+ sampling_params = SamplingParams()
+ elif isinstance(sampling_params, dict):
+ sampling_params = SamplingParams.from_dict(sampling_params)
+
+ # Ensure logprobs are requested for confidence evaluation
+ routed_params = copy(sampling_params)
+ if routed_params.logprobs is None or routed_params.logprobs < self._top_k:
+ routed_params.logprobs = self._top_k
+
+ inputs_list = inputs if isinstance(inputs, list) else [inputs]
+ is_trajectory = isinstance(inputs_list[0], dict) and 'input_ids' not in inputs_list[0]
+
+ results = self.sampler.sample(inputs_list, routed_params, adapter_name, adapter_path=adapter_path, **kwargs)
+
+ if not is_trajectory:
+ return results
+
+ for i, (resp, traj) in enumerate(zip(results, inputs_list)):
+ new_sequences = []
+ for seq in resp.sequences:
+ if self._should_route(seq):
+ fallback_text = self._fallback_generate(traj)
+ if fallback_text is not None:
+ new_sequences.append(
+ SampledSequence(
+ stop_reason='stop',
+ tokens=[],
+ logprobs=None,
+ decoded=fallback_text,
+ ))
+ continue
+ new_sequences.append(seq)
+ results[i] = SampleResponse(
+ sequences=new_sequences,
+ prompt_token_ids=resp.prompt_token_ids,
+ prompt_logprobs=resp.prompt_logprobs,
+ topk_prompt_logprobs=resp.topk_prompt_logprobs,
+ )
+
+ return results
diff --git a/tests/dataset/test_save_as.py b/tests/dataset/test_save_as.py
new file mode 100644
index 000000000..ccb0f6f34
--- /dev/null
+++ b/tests/dataset/test_save_as.py
@@ -0,0 +1,210 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Test Dataset.save_as:
+1. Immediate mode (bulk/incremental) across formats (jsonl, csv, parquet)
+2. Training mode (write-through) across formats (jsonl, csv, parquet)
+"""
+import json
+import os
+import tempfile
+
+import pytest
+
+from twinkle.dataset import Dataset, DatasetMeta
+
+
+SAMPLE_DATA = [
+ {'text': 'Hello world', 'label': 0},
+ {'text': 'Test data', 'label': 1},
+ {'text': 'Another example', 'label': 0},
+ {'text': 'Sample text', 'label': 1},
+]
+
+
+def _make_dataset(data=None, streaming=False):
+ """Create a Dataset from in-memory data."""
+ d = data or SAMPLE_DATA
+ if streaming:
+ def gen():
+ yield from d
+ return Dataset(dataset_meta=DatasetMeta(data=gen), streaming=True)
+ return Dataset(dataset_meta=DatasetMeta(data=d))
+
+
+class TestSaveAsImmediate:
+ """Immediate mode: save the entire dataset at once."""
+
+ def test_save_jsonl(self, tmp_path):
+ ds = _make_dataset()
+ out = str(tmp_path / 'output.jsonl')
+ ds.save_as(out)
+
+ with open(out, 'r') as f:
+ lines = [json.loads(l) for l in f if l.strip()]
+ assert len(lines) == 4
+ assert lines[0]['text'] == 'Hello world'
+ assert lines[3]['label'] == 1
+
+ def test_save_csv(self, tmp_path):
+ import pandas as pd
+ ds = _make_dataset()
+ out = str(tmp_path / 'output.csv')
+ ds.save_as(out, format='csv')
+
+ df = pd.read_csv(out)
+ assert len(df) == 4
+ assert df.iloc[0]['text'] == 'Hello world'
+ assert df.iloc[1]['label'] == 1
+
+ def test_save_parquet(self, tmp_path):
+ import pandas as pd
+ ds = _make_dataset()
+ out = str(tmp_path / 'output.parquet')
+ ds.save_as(out)
+
+ df = pd.read_parquet(out)
+ assert len(df) == 4
+ assert df.iloc[0]['text'] == 'Hello world'
+ assert df.iloc[2]['label'] == 0
+
+ def test_save_json_extension_inferred_as_jsonl(self, tmp_path):
+ ds = _make_dataset()
+ out = str(tmp_path / 'output.json')
+ ds.save_as(out)
+
+ with open(out, 'r') as f:
+ lines = [json.loads(l) for l in f if l.strip()]
+ assert len(lines) == 4
+
+ def test_save_incremental_streaming(self, tmp_path):
+ """Streaming (IterableDataset) triggers incremental save path."""
+ ds = _make_dataset(streaming=True)
+ out = str(tmp_path / 'stream_out.jsonl')
+ ds.save_as(out)
+
+ with open(out, 'r') as f:
+ lines = [json.loads(l) for l in f if l.strip()]
+ assert len(lines) == 4
+ assert lines[0]['text'] == 'Hello world'
+
+ def test_save_incremental_csv_streaming(self, tmp_path):
+ import pandas as pd
+ ds = _make_dataset(streaming=True)
+ out = str(tmp_path / 'stream_out.csv')
+ ds.save_as(out, format='csv', batch_size=2)
+
+ df = pd.read_csv(out)
+ assert len(df) == 4
+
+ def test_save_incremental_parquet_streaming(self, tmp_path):
+ import pandas as pd
+ ds = _make_dataset(streaming=True)
+ out = str(tmp_path / 'stream_out.parquet')
+ ds.save_as(out, format='parquet', batch_size=2)
+
+ df = pd.read_parquet(out)
+ assert len(df) == 4
+
+ def test_error_no_dataset(self, tmp_path):
+ ds = Dataset()
+ with pytest.raises(ValueError, match='No dataset to save'):
+ ds.save_as(str(tmp_path / 'x.jsonl'))
+
+ def test_error_unsupported_format(self, tmp_path):
+ ds = _make_dataset()
+ with pytest.raises(ValueError, match='Unsupported format'):
+ ds.save_as(str(tmp_path / 'x.txt'), format='txt')
+
+ def test_creates_output_directory(self, tmp_path):
+ ds = _make_dataset()
+ out = str(tmp_path / 'nested' / 'dir' / 'output.jsonl')
+ ds.save_as(out)
+ assert os.path.isfile(out)
+
+
+class TestSaveAsTraining:
+ """Training mode: write-through as items are consumed via __getitem__."""
+
+ def test_training_jsonl(self, tmp_path):
+ ds = _make_dataset()
+ out = str(tmp_path / 'train.jsonl')
+ ds.save_as(out, mode='training')
+
+ # Consume all items via __getitem__
+ for i in range(len(ds)):
+ _ = ds[i]
+
+ ds.flush_save()
+
+ with open(out, 'r') as f:
+ lines = [json.loads(l) for l in f if l.strip()]
+ assert len(lines) == 4
+ assert lines[0]['text'] == 'Hello world'
+ assert lines[3]['label'] == 1
+
+ def test_training_csv(self, tmp_path):
+ import pandas as pd
+ ds = _make_dataset()
+ out = str(tmp_path / 'train.csv')
+ ds.save_as(out, format='csv', batch_size=2, mode='training')
+
+ for i in range(len(ds)):
+ _ = ds[i]
+
+ ds.flush_save()
+
+ df = pd.read_csv(out)
+ assert len(df) == 4
+ assert df.iloc[0]['text'] == 'Hello world'
+
+ def test_training_parquet(self, tmp_path):
+ import pandas as pd
+ ds = _make_dataset()
+ out = str(tmp_path / 'train.parquet')
+ ds.save_as(out, format='parquet', batch_size=2, mode='training')
+
+ for i in range(len(ds)):
+ _ = ds[i]
+
+ ds.flush_save()
+
+ df = pd.read_parquet(out)
+ assert len(df) == 4
+ assert df.iloc[2]['text'] == 'Another example'
+
+ def test_training_partial_consume(self, tmp_path):
+ """Only consumed items are written."""
+ ds = _make_dataset()
+ out = str(tmp_path / 'partial.jsonl')
+ ds.save_as(out, mode='training')
+
+ # Only consume first 2 items
+ _ = ds[0]
+ _ = ds[1]
+
+ ds.flush_save()
+
+ with open(out, 'r') as f:
+ lines = [json.loads(l) for l in f if l.strip()]
+ assert len(lines) == 2
+ assert lines[0]['text'] == 'Hello world'
+ assert lines[1]['text'] == 'Test data'
+
+ def test_training_flush_idempotent(self, tmp_path):
+ """Double flush_save should not raise."""
+ ds = _make_dataset()
+ out = str(tmp_path / 'idem.jsonl')
+ ds.save_as(out, mode='training')
+ _ = ds[0]
+ ds.flush_save()
+ ds.flush_save() # second call is a no-op
+
+ def test_lock_file_cleanup(self, tmp_path):
+ """Lock file is created during training mode."""
+ ds = _make_dataset()
+ out = str(tmp_path / 'lock_test.jsonl')
+ ds.save_as(out, mode='training')
+ _ = ds[0]
+ ds.flush_save()
+ # Lock file should exist (created by PosixFileLock)
+ assert os.path.isfile(out + '.lock')
diff --git a/tests/preprocessor/test_agent_trace_filter.py b/tests/preprocessor/test_agent_trace_filter.py
new file mode 100644
index 000000000..4a1c21beb
--- /dev/null
+++ b/tests/preprocessor/test_agent_trace_filter.py
@@ -0,0 +1,367 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for AgentTraceFilter.
+
+AgentTraceFilter is detection-only — it tags rows with ``is_agent=True/False``
+and never drops or mutates messages. Detection delegates to
+``ToolCallRegistry.detect_first`` so the test surface is:
+
+ 1. Tag is set on EVERY row (uniform schema).
+ 2. role='tool' or non-empty ``tool_calls`` field → True.
+ 3. Text-embedded tool calls (Cline / Hermes / ReAct) on assistant role → True.
+ 4. Plain assistant content with no tool markers → False.
+ 5. Look-alike XML that the registry rejects (e.g. plain ``...``
+ without inner params) → False.
+ 6. Malformed message lists never raise.
+"""
+import pytest
+
+from twinkle_agentic.preprocessor.agent_trace_filter import AgentTraceFilter, _is_agent_row, _msg_text
+
+
+def _row(messages):
+ return {'messages': messages}
+
+
+# ── _msg_text helper ─────────────────────────────────────────────────────────
+
+
+class TestMsgText:
+
+ def test_string_content(self):
+ assert _msg_text({'role': 'user', 'content': 'hello'}) == 'hello'
+
+ def test_list_content_concat(self):
+ msg = {
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'a'
+ },
+ {
+ 'type': 'image',
+ 'url': '...'
+ }, # non-text part ignored
+ {
+ 'type': 'text',
+ 'text': 'b'
+ },
+ ]
+ }
+ assert _msg_text(msg) == 'a b'
+
+ def test_missing_content(self):
+ assert _msg_text({'role': 'user'}) == ''
+
+ def test_none_content(self):
+ assert _msg_text({'role': 'user', 'content': None}) == ''
+
+ def test_non_str_non_list_content(self):
+ assert _msg_text({'role': 'user', 'content': 123}) == ''
+
+
+# ── _is_agent_row detection ──────────────────────────────────────────────────
+
+
+class TestIsAgentRowStructural:
+
+ def test_role_tool_triggers(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [{
+ 'id': 'a',
+ 'type': 'function',
+ 'function': {
+ 'name': 'x',
+ 'arguments': '{}'
+ }
+ }]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'result',
+ 'tool_call_id': 'a'
+ },
+ ]
+ assert _is_agent_row(msgs) is True
+
+ def test_tool_calls_field_triggers(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [{
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'f',
+ 'arguments': '{}'
+ }
+ }]
+ },
+ ]
+ assert _is_agent_row(msgs) is True
+
+ def test_empty_tool_calls_field_does_not_trigger(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'plain reply',
+ 'tool_calls': []
+ },
+ ]
+ assert _is_agent_row(msgs) is False
+
+ def test_non_list_tool_calls_field_does_not_trigger(self):
+ msgs = [
+ {
+ 'role': 'assistant',
+ 'content': 'x',
+ 'tool_calls': None
+ },
+ ]
+ assert _is_agent_row(msgs) is False
+
+
+class TestIsAgentRowTextEmbedded:
+
+ def test_cline_style_triggers(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'read the file'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '/etc/hosts'
+ },
+ ]
+ assert _is_agent_row(msgs) is True
+
+ def test_hermes_qwen_style_triggers(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '\n{"name": "search", "arguments": {"q": "x"}}\n'
+ },
+ ]
+ assert _is_agent_row(msgs) is True
+
+ def test_react_action_style_triggers(self):
+ # ReAct parser uses bracket syntax: ``Action: name[args]``.
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'Thought: I need to search.\nAction: search[query=x]'
+ },
+ ]
+ assert _is_agent_row(msgs) is True
+
+ def test_plain_assistant_text_does_not_trigger(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'Hello! How can I help?'
+ },
+ ]
+ assert _is_agent_row(msgs) is False
+
+ def test_lookalike_xml_without_inner_params_does_not_trigger(self):
+ # ``echo hi`` has no ``val`` child — Cline parser
+ # rejects it via inner-param requirement. Hermes/ReAct also reject.
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'echo hi'
+ },
+ ]
+ assert _is_agent_row(msgs) is False
+
+ def test_denied_outer_tag_does_not_trigger(self):
+ # ````/```` are in the Cline DENY frozenset.
+ msgs = [
+ {
+ 'role': 'assistant',
+ 'content': 'because'
+ },
+ ]
+ assert _is_agent_row(msgs) is False
+
+ def test_user_text_with_tool_markers_does_not_trigger(self):
+ # Markers must come from the assistant — user-side embedded XML is just data.
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'x'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'I will do that.'
+ },
+ ]
+ assert _is_agent_row(msgs) is False
+
+ def test_list_content_assistant_with_tool_call(self):
+ msgs = [
+ {
+ 'role':
+ 'assistant',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': ''
+ },
+ {
+ 'type': 'text',
+ 'text': '{"name":"f","arguments":{}}'
+ },
+ ]
+ },
+ ]
+ assert _is_agent_row(msgs) is True
+
+
+class TestIsAgentRowEdgeCases:
+
+ def test_non_list_messages(self):
+ assert _is_agent_row(None) is False
+ assert _is_agent_row('') is False
+ assert _is_agent_row({}) is False
+
+ def test_empty_messages(self):
+ assert _is_agent_row([]) is False
+
+ def test_non_dict_message_skipped(self):
+ msgs = [
+ 'not a dict',
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _is_agent_row(msgs) is False
+
+ def test_short_circuits_on_first_match(self):
+ # Even if later messages are clean, an earlier tool-call hit wins.
+ msgs = [
+ {
+ 'role': 'tool',
+ 'content': 'r',
+ 'tool_call_id': 'x'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'plain'
+ },
+ ]
+ assert _is_agent_row(msgs) is True
+
+
+# ── AgentTraceFilter pipeline behavior ───────────────────────────────────────
+
+
+class TestAgentTraceFilterPipeline:
+
+ def test_tags_every_row(self):
+ rows = [
+ _row([{
+ 'role': 'assistant',
+ 'content': 'plain'
+ }]),
+ _row([{
+ 'role': 'tool',
+ 'content': 'r',
+ 'tool_call_id': 'x'
+ }]),
+ _row([{
+ 'role': 'assistant',
+ 'content': 'x'
+ }]),
+ ]
+ out = AgentTraceFilter()(rows)
+ assert len(out) == 3
+ # Every row must have ``is_agent`` so map_row_to_col sees a uniform schema.
+ assert all('is_agent' in r for r in out)
+ assert [r['is_agent'] for r in out] == [False, True, True]
+
+ def test_never_drops_rows(self):
+ rows = [_row([{'role': 'user', 'content': 'x'}])] * 5
+ out = AgentTraceFilter()(rows)
+ assert len(out) == 5
+
+ def test_preserves_other_fields(self):
+ rows = [
+ {
+ 'messages': [{
+ 'role': 'tool',
+ 'content': 'r',
+ 'tool_call_id': 'x'
+ }],
+ 'id': 'row-1',
+ 'extra': {
+ 'k': 'v'
+ }
+ },
+ ]
+ out = AgentTraceFilter()(rows)
+ assert out[0]['id'] == 'row-1'
+ assert out[0]['extra'] == {'k': 'v'}
+ assert out[0]['is_agent'] is True
+
+ def test_does_not_mutate_input(self):
+ original = _row([{'role': 'assistant', 'content': 'plain'}])
+ rows = [original]
+ AgentTraceFilter()(rows)
+ # Filter must return new dicts, not mutate originals.
+ assert 'is_agent' not in original
+
+ def test_missing_messages_key(self):
+ rows = [{'id': 'lonely'}] # no messages
+ out = AgentTraceFilter()(rows)
+ assert len(out) == 1
+ assert out[0]['is_agent'] is False
+
+ def test_messages_is_none(self):
+ rows = [_row(None)]
+ out = AgentTraceFilter()(rows)
+ assert out[0]['is_agent'] is False
+
+ def test_empty_input(self):
+ assert AgentTraceFilter()([]) == []
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/preprocessor/test_dead_loop_filter.py b/tests/preprocessor/test_dead_loop_filter.py
new file mode 100644
index 000000000..9f90e451e
--- /dev/null
+++ b/tests/preprocessor/test_dead_loop_filter.py
@@ -0,0 +1,308 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for DeadLoopFilter.
+
+Three orthogonal "stuck" signals:
+ 1. Hesitation density — markers per 1000 chars > threshold
+ 2. Correction cascade — ≥N markers within a sliding window
+ 3. High n-gram repetition — (1 - unique/total) > threshold
+
+A row is dropped if ANY signal trips on any assistant turn.
+Rows with ``is_agent=True`` are always kept (agent rollouts have legitimate
+self-correction phrasing).
+
+When the message contains ``...``, the think part and the
+response part are scored independently with separate (looser) think-thresholds.
+"""
+import pytest
+
+from twinkle_agentic.preprocessor.dead_loop_filter import (DeadLoopFilter, _has_correction_cascade_with_threshold,
+ _hesitation_density, _high_repetition_with_threshold,
+ _is_stuck)
+
+
+def _row(messages, **extra):
+ return {'messages': messages, **extra}
+
+
+def _fil(rows, **kw):
+ return DeadLoopFilter(**kw)(rows)
+
+
+# ── _hesitation_density ─────────────────────────────────────────────────────
+
+
+class TestHesitationDensity:
+
+ def test_no_markers(self):
+ text = 'This is a perfectly normal explanation of gradient descent.'
+ assert _hesitation_density(text) == 0.0
+
+ def test_english_marker_counted(self):
+ # "wait, wait" matches `wait[,\s]+(wait|...)` — one marker.
+ text = 'wait, wait this is wrong'
+ d = _hesitation_density(text)
+ assert d > 0
+
+ def test_density_per_1000(self):
+ # ~5 markers in 100 chars → density ~50/1000
+ text = ('hmm hmm hmm hmm hmm ' * 1).strip() # 5 hmm tokens
+ # Each "hmm" matches `hmm+[,\s]*\.{0,3}` → 5 matches
+ density = _hesitation_density(text)
+ assert density > 100 # very dense
+
+ def test_chinese_marker(self):
+ text = '等等,让我重新想想这个问题。'
+ assert _hesitation_density(text) > 0
+
+ def test_empty_text(self):
+ assert _hesitation_density('') == 0.0
+
+ def test_japanese_marker(self):
+ text = 'ちょっと待って、もう一度考え直してみます。'
+ assert _hesitation_density(text) > 0
+
+ def test_korean_marker(self):
+ text = '잠깐, 다시 생각해봐야겠어요.'
+ assert _hesitation_density(text) > 0
+
+
+# ── _has_correction_cascade_with_threshold ──────────────────────────────────
+
+
+class TestCorrectionCascade:
+
+ def test_below_threshold(self):
+ # Only 2 cascade markers; threshold=5 → no cascade.
+ text = 'wait, actually let me think.'
+ assert _has_correction_cascade_with_threshold(text, threshold=5) is False
+
+ def test_at_threshold_in_window(self):
+ # 5 cascade tokens packed into <800 chars → cascade detected.
+ text = 'wait wait wait wait wait'
+ assert _has_correction_cascade_with_threshold(text, threshold=5, window=800) is True
+
+ def test_threshold_outside_window(self):
+ # 5 markers but spread across >800 chars → no cascade.
+ spacer = ' ' * 200 # each spacer is 200 chars
+ text = f'wait{spacer}wait{spacer}wait{spacer}wait{spacer}wait' # 5*200 = 1000 chars
+ assert _has_correction_cascade_with_threshold(text, threshold=5, window=800) is False
+
+ def test_chinese_cascade(self):
+ text = '等等,不对,重新想想,错了,让我再算一遍。'
+ assert _has_correction_cascade_with_threshold(text, threshold=4) is True
+
+ def test_zero_threshold_unreachable(self):
+ # threshold=0 means need 0 matches in any window — len(matches) < 0 is
+ # never true so this returns True even on empty. Test the sane case.
+ assert _has_correction_cascade_with_threshold('clean text', threshold=1) is False
+
+
+# ── _high_repetition_with_threshold ─────────────────────────────────────────
+
+
+class TestRepetition:
+
+ def test_below_min_words(self):
+ # Fewer than ngram_min_words words → False (insufficient sample).
+ text = 'this is a short text'
+ assert _high_repetition_with_threshold(text, threshold=0.0, ngram_min_words=30) is False
+
+ def test_no_repetition(self):
+ # 30 distinct words → unique_ratio ~ 1.0 → repetition ~ 0.
+ text = ' '.join(f'word{i}' for i in range(40))
+ assert _high_repetition_with_threshold(text, threshold=0.45, ngram_min_words=30) is False
+
+ def test_high_repetition_triggers(self):
+ # Same 8-gram repeated → unique_ratio low → repetition high.
+ phrase = 'the quick brown fox jumps over the lazy'
+ text = ' '.join([phrase] * 10)
+ assert _high_repetition_with_threshold(text, threshold=0.45, ngram_size=8, ngram_min_words=30) is True
+
+ def test_threshold_boundary(self):
+ # Same text under different thresholds.
+ phrase = 'a b c d e f g h '
+ text = phrase * 6 # 48 words, only 8 unique
+ # very low threshold → trips
+ assert _high_repetition_with_threshold(text, threshold=0.1) is True
+ # very high threshold → does not trip even with high duplication
+ assert _high_repetition_with_threshold(text, threshold=0.99) is False
+
+
+# ── _is_stuck ───────────────────────────────────────────────────────────────
+
+
+class TestIsStuck:
+
+ def test_clean_text_not_stuck(self):
+ # Use diverse prose so n-gram repetition stays below threshold.
+ text = ('Gradient descent is an iterative optimization algorithm used '
+ 'for finding the local minimum of a differentiable function. '
+ 'It updates parameters in the direction opposite to the '
+ 'gradient of the objective at the current point. Variants '
+ 'such as momentum and Adam improve convergence speed.')
+ assert _is_stuck(text) is False
+
+ def test_high_density_stuck(self):
+ # Pack many hesitation tokens to exceed 7/1000 density.
+ text = 'wait, wait this is wrong. hmm... actually no. uh, wait wait wait.'
+ assert _is_stuck(text) is True
+
+ def test_cascade_stuck(self):
+ # 5 cascade tokens in tight window
+ text = 'wait actually wait actually wait!'
+ assert _is_stuck(
+ text, hesitation_density_threshold=999.0, cascade_threshold=5, repetition_threshold=0.99) is True
+
+ def test_repetition_stuck(self):
+ phrase = 'the quick brown fox jumps over the lazy'
+ text = ' '.join([phrase] * 10)
+ assert _is_stuck(
+ text, hesitation_density_threshold=999.0, cascade_threshold=999, repetition_threshold=0.45) is True
+
+ def test_think_block_separate_thresholds(self):
+ # Hesitation that would trip in response section is allowed inside
+ # ... because think-thresholds are looser (15.0 vs 7.0).
+ # Build a think with moderate density (~10/1000) — below 15 think
+ # threshold, but would exceed 7 in normal text.
+ think_part = 'wait, actually let me reconsider this. ' * 3 + 'a' * 1500
+ text = f'{think_part}The answer is 42.'
+ assert _is_stuck(text) is False # think-density well below 15
+
+ def test_response_part_after_think_stuck(self):
+ # Clean think but stuck response → still stuck.
+ text = ('Calculating step by step.'
+ 'wait, wait this is wrong. hmm... actually no. uh, wait wait wait.')
+ assert _is_stuck(text) is True
+
+
+# ── DeadLoopFilter pipeline ─────────────────────────────────────────────────
+
+
+class TestDeadLoopFilterPipeline:
+
+ def test_drops_stuck_row(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'wait, wait this is wrong. hmm... actually no. '
+ 'uh, wait wait wait.'
+ },
+ ])
+ ]
+ assert _fil(rows) == []
+
+ def test_keeps_clean_row(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'A clear, well-formed answer goes here.'
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_agent_row_always_kept(self):
+ # is_agent=True bypasses all stuck checks.
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'wait wait wait wait wait wait wait!!!'
+ },
+ ],
+ is_agent=True)
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_no_assistant_kept(self):
+ rows = [_row([{'role': 'user', 'content': 'hi'}])]
+ assert len(_fil(rows)) == 1
+
+ def test_any_assistant_stuck_drops_row(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q1'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'clean reply'
+ },
+ {
+ 'role': 'user',
+ 'content': 'q2'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'wait, wait this is wrong. hmm... actually no. '
+ 'uh, wait wait wait.'
+ },
+ ])
+ ]
+ assert _fil(rows) == []
+
+ def test_empty_input(self):
+ assert _fil([]) == []
+
+ def test_custom_thresholds(self):
+ # 1 hesitation marker in a long message — density well below the
+ # default 7/1000. Tightening the threshold should drop it.
+ long_msg = ('Hmm, let me think about this carefully. Gradient descent '
+ 'requires a learning rate, the loss function, and an '
+ 'initial parameter point. The algorithm iteratively '
+ 'updates the parameters towards the negative gradient. '
+ 'Momentum-based variants accumulate past gradients to '
+ 'smooth the trajectory and accelerate convergence on '
+ 'ill-conditioned problems. Adam additionally adapts the '
+ 'per-parameter learning rate using running second-moment '
+ 'estimates, which often makes it the default choice for '
+ 'practitioners across many deep-learning tasks.')
+ rows = [_row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': long_msg
+ },
+ ])]
+ # Default 7/1000 — single marker in long text → kept
+ assert len(_fil(rows)) == 1
+ # Aggressive threshold drops it
+ assert _fil(rows, hesitation_density_threshold=0.5) == []
+
+ def test_chinese_stuck(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '等等,不对,让我重新想想。错了,让我再来一次。'
+ '我又搞错了。等等,等等。'
+ },
+ ])
+ ]
+ assert _fil(rows) == []
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/preprocessor/test_hard_filter.py b/tests/preprocessor/test_hard_filter.py
new file mode 100644
index 000000000..ab8f682ac
--- /dev/null
+++ b/tests/preprocessor/test_hard_filter.py
@@ -0,0 +1,364 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for HardFilter.
+
+HardFilter drops:
+ Rule 1 — Single-turn trivial query (greeting / bare wh-question).
+ Rule 2 — Two-turn shallow assistant reply (< min chars, no thinking chain).
+
+CJK and ASCII branches use different length thresholds because of the
+information density gap.
+"""
+import pytest
+
+from twinkle_agentic.preprocessor.hard_filter import HardFilter, _cjk_ratio, _has_thinking, _is_simple_query
+
+
+def _row(messages):
+ return {'messages': messages}
+
+
+# ── _cjk_ratio ───────────────────────────────────────────────────────────────
+
+
+class TestCjkRatio:
+
+ def test_pure_ascii(self):
+ assert _cjk_ratio('hello world') == 0.0
+
+ def test_pure_chinese(self):
+ assert _cjk_ratio('你好世界') == 1.0
+
+ def test_mixed(self):
+ # 2 CJK chars / 6 total
+ assert abs(_cjk_ratio('hi你好zz') - 2 / 6) < 1e-9
+
+ def test_japanese_hiragana(self):
+ # Hiragana is in the CJK range covered by the regex.
+ assert _cjk_ratio('こんにちは') == 1.0
+
+ def test_korean_hangul(self):
+ assert _cjk_ratio('안녕하세요') == 1.0
+
+ def test_empty(self):
+ # max(len, 1) → 0/1 = 0
+ assert _cjk_ratio('') == 0.0
+
+
+# ── _is_simple_query: ASCII / English ────────────────────────────────────────
+
+
+class TestSimpleQueryEnglish:
+
+ def test_short_text_is_simple(self):
+ assert _is_simple_query('hi') is True
+ assert _is_simple_query('a' * 9) is True # default min=10
+
+ def test_at_threshold_not_simple_unless_pattern(self):
+ # 10 non-pattern chars escapes both length and pattern checks
+ assert _is_simple_query('quantum xx') is False
+
+ def test_greeting_hello(self):
+ assert _is_simple_query('Hello!') is True
+ assert _is_simple_query('Heeellloooo') is True
+
+ def test_greeting_good_morning(self):
+ assert _is_simple_query('Good morning') is True
+
+ def test_greeting_how_are_you(self):
+ assert _is_simple_query('How are you') is True
+
+ def test_bare_wh_question(self):
+ assert _is_simple_query('what is python') is True
+
+ def test_imperative_short(self):
+ assert _is_simple_query('tell me about it') is True
+ assert _is_simple_query('explain') is True
+
+ def test_substantive_question_not_simple(self):
+ # Long, technical question should pass (not simple).
+ text = ('Please explain the difference between gradient descent and '
+ 'momentum-based optimization in deep learning training.')
+ assert _is_simple_query(text) is False
+
+
+class TestSimpleQueryChinese:
+
+ def test_short_cjk_is_simple(self):
+ assert _is_simple_query('你好') is True
+ assert _is_simple_query('你好啊') is True # < 6
+
+ def test_at_cjk_threshold(self):
+ # 6 CJK chars; greeting (`你好+` matches `你好好好好好`) → simple
+ assert _is_simple_query('你好好好好好') is True
+ # 6 substantive CJK chars; no greeting/simple pattern → NOT simple
+ assert _is_simple_query('量子计算原理') is False
+
+ def test_greeting_zh(self):
+ assert _is_simple_query('你好!') is True
+ assert _is_simple_query('早上好') is True
+ assert _is_simple_query('哈喽哈喽') is True
+
+ def test_what_is_x(self):
+ assert _is_simple_query('什么是机器学习?') is True
+ assert _is_simple_query('梯度下降是什么?') is True
+
+ def test_substantive_zh_not_simple(self):
+ text = '请详细解释一下变换器架构中的多头自注意力机制是如何并行计算的,以及为什么需要位置编码。'
+ assert _is_simple_query(text) is False
+
+
+class TestSimpleQueryJapanese:
+
+ def test_japanese_greeting(self):
+ assert _is_simple_query('こんにちは') is True
+
+ def test_japanese_what_is(self):
+ assert _is_simple_query('機械学習とは何ですか') is True
+
+
+class TestSimpleQueryKorean:
+
+ def test_korean_greeting(self):
+ assert _is_simple_query('안녕하세요') is True
+
+ def test_korean_what_is(self):
+ # KO_SIMPLE_RE expects "X이/가 뭐" pattern; trailing 인가요/까요 are
+ # only single optional chars, so use the bare 뭐 form here.
+ assert _is_simple_query('머신러닝이 뭐') is True
+
+
+class TestSimpleQueryEdge:
+
+ def test_empty(self):
+ assert _is_simple_query('') is True
+
+ def test_whitespace_only(self):
+ assert _is_simple_query(' \n ') is True
+
+ def test_custom_thresholds(self):
+ # Raise the bar so a 12-char query becomes simple.
+ text = 'short query!'
+ assert _is_simple_query(text, min_user_chars=20) is True
+ assert _is_simple_query(text, min_user_chars=5) is False
+
+
+# ── _has_thinking ────────────────────────────────────────────────────────────
+
+
+class TestHasThinking:
+
+ def test_thinking_field_long_enough(self):
+ msg = {'thinking': 'a' * 250}
+ assert _has_thinking(msg) is True
+
+ def test_thinking_field_too_short(self):
+ msg = {'thinking': 'short'}
+ assert _has_thinking(msg) is False
+
+ def test_reasoning_content_alias(self):
+ msg = {'reasoning_content': 'a' * 250}
+ assert _has_thinking(msg) is True
+
+ def test_no_thinking(self):
+ assert _has_thinking({'content': 'reply'}) is False
+
+ def test_custom_min_chars(self):
+ msg = {'thinking': 'short'}
+ assert _has_thinking(msg, min_chars=3) is True
+
+ def test_non_string_thinking_truthy(self):
+ # Falls through to bool(thinking)
+ assert _has_thinking({'thinking': {'a': 1}}) is True
+ assert _has_thinking({'thinking': []}) is False
+
+
+# ── HardFilter pipeline ──────────────────────────────────────────────────────
+
+
+def _fil(rows, **kw):
+ return HardFilter(**kw)(rows)
+
+
+class TestRule1SimpleQuery:
+
+ def test_drops_greeting_only(self):
+ rows = [_row([
+ {
+ 'role': 'user',
+ 'content': 'hello'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hi there!'
+ },
+ ])]
+ assert _fil(rows) == []
+
+ def test_drops_bare_wh_question(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'what is AI'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'a short answer'
+ },
+ ])
+ ]
+ assert _fil(rows) == []
+
+ def test_keeps_when_substantive(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'Could you explain gradient descent step by step in detail?'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'Gradient descent is an iterative optimization algorithm... ' * 5
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_keeps_simple_query_with_thinking(self):
+ # Rule 1 rescue: thinking chain ≥200 chars saves the row.
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello',
+ 'reasoning_content': 'Now I need to greet politely... ' * 20
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_simple_query_no_assistant_dropped(self):
+ # No assistant turn → no thinking → dropped.
+ rows = [_row([{'role': 'user', 'content': 'hi'}])]
+ assert _fil(rows) == []
+
+
+class TestRule2ShallowReply:
+
+ def test_drops_short_reply(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'Explain the difference between A and B in detail please.'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'A is good.'
+ }, # < 80 chars
+ ])
+ ]
+ assert _fil(rows) == []
+
+ def test_keeps_long_reply(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'Explain the difference between A and B in detail please.'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'A and B differ in several ways. ' * 5
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_short_reply_with_thinking_kept(self):
+ # Rule 2 rescue: thinking saves a short final reply.
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'Explain the difference between A and B in detail please.'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'A is good.',
+ 'thinking': 'Step 1: compare features... ' * 20
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+
+class TestPipelineEdges:
+
+ def test_no_user_dropped_by_default(self):
+ rows = [_row([{'role': 'assistant', 'content': 'orphan reply'}])]
+ assert _fil(rows) == []
+
+ def test_no_user_kept_when_allowed(self):
+ rows = [_row([{'role': 'assistant', 'content': 'orphan'}])]
+ assert len(_fil(rows, allow_incomplete_role=True)) == 1
+
+ def test_multi_user_skips_rules(self):
+ # With ≥2 user turns, neither Rule 1 nor Rule 2 applies.
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'short'
+ },
+ {
+ 'role': 'user',
+ 'content': 'follow-up?'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'tiny'
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_non_list_messages(self):
+ rows = [{'messages': 'not a list'}]
+ assert _fil(rows) == [] # invalid → continue (skip)
+
+ def test_missing_messages(self):
+ rows = [{'id': 'x'}]
+ # No user_msgs and allow_incomplete_role=False → skipped.
+ assert _fil(rows) == []
+
+ def test_empty_input(self):
+ assert _fil([]) == []
+
+ def test_custom_thresholds_applied(self):
+ # Lower min_assistant_chars_2turn → keep what would normally be dropped.
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'tell me a real story please now'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'A is good.'
+ },
+ ])
+ ]
+ assert _fil(rows, min_assistant_chars_2turn=5) and len(rows) == 1
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/preprocessor/test_intent_classifier.py b/tests/preprocessor/test_intent_classifier.py
new file mode 100644
index 000000000..ace65c74e
--- /dev/null
+++ b/tests/preprocessor/test_intent_classifier.py
@@ -0,0 +1,507 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for the heuristic IntentClassifier pipeline.
+
+Focus areas:
+- Per-detector recall on representative samples (ZH + EN, R1-distill-flavoured).
+- Per-detector FP guards (chitchat, role mismatch, first-turn dissatisfaction).
+- Multi-detector ordering: ToolCallDetector short-circuit, ``setdefault`` semantics.
+- Edge cases: empty / None / non-dict / list-content messages, empty trajectories.
+- Public API contract: ``row['intent']``, user_data List[Tuple] entries ``('key_rounds', ...)`` / ``('intents', ...)``.
+- Detector pluggability: custom subclass, overriding ``DEFAULT_DETECTORS``.
+"""
+import pytest
+
+from twinkle_agentic.preprocessor.intent_classifier import (INTENT_CODE, INTENT_MATH, INTENT_OTHER, INTENT_TOOL_CALL,
+ INTENT_USER_DISSATISFACTION, CodeDetector, IntentClassifier,
+ IntentDetector, MathDetector, ToolCallDetector,
+ UserDissatisfactionDetector, _msg_text, _pair_assistant)
+from twinkle.data_format import pack_value, user_data_get
+
+# ── Helpers ───────────────────────────────────────────────────────────────────
+
+
+def _u(text):
+ return {'role': 'user', 'content': text}
+
+
+def _a(text, **extra):
+ msg = {'role': 'assistant', 'content': text}
+ msg.update(extra)
+ return msg
+
+
+def _row(*messages):
+ return {'messages': list(messages)}
+
+
+def _classify_one(*messages, detectors=None):
+ # Keep no-key-round rows (e.g. chitchat) so tests can assert ``intent == INTENT_OTHER``.
+ ic = IntentClassifier(detectors=detectors, drop_no_key_rounds=False)
+ kept, _dropped = ic([_row(*messages)])
+ return kept[0]
+
+
+def _ud_get(row, key):
+ """Read a value by key from packed user_data (JSON-decoded)."""
+ return user_data_get(row.get('user_data'), key)
+
+
+def _has_key_rounds(row):
+ kr = _ud_get(row, 'key_rounds')
+ return isinstance(kr, list) and bool(kr)
+
+
+# ── Helper functions ──────────────────────────────────────────────────────────
+
+
+class TestHelpers:
+
+ def test_msg_text_string(self):
+ assert _msg_text({'content': 'hi'}) == 'hi'
+
+ def test_msg_text_list_with_text_parts(self):
+ msg = {
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'foo'
+ },
+ {
+ 'type': 'image',
+ 'url': 'x'
+ },
+ {
+ 'type': 'text',
+ 'text': 'bar'
+ },
+ ]
+ }
+ assert _msg_text(msg) == 'foo bar'
+
+ def test_msg_text_missing_content(self):
+ assert _msg_text({}) == ''
+
+ def test_msg_text_none_content(self):
+ assert _msg_text({'content': None}) == ''
+
+ def test_msg_text_list_no_text_parts(self):
+ assert _msg_text({'content': [{'type': 'image'}]}) == ''
+
+ def test_pair_assistant_user_finds_next_assistant(self):
+ msgs = [_u('q'), _a('a1'), _u('follow'), _a('a2')]
+ assert _pair_assistant(msgs, 0, 'user') == 1
+ assert _pair_assistant(msgs, 2, 'user') == 3
+
+ def test_pair_assistant_assistant_returns_self(self):
+ msgs = [_u('q'), _a('a1')]
+ assert _pair_assistant(msgs, 1, 'assistant') == 1
+
+ def test_pair_assistant_user_no_following_assistant(self):
+ # User turn at the tail with no assistant after — un-pairable.
+ msgs = [_a('a1'), _u('dangling')]
+ assert _pair_assistant(msgs, 1, 'user') is None
+
+ def test_pair_assistant_other_role(self):
+ assert _pair_assistant([{'role': 'system', 'content': 's'}], 0, 'system') is None
+
+
+# ── ToolCallDetector ─────────────────────────────────────────────────────────
+
+
+class TestToolCallDetector:
+
+ def test_definitive_flag(self):
+ assert ToolCallDetector.definitive is True
+
+ def test_detects_assistant_with_tool_calls(self):
+ msgs = [_u('q'), _a('', tool_calls=[{'name': 'f'}])]
+ assert ToolCallDetector()(msgs) == [1]
+
+ def test_ignores_assistant_without_tool_calls(self):
+ assert ToolCallDetector()([_u('q'), _a('plain')]) == []
+
+ def test_ignores_user_with_tool_calls_field(self):
+ # A user dict carrying a tool_calls key must not be picked up.
+ msgs = [{'role': 'user', 'content': 'q', 'tool_calls': [{'name': 'x'}]}]
+ assert ToolCallDetector()(msgs) == []
+
+ def test_short_circuits_pipeline(self):
+ # When ToolCall fires it must suppress later detectors on the same round.
+ msgs = [
+ _u('解一元二次方程 x^2 - 5x + 6 = 0 的因式分解'),
+ _a('answer', tool_calls=[{
+ 'name': 'calc'
+ }]),
+ ]
+ out = _classify_one(*msgs)
+ assert out['intent'] == INTENT_TOOL_CALL
+ # math detector must not have written into intents.
+ assert _ud_get(out, 'intents') == {'1': INTENT_TOOL_CALL}
+
+
+# ── CodeDetector ──────────────────────────────────────────────────────────────
+
+
+class TestCodeDetector:
+
+ def test_fenced_code_block(self):
+ text = '```python\ndef f():\n return 1\n```'
+ assert CodeDetector()._match(text)
+
+ def test_short_fenced_block_below_min_length(self):
+ # Block content must be ≥10 chars to qualify.
+ assert not CodeDetector()._match('```\nhi\n```')
+
+ def test_keyword_threshold_three(self):
+ # Three keyword hits must trigger.
+ assert CodeDetector()._match('use async function and await the response')
+
+ def test_two_keywords_below_threshold(self):
+ assert not CodeDetector()._match('a class and a function')
+
+ def test_arrow_signature_alone_insufficient(self):
+ # Single arrow without other signals doesn't reach threshold.
+ assert not CodeDetector()._match('x => x + 1')
+
+ def test_call_signature_with_brace(self):
+ # `name(args) {` is a strong code indicator.
+ assert CodeDetector()._match('function fetchData(url) { return fetch(url); } and async await yield')
+
+ def test_chitchat_with_word_class_no_fp(self):
+ assert not CodeDetector()._match('I took a yoga class today')
+
+
+# ── MathDetector ──────────────────────────────────────────────────────────────
+
+
+class TestMathDetector:
+
+ @pytest.mark.parametrize('text', [
+ '设 $f(x)=x^2$ 求导得 2x',
+ '矩阵 A 的行列式 det(A) 不等于 0',
+ '三角形 ABC 周长是 12,面积约为 6',
+ '数列 {a_n} 是等差数列,公差为 2,首项为 1',
+ '4, 3, 4, 3, (),奇数位是 4',
+ 'Σ_{i=1}^n A_{ik} B_{kj}',
+ 'gradient and integral are both fundamental',
+ '求一元二次方程 x^2 - 5x + 6 = 0 的解',
+ '一个圆形的直径是 10cm,所以周长是 10π',
+ ])
+ def test_math_recall(self, text):
+ assert MathDetector()._match(text), f'should detect: {text!r}'
+
+ @pytest.mark.parametrize(
+ 'text',
+ [
+ '今天天气真好',
+ '我最近在追一部电视剧',
+ '帮我写一首诗',
+ '请帮我翻译这句英文',
+ # Single math keyword in non-math context — must not trip ≥2 threshold.
+ '积分兑换可以兑换礼品',
+ '矩阵这个电影很好看',
+ ])
+ def test_math_fp_guard(self, text):
+ assert not MathDetector()._match(text), f'must NOT detect: {text!r}'
+
+ def test_arithmetic_equation_single_hit(self):
+ # Only the arithmetic equation matches, threshold ≥2 not met.
+ assert not MathDetector()._match('计算 30 ÷ 6 = 5')
+
+ def test_threshold_is_configurable(self):
+ # Subclass with looser threshold catches single-hit case.
+ class LooseMath(MathDetector):
+ threshold = 1
+
+ assert LooseMath()._match('计算 30 ÷ 6 = 5')
+
+ def test_subscript_pattern(self):
+ assert MathDetector()._match('矩阵元素 a_{ij} 与 b_{kl} 满足条件')
+
+
+# ── UserDissatisfactionDetector ───────────────────────────────────────────────
+
+
+class TestUserDissatisfactionDetector:
+
+ @pytest.mark.parametrize('text', [
+ '不对,再来一次',
+ '完全错了',
+ '答非所问',
+ '你这是在胡扯',
+ '太离谱了',
+ '一塌糊涂',
+ '没逻辑啊',
+ '你根本没听懂我的意思',
+ '我说的不是这个',
+ '别瞎编',
+ '什么玩意',
+ '不靠谱',
+ '让我失望',
+ '不严谨',
+ '没get到',
+ ])
+ def test_zh_recall(self, text):
+ assert UserDissatisfactionDetector()._match(text)
+
+ @pytest.mark.parametrize('text', [
+ 'this is wrong',
+ 'totally incorrect',
+ 'try again please',
+ "doesn't make sense",
+ 'that is garbage',
+ 'you misunderstood me',
+ 'low quality response',
+ 'completely off topic',
+ 'are you serious',
+ 'waste of time',
+ 'this is bullshit',
+ 'redo it',
+ 'sub-par answer',
+ 'do better',
+ 'WTF is this',
+ 'nowhere near correct',
+ ])
+ def test_en_recall(self, text):
+ assert UserDissatisfactionDetector()._match(text)
+
+ @pytest.mark.parametrize('text', [
+ '今天心情很好',
+ '我喜欢这个回答',
+ '请帮我修改一下',
+ 'this is exactly what I wanted',
+ 'great answer thanks',
+ '能再详细一点吗',
+ ])
+ def test_fp_guard(self, text):
+ det = UserDissatisfactionDetector()
+ assert not det._match(text), f'FP on: {text!r}'
+
+ def test_first_turn_user_complaint_ignored(self):
+ # No prior assistant — the negative phrasing is part of the initial query, not a reaction.
+ msgs = [_u('你这答案完全错了,太垃圾'), _a('sorry')]
+ assert UserDissatisfactionDetector()(msgs) == []
+
+ def test_system_first_then_user_complaint_ignored(self):
+ msgs = [
+ {
+ 'role': 'system',
+ 'content': 'You are helpful.'
+ },
+ _u('上次回答简直一塌糊涂'),
+ _a('sorry'),
+ ]
+ # System turn must not satisfy "prior assistant".
+ assert UserDissatisfactionDetector()(msgs) == []
+
+ def test_multiturn_reaction_detected(self):
+ msgs = [_u('解释勾股定理'), _a('a²+b²=c²'), _u('不对,再来一次'), _a('好的')]
+ # The dissat user is at idx 2 → key round is the next assistant idx 3.
+ assert UserDissatisfactionDetector()(msgs) == [3]
+
+ def test_dissat_with_no_following_assistant_dropped(self):
+ # User dissatisfaction at the tail with no assistant pair → unpaired, no key round.
+ msgs = [_u('q'), _a('answer'), _u('完全错了')]
+ assert UserDissatisfactionDetector()(msgs) == []
+
+ def test_role_filter_blocks_assistant_self_correction(self):
+ # "等等我算错了,重新推导" appearing on assistant must not be tagged dissatisfaction.
+ msgs = [_u('推导一下'), _a('等等,我之前算错了,让我重新推导')]
+ assert UserDissatisfactionDetector()(msgs) == []
+
+
+# ── End-to-end IntentClassifier ───────────────────────────────────────────────
+
+
+class TestIntentClassifierE2E:
+
+ def test_chitchat_other(self):
+ out = _classify_one(_u('今天天气真好'), _a('是的,挺适合出门的'))
+ assert out['intent'] == INTENT_OTHER
+ assert not _has_key_rounds(out)
+
+ def test_math_round(self):
+ out = _classify_one(
+ _u('求一元二次方程 x^2 - 5x + 6 = 0 的解'),
+ _a('由因式分解得 (x-2)(x-3)=0'),
+ )
+ assert out['intent'] == INTENT_MATH
+ assert _ud_get(out, 'key_rounds') == [1]
+ assert _ud_get(out, 'intents') == {'1': INTENT_MATH}
+
+ def test_code_round(self):
+ out = _classify_one(
+ _u('use async function and await the response in JavaScript'),
+ _a('try const fetchData = async () => { return await fetch(url); }'),
+ )
+ assert out['intent'] == INTENT_CODE
+
+ def test_dissat_round(self):
+ out = _classify_one(_u('q'), _a('answer'), _u('totally garbage answer, redo'), _a('sorry'))
+ assert out['intent'] == INTENT_USER_DISSATISFACTION
+ assert _ud_get(out, 'key_rounds') == [3]
+
+ def test_assistant_self_correction_not_dissat(self):
+ # Root cause for original FP: role-agnostic regex on assistant text. Must stay fixed.
+ out = _classify_one(_u('推导一下'), _a('等等,我之前算错了,让我重新推导...'))
+ assert out['intent'] == INTENT_OTHER
+
+ def test_first_turn_user_negative_words_not_dissat(self):
+ out = _classify_one(_u('你这答案完全错了,太垃圾'), _a('抱歉'))
+ assert out['intent'] == INTENT_OTHER
+
+ def test_setdefault_earlier_detector_wins(self):
+ # When a round is first claimed by MathDetector, a later UserDissatisfactionDetector
+ # touching the same round must not overwrite it.
+ out = _classify_one(
+ _u('解一元二次方程 x^2 - 5x + 6 = 0'),
+ _a('factoring: (x-2)(x-3)'),
+ _u('不对,再来一次'),
+ _a('好的'),
+ )
+ intents = _ud_get(out, 'intents')
+ assert intents['1'] == INTENT_MATH
+ assert intents['3'] == INTENT_USER_DISSATISFACTION
+
+ def test_tool_call_definitive_short_circuits(self):
+ out = _classify_one(
+ _u('解一元二次方程 x^2 - 5x + 6 = 0'),
+ _a('', tool_calls=[{
+ 'name': 'calc'
+ }]),
+ )
+ assert out['intent'] == INTENT_TOOL_CALL
+ # MathDetector must not have run after the definitive ToolCallDetector.
+ assert set(_ud_get(out, 'intents').values()) == {INTENT_TOOL_CALL}
+
+ def test_multimodal_list_content(self):
+ # List-content messages must work transparently.
+ msgs = [
+ _u([{
+ 'type': 'text',
+ 'text': '求一元二次方程'
+ }, {
+ 'type': 'image',
+ 'url': 'x'
+ }]),
+ _a([{
+ 'type': 'text',
+ 'text': '因式分解后得到结果'
+ }]),
+ ]
+ out = _classify_one(*msgs)
+ assert out['intent'] == INTENT_MATH
+
+
+# ── Edge / robustness ─────────────────────────────────────────────────────────
+
+
+class TestEdgeCases:
+
+ def test_empty_rows(self):
+ kept, _ = IntentClassifier()([])
+ assert kept == []
+
+ def test_missing_messages_field(self):
+ kept, _ = IntentClassifier(drop_no_key_rounds=False)([{'foo': 'bar'}])
+ assert kept[0]['intent'] == INTENT_OTHER
+
+ def test_messages_is_none(self):
+ kept, _ = IntentClassifier(drop_no_key_rounds=False)([{'messages': None}])
+ assert kept[0]['intent'] == INTENT_OTHER
+
+ def test_messages_empty_list(self):
+ kept, _ = IntentClassifier(drop_no_key_rounds=False)([{'messages': []}])
+ assert kept[0]['intent'] == INTENT_OTHER
+
+ def test_messages_with_non_dict_entries(self):
+ # Non-dict entries must be silently skipped.
+ kept, _ = IntentClassifier()([{
+ 'messages': [
+ 'not a dict',
+ None,
+ _u('求一元二次方程'),
+ _a('因式分解'),
+ ]
+ }])
+ assert kept[0]['intent'] == INTENT_MATH
+
+ def test_user_data_preexists_preserved(self):
+ # IntentClassifier must merge into existing packed user_data without clobbering.
+ rows = [{
+ 'messages': [_u('解一元二次方程 x^2'), _a('因式分解 (x-2)(x-3)')],
+ 'user_data': [('source', pack_value('gsm8k')), ('difficulty', pack_value('easy'))],
+ }]
+ kept, _ = IntentClassifier()(rows)
+ row = kept[0]
+ assert _ud_get(row, 'source') == 'gsm8k'
+ assert _ud_get(row, 'difficulty') == 'easy'
+ assert _ud_get(row, 'key_rounds') == [1]
+ assert _ud_get(row, 'intents') == {'1': INTENT_MATH}
+
+ def test_input_row_not_mutated(self):
+ # IntentClassifier must shallow-copy rows; original dict must remain untouched.
+ original = {'messages': [_u('你好'), _a('hi')]}
+ IntentClassifier(drop_no_key_rounds=False)([original])
+ assert 'intent' not in original
+ assert 'user_data' not in original
+
+ def test_other_intent_does_not_emit_user_data(self):
+ out = _classify_one(_u('你好'), _a('hi'))
+ # No detectors fired → no key_rounds / intents written.
+ assert not _has_key_rounds(out)
+
+
+# ── Pluggability ──────────────────────────────────────────────────────────────
+
+
+class TestPluggability:
+
+ def test_custom_detector_via_constructor(self):
+
+ class GreetingDetector(IntentDetector):
+ intent = 'greeting'
+
+ def __call__(self, messages):
+ return [
+ i for i, m in enumerate(messages) if isinstance(m, dict) and m.get('role') == 'assistant'
+ and isinstance(m.get('content'), str) and 'hello' in m['content'].lower()
+ ]
+
+ ic = IntentClassifier(detectors=[GreetingDetector()])
+ kept, _ = ic([_row(_u('hi'), _a('Hello there'))])
+ assert kept[0]['intent'] == 'greeting'
+
+ def test_empty_detector_list_yields_other(self):
+ ic = IntentClassifier(detectors=[], drop_no_key_rounds=False)
+ kept, _ = ic([_row(_u('q'), _a('因式分解 一元二次方程'))])
+ assert kept[0]['intent'] == INTENT_OTHER
+
+ def test_intent_field_override(self):
+ ic = IntentClassifier(intent_field='label', drop_no_key_rounds=False)
+ kept, _ = ic([_row(_u('q'), _a('a'))])
+ assert 'label' in kept[0]
+ assert 'intent' not in kept[0]
+
+ def test_definitive_short_circuits_custom_pipeline(self):
+ # User-defined definitive detector must halt the pipeline after firing.
+ seen = []
+
+ class StopAll(IntentDetector):
+ intent = 'stop'
+ definitive = True
+
+ def __call__(self, messages):
+ seen.append('stop')
+ return [len(messages) - 1]
+
+ class NeverRuns(IntentDetector):
+ intent = 'never'
+
+ def __call__(self, messages):
+ seen.append('never')
+ return [0]
+
+ ic = IntentClassifier(detectors=[StopAll(), NeverRuns()])
+ ic([_row(_u('q'), _a('a'))])
+ assert seen == ['stop']
diff --git a/tests/preprocessor/test_message_sanity.py b/tests/preprocessor/test_message_sanity.py
new file mode 100644
index 000000000..d51eebe67
--- /dev/null
+++ b/tests/preprocessor/test_message_sanity.py
@@ -0,0 +1,844 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for MessageSanityFilter preprocessor."""
+import pytest
+
+from twinkle_agentic.preprocessor.message_sanity import (MessageSanityFilter, _trim_to_last_assistant,
+ _validate_content_integrity, _validate_role_order,
+ _validate_tool_call_matching)
+
+# ── Fixtures ──────────────────────────────────────────────────────────────────
+
+
+def _make_rows(messages_list):
+ """Wrap messages lists into row-format for the filter."""
+ return [{'messages': m} for m in messages_list]
+
+
+def _run_filter(messages_list, **kwargs):
+ """Run MessageSanityFilter on a list of message sequences, return surviving messages."""
+ f = MessageSanityFilter(**kwargs)
+ rows = _make_rows(messages_list)
+ result = f.message_sanity_filter(rows)
+ return [r['messages'] for r in result]
+
+
+# ── Role order tests ──────────────────────────────────────────────────────────
+
+
+class TestRoleOrder:
+
+ def test_valid_simple(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _validate_role_order(msgs) is True
+
+ def test_valid_with_system(self):
+ msgs = [
+ {
+ 'role': 'system',
+ 'content': 'You are helpful.'
+ },
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _validate_role_order(msgs) is True
+
+ def test_system_not_first(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'system',
+ 'content': 'late system'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _validate_role_order(msgs) is False
+
+ def test_tool_without_tool_calls(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'let me check'
+ },
+ {
+ 'role': 'tool',
+ 'content': 'result',
+ 'tool_call_id': 'x'
+ },
+ ]
+ assert _validate_role_order(msgs) is False
+
+ def test_tool_after_assistant_with_tool_calls(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'search'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [{
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'arguments': '{}'
+ }
+ }]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'found it',
+ 'tool_call_id': 'c1'
+ },
+ ]
+ assert _validate_role_order(msgs) is True
+
+ def test_tool_after_user(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'tool',
+ 'content': 'bad',
+ 'tool_call_id': 'x'
+ },
+ ]
+ assert _validate_role_order(msgs) is False
+
+ def test_invalid_role(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'bot',
+ 'content': 'hello'
+ },
+ ]
+ assert _validate_role_order(msgs) is False
+
+ def test_empty(self):
+ assert _validate_role_order([]) is False
+
+ def test_consecutive_tools(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'do things'
+ },
+ {
+ 'role':
+ 'assistant',
+ 'content':
+ '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'a',
+ 'arguments': '{}'
+ }
+ },
+ {
+ 'id': 'c2',
+ 'type': 'function',
+ 'function': {
+ 'name': 'b',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'res1',
+ 'tool_call_id': 'c1'
+ },
+ {
+ 'role': 'tool',
+ 'content': 'res2',
+ 'tool_call_id': 'c2'
+ },
+ ]
+ assert _validate_role_order(msgs) is True
+
+
+# ── Tool call matching tests ──────────────────────────────────────────────────
+
+
+class TestToolCallMatching:
+
+ def test_valid_matching(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'fn',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'ok',
+ 'tool_call_id': 'c1'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'done'
+ },
+ ]
+ assert _validate_tool_call_matching(msgs) is True
+
+ def test_orphan_tool_calls(self):
+ """Assistant has tool_calls but no tool response follows."""
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'fn',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'user',
+ 'content': 'what happened?'
+ },
+ ]
+ assert _validate_tool_call_matching(msgs) is False
+
+ def test_phantom_tool_response(self):
+ """Tool response references an ID not in the assistant's tool_calls."""
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'fn',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'ok',
+ 'tool_call_id': 'WRONG_ID'
+ },
+ ]
+ assert _validate_tool_call_matching(msgs) is False
+
+ def test_partial_response_ok(self):
+ """Only some tool_calls get responses — currently allowed."""
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role':
+ 'assistant',
+ 'content':
+ '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'a',
+ 'arguments': '{}'
+ }
+ },
+ {
+ 'id': 'c2',
+ 'type': 'function',
+ 'function': {
+ 'name': 'b',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'res1',
+ 'tool_call_id': 'c1'
+ },
+ ]
+ assert _validate_tool_call_matching(msgs) is True
+
+ def test_no_tool_calls_passes(self):
+ """Conversations without tool_calls pass trivially."""
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _validate_tool_call_matching(msgs) is True
+
+
+# ── Content integrity tests ───────────────────────────────────────────────────
+
+
+class TestContentIntegrity:
+
+ def test_valid_basic(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello there'
+ },
+ ]
+ assert _validate_content_integrity(msgs) is True
+
+ def test_empty_assistant(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': ''
+ },
+ ]
+ assert _validate_content_integrity(msgs) is False
+
+ def test_assistant_with_tool_calls_no_content_ok(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'search'
+ },
+ {
+ 'role':
+ 'assistant',
+ 'content':
+ '',
+ 'tool_calls': [{
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search_web',
+ 'arguments': '{"q":"test"}'
+ }
+ }]
+ },
+ ]
+ assert _validate_content_integrity(msgs) is True
+
+ def test_empty_system(self):
+ msgs = [
+ {
+ 'role': 'system',
+ 'content': ''
+ },
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _validate_content_integrity(msgs) is False
+
+ def test_too_long_message(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'x' * 60000
+ },
+ ]
+ assert _validate_content_integrity(msgs, max_msg_chars=50000) is False
+
+ def test_invalid_tool_call_structure(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'function': 'not_a_dict'
+ }, # function must be dict
+ ]
+ },
+ ]
+ assert _validate_content_integrity(msgs) is False
+
+ def test_invalid_function_name(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': '123bad',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ ]
+ assert _validate_content_integrity(msgs) is False
+
+ def test_invalid_arguments_json(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'fn',
+ 'arguments': '{invalid json'
+ }
+ },
+ ]
+ },
+ ]
+ assert _validate_content_integrity(msgs) is False
+
+ def test_dict_arguments_ok(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'go'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'fn',
+ 'arguments': {
+ 'key': 'val'
+ }
+ }
+ },
+ ]
+ },
+ ]
+ assert _validate_content_integrity(msgs) is True
+
+ def test_duplicate_user_messages(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _validate_content_integrity(msgs) is False
+
+ def test_duplicate_tool_messages_allowed(self):
+ """Two consecutive tool messages with same content should NOT be rejected."""
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'search both'
+ },
+ {
+ 'role':
+ 'assistant',
+ 'content':
+ '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'arguments': '{"q":"x"}'
+ }
+ },
+ {
+ 'id': 'c2',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'arguments': '{"q":"x"}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'same result',
+ 'tool_call_id': 'c1'
+ },
+ {
+ 'role': 'tool',
+ 'content': 'same result',
+ 'tool_call_id': 'c2'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'both returned same'
+ },
+ ]
+ assert _validate_content_integrity(msgs) is True
+
+ def test_min_turns(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ # min_turns=2 → user(1)+assistant(1)=2 >= 2 → pass
+ assert _validate_content_integrity(msgs, min_turns=2) is True
+ # min_turns=3 → total=2 < 3 → fail
+ assert _validate_content_integrity(msgs, min_turns=3) is False
+
+
+# ── Trim tests ────────────────────────────────────────────────────────────────
+
+
+class TestTrimToLastAssistant:
+
+ def test_already_ends_with_assistant(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ assert _trim_to_last_assistant(msgs) == msgs
+
+ def test_trim_trailing_user(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ {
+ 'role': 'user',
+ 'content': 'bye'
+ },
+ ]
+ assert _trim_to_last_assistant(msgs) == msgs[:2]
+
+ def test_no_assistant(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'user',
+ 'content': 'hello?'
+ },
+ ]
+ assert _trim_to_last_assistant(msgs) == []
+
+
+# ── Sensitive word tests ──────────────────────────────────────────────────────
+
+
+class TestSensitiveWords:
+
+ def test_english_word_boundary(self):
+ msgs_clean = [
+ {
+ 'role': 'user',
+ 'content': 'hello world'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hi there'
+ },
+ ]
+ msgs_bad = [
+ {
+ 'role': 'user',
+ 'content': 'hello world'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'what the fuck'
+ },
+ ]
+ result = _run_filter(
+ [msgs_clean, msgs_bad],
+ extra_sensitive_words=['fuck'],
+ )
+ assert len(result) == 1
+ assert result[0] == msgs_clean
+
+ def test_chinese_sensitive(self):
+ msgs_bad = [
+ {
+ 'role': 'user',
+ 'content': '你好'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '操你妈'
+ },
+ ]
+ result = _run_filter(
+ [msgs_bad],
+ extra_sensitive_words=['操你妈'],
+ )
+ assert len(result) == 0
+
+ def test_no_sensitive_config_passes_all(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'fuck'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ ]
+ # No sensitive words configured → everything passes
+ result = _run_filter([msgs])
+ assert len(result) == 1
+
+
+# ── End-to-end filter tests ───────────────────────────────────────────────────
+
+
+class TestEndToEnd:
+
+ def test_full_valid_agentic_trajectory(self):
+ msgs = [
+ {
+ 'role': 'system',
+ 'content': 'You are a helpful assistant.'
+ },
+ {
+ 'role': 'user',
+ 'content': 'What is the weather?'
+ },
+ {
+ 'role':
+ 'assistant',
+ 'content':
+ '',
+ 'tool_calls': [
+ {
+ 'id': 'call_1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'get_weather',
+ 'arguments': '{"city":"Beijing"}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': '{"temp": 22, "condition": "sunny"}',
+ 'tool_call_id': 'call_1'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'It is 22°C and sunny in Beijing.'
+ },
+ ]
+ result = _run_filter([msgs])
+ assert len(result) == 1
+
+ def test_trim_and_validate(self):
+ """Trailing user message gets trimmed, result still valid."""
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello'
+ },
+ {
+ 'role': 'user',
+ 'content': 'thanks'
+ },
+ ]
+ result = _run_filter([msgs])
+ assert len(result) == 1
+ assert result[0][-1]['role'] == 'assistant'
+
+ def test_no_assistant_discarded(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'hi'
+ },
+ {
+ 'role': 'user',
+ 'content': 'hello?'
+ },
+ ]
+ result = _run_filter([msgs])
+ assert len(result) == 0
+
+ def test_multiple_tool_rounds(self):
+ msgs = [
+ {
+ 'role': 'user',
+ 'content': 'plan a trip'
+ },
+ {
+ 'role':
+ 'assistant',
+ 'content':
+ '',
+ 'tool_calls': [
+ {
+ 'id': 'c1',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search_flights',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'flight options...',
+ 'tool_call_id': 'c1'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'Found flights. Let me check hotels.',
+ 'tool_calls': [
+ {
+ 'id': 'c2',
+ 'type': 'function',
+ 'function': {
+ 'name': 'search_hotels',
+ 'arguments': '{}'
+ }
+ },
+ ]
+ },
+ {
+ 'role': 'tool',
+ 'content': 'hotel options...',
+ 'tool_call_id': 'c2'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'Here is your complete trip plan.'
+ },
+ ]
+ result = _run_filter([msgs])
+ assert len(result) == 1
diff --git a/tests/preprocessor/test_pii_presidio_filter.py b/tests/preprocessor/test_pii_presidio_filter.py
new file mode 100644
index 000000000..37a64232b
--- /dev/null
+++ b/tests/preprocessor/test_pii_presidio_filter.py
@@ -0,0 +1,224 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for pure helpers in pii_presidio_filter.
+
+Only validators and replacement primitives are tested here — the full
+``PIIPresidioFilter`` requires presidio_analyzer + spacy + faker which are
+heavy/optional deps. Pure helpers are usable standalone and have clear
+mathematical contracts.
+
+Coverage:
+ * ``_is_valid_cn_id`` — 18-digit checksum (last digit may be 'X')
+ * ``_is_valid_luhn`` — Luhn algorithm with min length 13
+ * ``_mask_keep_edges`` — keep head/tail, mask middle
+ * ``_hash_short`` — SHA-256 prefix, deterministic w/ salt
+ * ``Strategy.coerce`` — enum coercion + strict failure mode
+"""
+import hashlib
+import pytest
+
+from twinkle_agentic.preprocessor.pii_presidio_filter import (Strategy, _hash_short, _is_valid_cn_id, _is_valid_luhn,
+ _mask_keep_edges)
+
+# ── _is_valid_cn_id ─────────────────────────────────────────────────────────
+
+
+class TestIsValidCnId:
+ """
+ Verified against the official GB 11643-1999 weights:
+ weights = (7,9,10,5,8,4,2,1,6,3,7,9,10,5,8,4,2)
+ checks = '10X98765432'
+ Test ID `11010519491231002X` is a textbook valid example.
+ """
+
+ def test_valid_id_with_x_check(self):
+ assert _is_valid_cn_id('11010519491231002X') is True
+
+ def test_valid_id_with_x_lowercase(self):
+ # Implementation upper-cases the check digit before compare.
+ assert _is_valid_cn_id('11010519491231002x') is True
+
+ def test_invalid_check_digit(self):
+ # Flip the last char to a wrong number.
+ assert _is_valid_cn_id('110105194912310020') is False
+
+ def test_too_short(self):
+ assert _is_valid_cn_id('110105194912310') is False
+
+ def test_too_long(self):
+ assert _is_valid_cn_id('11010519491231002X9') is False
+
+ def test_non_digit_in_first_17(self):
+ assert _is_valid_cn_id('1101051949123100AX') is False
+
+ def test_empty(self):
+ assert _is_valid_cn_id('') is False
+
+ def test_18_digits_invalid_checksum(self):
+ # 18 digits but last is wrong number
+ assert _is_valid_cn_id('110105194912310029') is False
+
+
+# ── _is_valid_luhn ──────────────────────────────────────────────────────────
+
+
+class TestIsValidLuhn:
+ """
+ `4532015112830366` is a well-known Visa test number that satisfies Luhn.
+ """
+
+ def test_valid_visa_test_number(self):
+ assert _is_valid_luhn('4532015112830366') is True
+
+ def test_valid_with_separators(self):
+ # Implementation strips non-digits via `c.isdigit()`.
+ assert _is_valid_luhn('4532-0151-1283-0366') is True
+ assert _is_valid_luhn('4532 0151 1283 0366') is True
+
+ def test_invalid_checksum(self):
+ # Flip the last digit.
+ assert _is_valid_luhn('4532015112830367') is False
+
+ def test_too_short(self):
+ # Only 12 digits — below 13-digit minimum.
+ assert _is_valid_luhn('453201511283') is False
+
+ def test_empty(self):
+ assert _is_valid_luhn('') is False
+
+ def test_no_digits(self):
+ assert _is_valid_luhn('abcd-efgh-ijkl-mnop') is False
+
+ def test_amex_test_number(self):
+ # 15-digit Amex test card.
+ assert _is_valid_luhn('378282246310005') is True
+
+ def test_mastercard_test_number(self):
+ assert _is_valid_luhn('5555555555554444') is True
+
+
+# ── _mask_keep_edges ────────────────────────────────────────────────────────
+
+
+class TestMaskKeepEdges:
+
+ def test_default_head_tail(self):
+ # head=3, tail=4 → keep 3 + mask middle + keep 4
+ s = '13800138000' # 11 chars
+ # 11 > 3+4 = 7 → masked = 11 - 7 = 4 stars
+ out = _mask_keep_edges(s)
+ assert out == '138' + '*' * 4 + '8000'
+
+ def test_short_string_all_masked(self):
+ # len ≤ head+tail → entire string masked.
+ s = 'short' # 5 chars; head+tail = 7
+ assert _mask_keep_edges(s) == '*****'
+
+ def test_at_threshold_all_masked(self):
+ # len == head+tail → all masked (boundary is `<=`)
+ s = '1234567' # 7 chars
+ assert _mask_keep_edges(s) == '*' * 7
+
+ def test_custom_head_tail(self):
+ s = 'abcdefghij' # 10 chars
+ # head=2, tail=2 → keep ab + 6 stars + ij
+ assert _mask_keep_edges(s, head=2, tail=2) == 'ab' + '*' * 6 + 'ij'
+
+ def test_custom_mask_char(self):
+ s = '1234567890'
+ out = _mask_keep_edges(s, head=1, tail=1, ch='X')
+ assert out == '1' + 'X' * 8 + '0'
+
+ def test_empty_string(self):
+ # len=0 ≤ head+tail → '' * 0 = ''
+ assert _mask_keep_edges('') == ''
+
+ def test_credit_card_default(self):
+ s = '4532015112830366' # 16 chars
+ out = _mask_keep_edges(s)
+ # head=3, tail=4 → keep 453 + 9 stars + 0366
+ assert out == '453' + '*' * 9 + '0366'
+
+
+# ── _hash_short ─────────────────────────────────────────────────────────────
+
+
+class TestHashShort:
+
+ def test_length_is_12(self):
+ assert len(_hash_short('alice@example.com')) == 12
+
+ def test_deterministic_same_input(self):
+ a = _hash_short('hello')
+ b = _hash_short('hello')
+ assert a == b
+
+ def test_different_inputs_different_outputs(self):
+ a = _hash_short('alice@example.com')
+ b = _hash_short('bob@example.com')
+ assert a != b
+
+ def test_salt_changes_output(self):
+ a = _hash_short('hello', salt='')
+ b = _hash_short('hello', salt='secret')
+ assert a != b
+
+ def test_matches_sha256_prefix(self):
+ expected = hashlib.sha256(b'hello').hexdigest()[:12]
+ assert _hash_short('hello') == expected
+
+ def test_matches_sha256_with_salt(self):
+ expected = hashlib.sha256(b'saltyhello').hexdigest()[:12]
+ assert _hash_short('hello', salt='salty') == expected
+
+ def test_empty_string(self):
+ # Hash is well-defined for empty input too.
+ expected = hashlib.sha256(b'').hexdigest()[:12]
+ assert _hash_short('') == expected
+
+ def test_unicode_input(self):
+ # UTF-8 encoding before hashing.
+ expected = hashlib.sha256('张三'.encode()).hexdigest()[:12]
+ assert _hash_short('张三') == expected
+
+
+# ── Strategy.coerce ─────────────────────────────────────────────────────────
+
+
+class TestStrategyCoerce:
+
+ def test_coerce_string_to_enum(self):
+ assert Strategy.coerce('mask') is Strategy.MASK
+ assert Strategy.coerce('replace') is Strategy.REPLACE
+ assert Strategy.coerce('redact') is Strategy.REDACT
+ assert Strategy.coerce('hash') is Strategy.HASH
+
+ def test_coerce_enum_returns_self(self):
+ assert Strategy.coerce(Strategy.MASK) is Strategy.MASK
+
+ def test_coerce_unknown_raises(self):
+ with pytest.raises(ValueError) as exc:
+ Strategy.coerce('encrypt')
+ # Error message lists allowed strategies for diagnosability.
+ msg = str(exc.value)
+ assert 'mask' in msg
+ assert 'replace' in msg
+ assert 'redact' in msg
+ assert 'hash' in msg
+
+ def test_coerce_empty_string_raises(self):
+ with pytest.raises(ValueError):
+ Strategy.coerce('')
+
+ def test_string_enum_membership(self):
+ # Strategy is a str-Enum: values should compare equal to their str form.
+ assert Strategy.MASK == 'mask'
+ assert Strategy.REPLACE.value == 'replace'
+
+ def test_coerce_case_sensitive(self):
+ # Implementation does not lowercase before lookup.
+ with pytest.raises(ValueError):
+ Strategy.coerce('MASK')
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/preprocessor/test_preprocessor_utils.py b/tests/preprocessor/test_preprocessor_utils.py
new file mode 100644
index 000000000..d52f8a77f
--- /dev/null
+++ b/tests/preprocessor/test_preprocessor_utils.py
@@ -0,0 +1,354 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for preprocessor.utils — pure logprob math helpers.
+
+These helpers compute conditional-vs-unconditional logprob deltas for
+IFD-family scoring (CherryLLM, T-SHIRT, ChR). All functions are stateless
+and accept simple list inputs.
+
+Conventions used in this test file:
+ * "lp" lists are aligned to the FULL sequence (prompt + answer).
+ * ``n_prompt`` is the number of prompt tokens; assistant tokens start at
+ index ``n_prompt`` in the cond list.
+ * Each lp entry is a dict {token_id: logprob_float}.
+"""
+import math
+import pytest
+
+from twinkle_agentic.preprocessor.utils import (_chr_min_distinct, _chr_min_weighted, _extract_logprob,
+ _ifd_family_metrics, _lp_to_jsonable, _mean_logprob_delta, _pad_batch,
+ _to_int_list)
+
+# ── _extract_logprob ────────────────────────────────────────────────────────
+
+
+class TestExtractLogprob:
+
+ def test_none(self):
+ assert _extract_logprob(None) is None
+
+ def test_scalar_int(self):
+ assert _extract_logprob(5) == 5.0
+
+ def test_scalar_float(self):
+ assert _extract_logprob(-1.2) == -1.2
+
+ def test_dict_with_int_token_id(self):
+ lp = {7: -0.5, 8: -2.0}
+ assert _extract_logprob(lp, token_id=7) == -0.5
+ assert _extract_logprob(lp, token_id=8) == -2.0
+
+ def test_dict_with_str_token_id_fallback(self):
+ # vLLM may emit string keys; lookup must fall back to str(token_id).
+ lp = {'7': -0.5}
+ assert _extract_logprob(lp, token_id=7) == -0.5
+
+ def test_dict_no_token_id_picks_first(self):
+ # No token_id → iter-first behaviour.
+ lp = {7: -0.5}
+ assert _extract_logprob(lp) == -0.5
+
+ def test_dict_token_id_missing_uses_first(self):
+ # token_id not in dict → fall back to first entry.
+ lp = {99: -3.0}
+ assert _extract_logprob(lp, token_id=7) == -3.0
+
+ def test_dict_with_logprob_attr_object(self):
+
+ class Entry:
+
+ def __init__(self, v):
+ self.logprob = v
+
+ lp = {7: Entry(-0.7)}
+ assert _extract_logprob(lp, token_id=7) == -0.7
+
+ def test_dict_with_nested_dict(self):
+ lp = {7: {'logprob': -0.9, 'rank': 1}}
+ assert _extract_logprob(lp, token_id=7) == -0.9
+
+ def test_dict_with_nested_dict_none_logprob(self):
+ lp = {7: {'logprob': None}}
+ assert _extract_logprob(lp, token_id=7) is None
+
+ def test_unrecognized_type(self):
+ # str entries → returns None
+ lp = {7: 'oops'}
+ assert _extract_logprob(lp, token_id=7) is None
+
+ def test_non_dict_non_scalar(self):
+ # A list is neither scalar nor dict → None.
+ assert _extract_logprob([1, 2, 3]) is None
+
+
+# ── _to_int_list ────────────────────────────────────────────────────────────
+
+
+class TestToIntList:
+
+ def test_plain_list(self):
+ assert _to_int_list([1, 2, 3]) == [1, 2, 3]
+
+ def test_tuple(self):
+ assert _to_int_list((1, 2, 3)) == [1, 2, 3]
+
+ def test_with_tolist(self):
+
+ class Tensor:
+
+ def tolist(self):
+ return [4, 5, 6]
+
+ assert _to_int_list(Tensor()) == [4, 5, 6]
+
+ def test_empty(self):
+ assert _to_int_list([]) == []
+
+
+# ── _chr_min_distinct ───────────────────────────────────────────────────────
+
+
+class TestChrMinDistinct:
+
+ def test_empty_inputs_returns_none(self):
+ assert _chr_min_distinct([], [{1: -1.0}], [], [1], 0) is None
+ assert _chr_min_distinct([{1: -1.0}], [], [1], [], 0) is None
+ assert _chr_min_distinct([{1: -1.0}], [{1: -1.0}], [1], [], 0) is None
+
+ def test_simple_all_positive(self):
+ # cond_lp[i] - asst_lp[i] > 0 for all i → ratio = 1.0
+ n_prompt = 1
+ # cond covers prompt(1) + asst(2) = 3 positions
+ cond_lp = [
+ {
+ 0: -10.0
+ }, # prompt position
+ {
+ 1: -0.1
+ }, # asst pos 0 — high cond logprob
+ {
+ 2: -0.2
+ }
+ ] # asst pos 1
+ asst_lp = [{1: -1.0}, {2: -1.5}]
+ cond_ids = [0, 1, 2]
+ asst_ids = [1, 2]
+ ratio = _chr_min_distinct(cond_lp, asst_lp, cond_ids, asst_ids, n_prompt)
+ assert ratio == 1.0
+
+ def test_all_negative(self):
+ # delta < 0 → ratio = 0
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -3.0}, {2: -3.0}]
+ asst_lp = [{1: -0.5}, {2: -0.5}]
+ ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert ratio == 0.0
+
+ def test_distinct_token_min_aggregation(self):
+ # Two occurrences of same token: one has +delta, one has -delta.
+ # min(deltas) is negative → token contributes 0 to ratio.
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -0.1}, {1: -3.0}]
+ asst_lp = [{1: -1.0}, {1: -0.5}] # delta1=+0.9, delta2=-2.5
+ ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 1, 1], [1, 1], n_prompt)
+ assert ratio == 0.0 # min < 0
+
+ def test_exclude_ids(self):
+ # Excluded token is dropped before counting.
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -0.1}, {2: -0.1}]
+ asst_lp = [{1: -1.0}, {2: -1.0}]
+ # Without exclude: 2 distinct tokens, both positive → 1.0
+ ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt, exclude_ids={1})
+ assert ratio == 1.0 # only token 2 counted, still positive
+
+ def test_truncation_when_cond_short(self):
+ # cond_lp shorter than n_prompt + n_asst → loop breaks early.
+ n_prompt = 2
+ cond_lp = [{0: 0.0}, {0: 0.0}, {1: -0.1}] # only 1 asst position
+ asst_lp = [{1: -1.0}, {2: -1.0}] # 2 asst positions requested
+ ratio = _chr_min_distinct(cond_lp, asst_lp, [0, 0, 1], [1, 2], n_prompt)
+ assert ratio == 1.0 # only the first delta processed
+
+
+# ── _chr_min_weighted ───────────────────────────────────────────────────────
+
+
+class TestChrMinWeighted:
+
+ def test_empty_returns_none(self):
+ assert _chr_min_weighted([], [{1: -1.0}], [], [1], 0) is None
+
+ def test_all_positive_returns_one(self):
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -0.1}, {2: -0.2}]
+ asst_lp = [{1: -1.0}, {2: -1.5}]
+ ratio = _chr_min_weighted(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert ratio == 1.0 # all positive → pos_w == total_w
+
+ def test_zero_total_weight_returns_none(self):
+ # All deltas == 0 → total_w == 0 → None
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -1.0}]
+ asst_lp = [{1: -1.0}]
+ assert _chr_min_weighted(cond_lp, asst_lp, [0, 1], [1], n_prompt) is None
+
+ def test_weighted_mixture(self):
+ # Token A: min_delta = +2.0 (weight 2)
+ # Token B: min_delta = -1.0 (weight 1)
+ # pos / total = 2 / 3
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: 1.0}, {2: -2.0}] # cond: A=1.0, B=-2.0
+ asst_lp = [{1: -1.0}, {2: -1.0}] # asst: A=-1.0, B=-1.0
+ # delta A = 1.0 - (-1.0) = 2.0
+ # delta B = -2.0 - (-1.0) = -1.0
+ ratio = _chr_min_weighted(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert abs(ratio - 2 / 3) < 1e-9
+
+
+# ── _ifd_family_metrics ─────────────────────────────────────────────────────
+
+
+class TestIfdFamilyMetrics:
+
+ def test_empty_returns_empty_dict(self):
+ assert _ifd_family_metrics([], [{1: -1.0}], [], [1], 0) == {}
+
+ def test_simple_uniform(self):
+ # All deltas = 0.5 → mean=0.5, ifd=exp(-0.5)
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -0.5}, {2: -0.5}]
+ asst_lp = [{1: -1.0}, {2: -1.0}]
+ out = _ifd_family_metrics(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert out['n_tokens'] == 2
+ assert abs(out['mean_delta'] - 0.5) < 1e-9
+ assert abs(out['ifd'] - math.exp(-0.5)) < 1e-9
+ # s_ifd_50 keeps top-1 by |delta| = 0.5; s_ifd_75 keeps top-2 (rounded up).
+ assert abs(out['s_ifd_50'] - math.exp(-0.5)) < 1e-9
+ assert abs(out['s_ifd_75'] - math.exp(-0.5)) < 1e-9
+
+ def test_mixed_deltas(self):
+ # deltas = [+2.0, -1.0]; mean = 0.5
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: 1.0}, {2: -2.0}]
+ asst_lp = [{1: -1.0}, {2: -1.0}]
+ out = _ifd_family_metrics(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert out['n_tokens'] == 2
+ assert abs(out['mean_delta'] - 0.5) < 1e-9
+ # s_ifd_50 keeps top-1 by |delta| = 2.0 → exp(-2.0)
+ assert abs(out['s_ifd_50'] - math.exp(-2.0)) < 1e-9
+
+
+# ── _mean_logprob_delta ─────────────────────────────────────────────────────
+
+
+class TestMeanLogprobDelta:
+
+ def test_empty(self):
+ assert _mean_logprob_delta([], [{1: -1.0}], [], [1], 0) is None
+
+ def test_uniform_delta(self):
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -0.5}, {2: -0.5}]
+ asst_lp = [{1: -1.0}, {2: -1.0}]
+ out = _mean_logprob_delta(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert abs(out - 0.5) < 1e-9
+
+ def test_mixed_average(self):
+ # deltas = [+2.0, -1.0] → mean 0.5
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: 1.0}, {2: -2.0}]
+ asst_lp = [{1: -1.0}, {2: -1.0}]
+ out = _mean_logprob_delta(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert abs(out - 0.5) < 1e-9
+
+ def test_skips_none_logprobs(self):
+ # When asst lp returns None, that position is skipped silently.
+ n_prompt = 1
+ cond_lp = [{0: 0.0}, {1: -0.5}, {2: -0.5}]
+ asst_lp = [None, {2: -1.0}]
+ out = _mean_logprob_delta(cond_lp, asst_lp, [0, 1, 2], [1, 2], n_prompt)
+ assert abs(out - 0.5) < 1e-9 # only position 1 used
+
+
+# ── _lp_to_jsonable ─────────────────────────────────────────────────────────
+
+
+class TestLpToJsonable:
+
+ def test_none_input(self):
+ assert _lp_to_jsonable(None) == []
+
+ def test_empty(self):
+ assert _lp_to_jsonable([]) == []
+
+ def test_none_passthrough(self):
+ assert _lp_to_jsonable([None, None]) == [None, None]
+
+ def test_scalar_to_float(self):
+ assert _lp_to_jsonable([1, -2.0]) == [1.0, -2.0]
+
+ def test_dict_with_logprob_object(self):
+
+ class Entry:
+
+ def __init__(self, lp, rank, decoded):
+ self.logprob = lp
+ self.rank = rank
+ self.decoded_token = decoded
+
+ out = _lp_to_jsonable([{7: Entry(-0.5, 1, 'hello')}])
+ assert out == [{'7': {'logprob': -0.5, 'rank': 1, 'decoded': 'hello'}}]
+
+ def test_dict_with_nested_dict(self):
+ out = _lp_to_jsonable([{7: {'logprob': -0.5}}])
+ assert out == [{'7': {'logprob': -0.5}}]
+
+ def test_dict_with_repr_fallback(self):
+ # Non-dict, non-Entry value falls back to repr string.
+ out = _lp_to_jsonable([{7: 'plain'}])
+ assert out == [{'7': repr('plain')}]
+
+ def test_non_dict_non_scalar_repr(self):
+ # An object that isn't dict/scalar gets repr-ed.
+ out = _lp_to_jsonable([(1, 2)])
+ assert out == [repr((1, 2))]
+
+
+# ── _pad_batch ──────────────────────────────────────────────────────────────
+
+
+class TestPadBatch:
+
+ def test_empty_batch(self):
+ padded, n = _pad_batch([], floor=4)
+ assert padded == []
+ assert n == 0
+
+ def test_already_at_floor(self):
+ batch = [[1], [2], [3], [4]]
+ padded, n = _pad_batch(batch, floor=4)
+ assert padded == batch
+ assert n == 4
+
+ def test_above_floor(self):
+ batch = [[1], [2], [3], [4], [5]]
+ padded, n = _pad_batch(batch, floor=3)
+ assert padded == batch # unchanged
+ assert n == 5
+
+ def test_below_floor_pads_with_last(self):
+ batch = [[1], [2]]
+ padded, n = _pad_batch(batch, floor=4)
+ assert padded == [[1], [2], [2], [2]]
+ assert n == 2 # original size
+
+ def test_returns_new_list(self):
+ batch = [[1], [2]]
+ padded, _ = _pad_batch(batch, floor=4)
+ # Mutating padded should not affect original.
+ padded.append([99])
+ assert batch == [[1], [2]]
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/preprocessor/test_refuse_filter.py b/tests/preprocessor/test_refuse_filter.py
new file mode 100644
index 000000000..5a84842d3
--- /dev/null
+++ b/tests/preprocessor/test_refuse_filter.py
@@ -0,0 +1,320 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for RefuseFilter.
+
+RefuseFilter drops rows whose first assistant reply contains a self-referential
+refusal signal across en/zh/ja/ko. Key invariants:
+
+ * Detection is anchored to ``I/we`` + inability + task verb (en) and 我 +
+ 无法/不能 + 帮/回答/... (zh). Generic uses of ``cannot`` / ``无法`` that are
+ NOT self-referential must NOT trip the filter.
+ * Only the FIRST assistant turn is scanned, within ``check_window`` chars.
+ * ``...`` blocks are stripped before scanning.
+ * Empty / missing assistant content is kept (filter is conservative).
+"""
+import pytest
+
+from twinkle_agentic.preprocessor.refuse_filter import RefuseFilter, _is_refusal
+
+
+def _row(messages):
+ return {'messages': messages}
+
+
+def _fil(rows, **kw):
+ return RefuseFilter(**kw)(rows)
+
+
+# ── _is_refusal: English ────────────────────────────────────────────────────
+
+
+class TestRefusalEnglish:
+
+ def test_core_self_referential(self):
+ assert _is_refusal("I can't help with that request.") is True
+ assert _is_refusal('I cannot assist with creating malware.') is True
+ assert _is_refusal('I am not able to provide that information.') is True
+ assert _is_refusal("We won't generate that content.") is True
+
+ def test_apology_opener(self):
+ assert _is_refusal("I'm sorry, but I can't help with that.") is True
+ assert _is_refusal('Unfortunately, I am unable to comply.') is True
+ assert _is_refusal('I apologize, but I cannot answer that.') is True
+
+ def test_policy_violation(self):
+ assert _is_refusal('This request violates my guidelines.') is True
+ assert _is_refusal('Your question goes against safety policies.') is True
+
+ def test_standalone_decline(self):
+ assert _is_refusal('I must decline this request.') is True
+ assert _is_refusal('I refuse to participate in that.') is True
+ assert _is_refusal("As an AI, I can't help with illegal activities.") is True
+
+ def test_false_positive_cannot_stress(self):
+ # "I cannot stress enough" — has I+cannot but NOT a task verb in window
+ text = 'I cannot stress enough how important documentation is.'
+ assert _is_refusal(text) is False
+
+ def test_false_positive_cannot_find(self):
+ # "I cannot find the bug" — find is not in task-verb whitelist
+ text = 'I cannot find the bug after looking everywhere.'
+ assert _is_refusal(text) is False
+
+ def test_false_positive_third_person(self):
+ # No I/we subject → not refusal
+ assert _is_refusal('The API cannot handle null inputs.') is False
+ assert _is_refusal('You cannot use this without auth.') is False
+
+ def test_case_insensitive(self):
+ assert _is_refusal('I CANNOT HELP with that.') is True
+ assert _is_refusal("i can't assist you here") is True
+
+
+# ── _is_refusal: Chinese ────────────────────────────────────────────────────
+
+
+class TestRefusalChinese:
+
+ def test_apology_inability(self):
+ assert _is_refusal('非常抱歉,我无法回答这个问题。') is True
+ assert _is_refusal('对不起,我不能提供这类信息。') is True
+ assert _is_refusal('抱歉,没有办法帮助您完成这个请求。') is True
+
+ def test_self_referential(self):
+ assert _is_refusal('我无法帮助你完成这项任务。') is True
+ assert _is_refusal('我不能提供医疗建议。') is True
+ assert _is_refusal('我不被允许回答此类问题。') is True
+
+ def test_request_violation(self):
+ assert _is_refusal('您的请求违反了使用规则。') is True
+ assert _is_refusal('这个问题不合适讨论。') is True
+
+ def test_ai_identity_refusal(self):
+ assert _is_refusal('作为AI,我无法回答涉及隐私的问题。') is True
+ assert _is_refusal('作为人工智能,我不能提供该类内容。') is True
+
+ def test_false_positive_self_deprecating_preamble(self):
+ # "作为AI,我虽无法体验情感,但..." — 无法 not paired with task verb
+ text = '作为AI,我虽无法亲身体验情感,但可以分析其特征。'
+ assert _is_refusal(text) is False
+
+ def test_false_positive_third_person(self):
+ # No 我 subject + task verb pattern
+ assert _is_refusal('该接口无法处理空输入。') is False
+ # "无法理解" with no 我 / task verb pattern
+ assert _is_refusal('系统暂时不能正常运作。') is False
+
+
+# ── _is_refusal: Japanese ───────────────────────────────────────────────────
+
+
+class TestRefusalJapanese:
+
+ def test_apology_inability(self):
+ # Pattern is `(申し訳|恐れ入り)ます(が|けれど)` — only 恐れ入りますが
+ # is grammatically natural.
+ assert _is_refusal('恐れ入りますが、お答えできません。') is True
+ assert _is_refusal('恐れ入りますけれど、対応できません。') is True
+
+ def test_explicit_inability(self):
+ assert _is_refusal('回答することはできません。') is True
+ assert _is_refusal('ご要望にはお応えできません。') is True
+
+ def test_request_decline(self):
+ # Matches `(その|この)(リクエスト|質問|依頼).{0,20}(お断り|辞退|対応できません)`
+ assert _is_refusal('そのリクエストはお断りします。') is True
+ assert _is_refusal('このリクエストには対応できません。') is True
+
+
+# ── _is_refusal: Korean ─────────────────────────────────────────────────────
+
+
+class TestRefusalKorean:
+
+ def test_apology_inability(self):
+ assert _is_refusal('죄송하지만 답변을 드릴 수 없습니다.') is True
+ assert _is_refusal('유감스럽게도 도와드릴 수 없습니다.') is True
+
+ def test_action_difficulty(self):
+ assert _is_refusal('답변드리기 어렵습니다.') is True
+ assert _is_refusal('처리하기 불가능합니다.') is True
+
+
+# ── check_window ────────────────────────────────────────────────────────────
+
+
+class TestCheckWindow:
+
+ def test_window_excludes_late_refusal(self):
+ # Refusal at position 700 — beyond default 600-char window
+ text = 'a' * 700 + " I can't help you complete that task."
+ assert _is_refusal(text, check_window=600) is False
+
+ def test_custom_window_includes_late_refusal(self):
+ text = 'a' * 700 + " I can't help you complete that task."
+ assert _is_refusal(text, check_window=1000) is True
+
+ def test_zero_window_finds_nothing(self):
+ assert _is_refusal("I can't help you complete tasks.", check_window=0) is False
+
+
+# ── RefuseFilter pipeline ───────────────────────────────────────────────────
+
+
+class TestRefuseFilterPipeline:
+
+ def test_drops_refusal_row(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'do bad thing'
+ },
+ {
+ 'role': 'assistant',
+ 'content': "I'm sorry, but I cannot help with that request."
+ },
+ ])
+ ]
+ assert _fil(rows) == []
+
+ def test_keeps_normal_reply(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'explain X'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'X is a concept that...'
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_only_first_assistant_scanned(self):
+ # Refusal in SECOND assistant turn → kept (filter only checks first).
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q1'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'A clean reply.'
+ },
+ {
+ 'role': 'user',
+ 'content': 'q2'
+ },
+ {
+ 'role': 'assistant',
+ 'content': "I can't help with that."
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_think_block_stripped(self):
+ # Refusal phrasing inside ... must NOT trigger.
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'I cannot help with this request'
+ 'Sure, here is the answer: 42.'
+ },
+ ])
+ ]
+ assert len(_fil(rows)) == 1
+
+ def test_no_assistant_kept(self):
+ rows = [_row([{'role': 'user', 'content': 'hi'}])]
+ assert len(_fil(rows)) == 1
+
+ def test_empty_assistant_kept(self):
+ rows = [_row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': ''
+ },
+ ])]
+ assert len(_fil(rows)) == 1
+
+ def test_empty_input(self):
+ assert _fil([]) == []
+
+ def test_missing_messages_kept(self):
+ # No messages key → no assistant → kept
+ rows = [{'id': 'x'}]
+ assert len(_fil(rows)) == 1
+
+ def test_mixed_batch(self):
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q1'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'a normal answer'
+ },
+ ]),
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q2'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'I refuse to help you with that task.'
+ },
+ ]),
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q3'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '抱歉,我无法回答这个问题。'
+ },
+ ]),
+ ]
+ out = _fil(rows)
+ assert len(out) == 1
+ assert out[0]['messages'][0]['content'] == 'q1'
+
+ def test_custom_check_window(self):
+ # Default 600 would miss a late refusal; tighten via pipeline kw.
+ long_prefix = 'a' * 700
+ rows = [
+ _row([
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': long_prefix + " I can't help you complete that."
+ },
+ ])
+ ]
+ # default window → kept
+ assert len(_fil(rows)) == 1
+ # widen → dropped
+ assert _fil(rows, check_window=1000) == []
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/preprocessor/test_token_soup.py b/tests/preprocessor/test_token_soup.py
new file mode 100644
index 000000000..ae97b09f1
--- /dev/null
+++ b/tests/preprocessor/test_token_soup.py
@@ -0,0 +1,282 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tests for TokenSoupFilter.
+
+Covers each garbled-output signal in ``_is_token_soup`` plus the
+script-chaos analyzer and the row-filter pipeline.
+"""
+import pytest
+
+from twinkle_agentic.preprocessor.token_soup import TokenSoupFilter, _is_token_soup, _script_chaos, _script_of
+
+
+def _row(content):
+ return {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': content
+ },
+ ]
+ }
+
+
+# ── Per-signal detector tests ────────────────────────────────────────────────
+
+
+class TestReplacementChar:
+
+ def test_above_threshold(self):
+ text = '\ufffd' * 5 + 'short' # 5/10 = 50% > 2%
+ assert _is_token_soup(text) is True
+
+ def test_below_threshold(self):
+ text = '\ufffd' + 'hello world this is text. ' * 30 # 1/~780 ≈ 0.1% < 2%
+ # No other signal should fire
+ assert _is_token_soup(text) is False
+
+ def test_no_replacement_char(self):
+ assert _is_token_soup('hello world') is False
+
+
+class TestControlChar:
+
+ def test_above_threshold(self):
+ text = '\x01\x02\x03\x04\x05' + 'a' * 100 # 5/105 ≈ 4.8% > 1%
+ assert _is_token_soup(text) is True
+
+ def test_keeps_legitimate_whitespace(self):
+ text = 'line1\nline2\tindented\rcr'
+ assert _is_token_soup(text) is False
+
+ def test_del_char_triggers(self):
+ text = '\x7f' * 5 + 'a' * 100
+ assert _is_token_soup(text) is True
+
+
+class TestPrivateUseArea:
+
+ def test_bmp_pua_above_threshold(self):
+ text = '\ue000\ue001\ue002\ue003\ue004' + 'a' * 100 # 5/105 ≈ 4.8% > 3%
+ assert _is_token_soup(text) is True
+
+ def test_below_threshold(self):
+ text = '\ue000' + 'hello world this is text. ' * 30 # ~0.1% < 3%
+ assert _is_token_soup(text) is False
+
+
+class TestSpecialTokens:
+
+ def test_repeated_pipe_token(self):
+ text = '<|endoftext|>' * 25
+ assert _is_token_soup(text, special_token_count=20) is True
+
+ def test_repeated_bert_uppercase(self):
+ text = '[PAD]' * 25
+ assert _is_token_soup(text, special_token_count=20) is True
+
+ def test_lowercase_brackets_not_matched(self):
+ # ``dp[mask]`` is normal code; lowercase variant must NOT match.
+ text = 'arr[mask] = arr[mask] | 1; ' * 30
+ assert _is_token_soup(text, special_token_count=20) is False
+
+ def test_byte_token_form(self):
+ text = '<0x0A>' * 25
+ assert _is_token_soup(text, special_token_count=20) is True
+
+ def test_below_count(self):
+ text = '<|endoftext|>' * 5
+ assert _is_token_soup(text, special_token_count=20) is False
+
+ def test_unk_pad_html_tags(self):
+ text = '' * 12 + '' * 13
+ assert _is_token_soup(text, special_token_count=20) is True
+
+
+class TestSingleCharRepeat:
+
+ def test_letter_repeat_triggers(self):
+ text = 'aaaaaaaaaaaaaaaaaaaaaaaaaa hello world' # 26 a's > 19
+ assert _is_token_soup(text) is True
+
+ def test_dash_excluded(self):
+ text = '-' * 50 + ' separator'
+ assert _is_token_soup(text) is False
+
+ def test_equals_excluded(self):
+ text = '=' * 50
+ assert _is_token_soup(text) is False
+
+ def test_digit_excluded(self):
+ text = '9' * 50
+ assert _is_token_soup(text) is False
+
+ def test_box_drawing_excluded(self):
+ text = '\u2500' * 50 # ─ box-drawing horizontal
+ assert _is_token_soup(text) is False
+
+ def test_below_threshold(self):
+ text = 'a' * 19 # 19 < 20 (regex requires \1{19,} → 1 + 19 = 20)
+ assert _is_token_soup(text) is False
+
+ def test_at_threshold(self):
+ text = 'a' * 20 # 20 a's: 1 + 19 repeats → matches
+ assert _is_token_soup(text) is True
+
+
+# ── Script-chaos analyzer ────────────────────────────────────────────────────
+
+
+class TestScriptOf:
+
+ def test_latin(self):
+ assert _script_of(ord('A')) == 'latin'
+ assert _script_of(ord('z')) == 'latin'
+
+ def test_cjk(self):
+ assert _script_of(ord('中')) == 'cjk'
+
+ def test_hiragana_katakana(self):
+ assert _script_of(0x3042) == 'hiragana' # あ
+ assert _script_of(0x30A2) == 'katakana' # ア
+
+ def test_cyrillic(self):
+ assert _script_of(0x0410) == 'cyrillic'
+
+ def test_hangul(self):
+ assert _script_of(0xAC00) == 'hangul'
+
+ def test_private(self):
+ assert _script_of(0xE000) == 'private'
+
+ def test_other(self):
+ assert _script_of(0x2000) == 'other' # general punctuation
+
+
+class TestScriptChaos:
+
+ def test_pure_latin_zero_chaos(self):
+ assert _script_chaos('hello world this is a long english sentence') == 0.0
+
+ def test_pure_cjk_zero_chaos(self):
+ assert _script_chaos('这是一段足够长的中文文本用于测试脚本切换检测' * 2) == 0.0
+
+ def test_short_text_returns_zero(self):
+ # Below ``min_chars`` → returns 0.0 regardless of mix.
+ assert _script_chaos('aあ', min_chars=40) == 0.0
+
+ def test_high_chaos_alternation(self):
+ # Pure letter/number alternation between scripts → chaos ≈ 1.0.
+ text = ('aあbいcうdえeお' * 5) # 50 alternating letters
+ score = _script_chaos(text, min_chars=40)
+ assert score > 0.9
+
+ def test_filter_with_chaos(self):
+ text = ('aあbいcうdえeお' * 5) # high chaos
+ assert _is_token_soup(text, script_chaos_min_chars=40, script_chaos_threshold=0.55) is True
+
+ def test_skips_punct_whitespace(self):
+ # Categories not in (L, N) are dropped before script-of pairing.
+ text = 'hello, world! how are you?'
+ assert _script_chaos(text) == 0.0
+
+
+# ── max_chars head-sampling ──────────────────────────────────────────────────
+
+
+class TestMaxChars:
+
+ def test_only_head_examined(self):
+ # Soup at the tail; head is clean. With max_chars=100 we should not see it.
+ head = 'hello world this is plain text. ' * 4 # ~128 chars, no repeat-20
+ text = head[:100] + '\ufffd' * 100
+ assert _is_token_soup(text, max_chars=100, replacement_char_ratio=0.02) is False
+
+ def test_full_text_when_max_chars_zero(self):
+ head = 'hello world this is plain text. ' * 4
+ text = head[:100] + '\ufffd' * 100
+ assert _is_token_soup(text, max_chars=0, replacement_char_ratio=0.02) is True
+
+
+# ── Empty / trivial inputs ───────────────────────────────────────────────────
+
+
+class TestTrivial:
+
+ def test_empty_text(self):
+ assert _is_token_soup('') is False
+
+ def test_short_clean_text(self):
+ assert _is_token_soup('Hi there!') is False
+
+
+# ── Pipeline ─────────────────────────────────────────────────────────────────
+
+
+class TestTokenSoupFilterPipeline:
+
+ def test_drops_soupy_assistant(self):
+ f = TokenSoupFilter()
+ rows = [_row('clean response'), _row('aaaaaaaaaaaaaaaaaaaaaaaaaaaaa')]
+ out = f(rows)
+ assert len(out) == 1
+ assert out[0]['messages'][1]['content'] == 'clean response'
+
+ def test_keeps_row_without_assistant(self):
+ f = TokenSoupFilter()
+ rows = [{'messages': [{'role': 'user', 'content': 'q'}]}]
+ out = f(rows)
+ assert len(out) == 1
+
+ def test_any_assistant_soupy_drops_row(self):
+ f = TokenSoupFilter()
+ rows = [{
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': 'q'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'fine'
+ },
+ {
+ 'role': 'user',
+ 'content': 'q2'
+ },
+ {
+ 'role': 'assistant',
+ 'content': '\ufffd' * 10 + 'a' * 5
+ },
+ ]
+ }]
+ out = f(rows)
+ assert out == []
+
+ def test_strips_whitespace_before_check(self):
+ # Leading/trailing whitespace shouldn't bypass detection.
+ f = TokenSoupFilter()
+ rows = [_row(' ' + 'a' * 30 + ' ')]
+ assert f(rows) == []
+
+ def test_threshold_overrides_propagated(self):
+ # With a stricter ratio, even small amounts of \ufffd trip it.
+ f = TokenSoupFilter(replacement_char_ratio=0.0)
+ rows = [_row('hello\ufffdworld')]
+ assert f(rows) == []
+
+ def test_empty_rows(self):
+ assert TokenSoupFilter()([]) == []
+
+ def test_messages_missing(self):
+ f = TokenSoupFilter()
+ rows = [{'id': 'no-msgs'}]
+ out = f(rows)
+ assert len(out) == 1
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/template/test_tool_parsers.py b/tests/template/test_tool_parsers.py
new file mode 100644
index 000000000..fc404b7a2
--- /dev/null
+++ b/tests/template/test_tool_parsers.py
@@ -0,0 +1,457 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Pure-Python tests for tool-call parsers (no model download).
+
+Covers Hermes/Qwen, ReAct, Cline parsing, cleaning, and — most importantly
+— streaming correctness via the generic state machine in
+:class:`twinkle.template.base.Template`.
+"""
+import json
+import pytest
+
+from twinkle.template.base import Template
+from twinkle.template.tools import ClineParser, HermesQwenParser, ReActParser, ToolCallRegistry, trailing_prefix_of
+
+
+class _StubTemplate:
+ """Minimal Template-shaped object exposing only stream-related members.
+
+ Avoids loading a real tokenizer/processor (which would need network).
+ """
+
+ parse_tool_call_stream = Template.parse_tool_call_stream
+ _stream_marker_blocks = Template._stream_marker_blocks
+ _format_tc_delta = staticmethod(Template._format_tc_delta)
+
+ def __init__(self, model_id: str):
+ self.model_id = model_id
+
+
+def _stream(model_id, chunks_with_finished):
+ t = _StubTemplate(model_id)
+ state = {}
+ events = []
+ for chunk, fin in chunks_with_finished:
+ events.extend(t.parse_tool_call_stream(state, chunk, finished=fin))
+ return events, state
+
+
+# ---------------------------------------------------------------------------
+# HermesQwenParser
+# ---------------------------------------------------------------------------
+
+
+class TestHermesQwenParser:
+
+ def setup_method(self):
+ self.p = HermesQwenParser()
+
+ def test_detect(self):
+ assert self.p.detect('hi {"name":"f","arguments":{}}')
+ assert not self.p.detect('plain text')
+ assert not self.p.detect('')
+
+ def test_matches_model(self):
+ assert self.p.matches_model('qwen2.5-7b')
+ assert self.p.matches_model('qwen3-32b')
+ assert not self.p.matches_model('llama-3.1-8b')
+
+ def test_parse_json_variant(self):
+ text = '{"name": "get_weather", "arguments": {"city": "Paris"}}'
+ out = self.p.parse(text)
+ assert out == [{
+ 'type': 'function',
+ 'function': {
+ 'name': 'get_weather',
+ 'arguments': {
+ 'city': 'Paris'
+ }
+ },
+ }]
+
+ def test_parse_function_xml_variant(self):
+ text = (''
+ '12'
+ '')
+ out = self.p.parse(text)
+ assert len(out) == 1
+ assert out[0]['function']['name'] == 'add'
+ # JSON-decoding of param values: numbers come back as int.
+ assert out[0]['function']['arguments'] == {'a': 1, 'b': 2}
+
+ def test_parse_multiple_blocks(self):
+ text = ('{"name":"f1","arguments":{}}'
+ 'between '
+ '{"name":"f2","arguments":{"k":"v"}}')
+ out = self.p.parse(text)
+ assert [c['function']['name'] for c in out] == ['f1', 'f2']
+ assert out[1]['function']['arguments'] == {'k': 'v'}
+
+ def test_parse_unclosed_block_at_eof(self):
+ # ``\Z`` fallback in _BLOCK_RE handles truncated trailing block.
+ text = '{"name": "f", "arguments": {}}'
+ out = self.p.parse(text)
+ assert out and out[0]['function']['name'] == 'f'
+
+ def test_parse_empty_returns_empty_list(self):
+ assert self.p.parse('') == []
+ assert self.p.parse('plain text without markers') == []
+
+ def test_clean_strips_blocks(self):
+ text = 'hello {"name":"f","arguments":{}} world'
+ assert self.p.clean(text) == 'hello world'
+
+ def test_clean_unclosed_at_eof(self):
+ text = 'hello {"name":"f"'
+ assert self.p.clean(text) == 'hello'
+
+ def test_clean_empty(self):
+ assert self.p.clean('') == ''
+
+ def test_markers_declared(self):
+ assert self.p.open_marker == ''
+ assert self.p.close_marker == ''
+
+
+class TestHermesQwenStreaming:
+ """Generic open/close marker buffer state machine."""
+
+ def test_plain_text_passthrough(self):
+ events, _ = _stream('qwen2.5-7b', [('Hello world!', True)])
+ assert events == [{'content': 'Hello world!'}]
+
+ def test_holds_back_partial_open_marker(self):
+ events, state = _stream('qwen2.5-7b', [
+ ('Hello! ', False),
+ ('{"name":"f","arguments":{}}', False),
+ ('done.', False),
+ ('', True),
+ ])
+ types = [next(iter(e)) for e in events]
+ assert types == ['content', 'tool_calls', 'content']
+ tc = events[1]['tool_calls'][0]
+ assert tc['function']['name'] == 'f'
+ # OpenAI streaming spec: arguments serialised as JSON string.
+ assert tc['function']['arguments'] == '{}'
+ assert tc['index'] == 0
+ assert tc['id'].startswith('call_')
+ assert tc['type'] == 'function'
+
+ def test_stream_chunked_inside_block(self):
+ # Split the block at every char to torture-test the partial-marker
+ # hold-back logic.
+ full = '{"name":"f","arguments":{"x":1}}'
+ chunks = [(full[i:i + 1], False) for i in range(len(full))]
+ chunks.append(('', True))
+ events, state = _stream('qwen2.5-7b', chunks)
+ tcs = [e['tool_calls'][0] for e in events if 'tool_calls' in e]
+ assert len(tcs) == 1
+ assert tcs[0]['function']['name'] == 'f'
+ assert json.loads(tcs[0]['function']['arguments']) == {'x': 1}
+ assert state['pending'] == ''
+ # No content events should leak the markup.
+ for e in events:
+ if 'content' in e:
+ assert '' not in e['content']
+ assert '' not in e['content']
+
+ def test_multiple_blocks_increasing_indices(self):
+ events, _ = _stream('qwen2.5-7b', [
+ ('{"name":"a","arguments":{}}'
+ '{"name":"b","arguments":{}}', True),
+ ])
+ tcs = [e['tool_calls'][0] for e in events if 'tool_calls' in e]
+ assert [t['function']['name'] for t in tcs] == ['a', 'b']
+ assert [t['index'] for t in tcs] == [0, 1]
+
+ def test_unclosed_block_flushed_on_finish(self):
+ events, state = _stream('qwen2.5-7b', [
+ ('{"name":"f","arguments":{}}', True),
+ ])
+ assert state['pending'] == ''
+ tcs = [e['tool_calls'][0] for e in events if 'tool_calls' in e]
+ assert tcs and tcs[0]['function']['name'] == 'f'
+
+ def test_arguments_serialised_as_json_string(self):
+ events, _ = _stream('qwen2.5-7b', [
+ ('{"name":"f","arguments":{"k":"v","n":3}}', True),
+ ])
+ tc = next(e['tool_calls'][0] for e in events if 'tool_calls' in e)
+ assert isinstance(tc['function']['arguments'], str)
+ assert json.loads(tc['function']['arguments']) == {'k': 'v', 'n': 3}
+
+ def test_content_events_lossless_for_non_block_text(self):
+ # All non-tool-call text must pass through verbatim, regardless of
+ # chunk boundaries.
+ original_content_outside = 'aXY'
+ full = ('a'
+ '{"name":"f","arguments":{}}'
+ 'XY')
+ chunks = [(full[i:i + 3], False) for i in range(0, len(full), 3)]
+ chunks.append(('', True))
+ events, _ = _stream('qwen2.5-7b', chunks)
+ rebuilt = ''.join(e['content'] for e in events if 'content' in e)
+ assert rebuilt == original_content_outside
+
+ def test_no_emission_until_chunk_arrives(self):
+ # Streaming with empty chunk and not-finished should be a no-op.
+ events, _ = _stream('qwen2.5-7b', [('', False)])
+ assert events == []
+
+
+# ---------------------------------------------------------------------------
+# ReActParser
+# ---------------------------------------------------------------------------
+
+
+class TestReActParser:
+
+ def setup_method(self):
+ self.p = ReActParser()
+
+ def test_detect_action_line(self):
+ assert self.p.detect('Thought: I need search.\nAction: search[python]')
+ assert not self.p.detect('plain text without action keyword')
+ assert not self.p.detect('')
+
+ def test_no_block_marker(self):
+ # Prose format — streaming has no marker to lock onto.
+ assert self.p.open_marker is None
+ assert self.p.close_marker is None
+
+ def test_does_not_match_qwen_model(self):
+ assert not self.p.matches_model('qwen2.5')
+ assert not self.p.matches_model('llama-3')
+
+ def test_parse_single_action(self):
+ text = 'Thought: search the web.\nAction: search[hello world]'
+ out = self.p.parse(text)
+ assert out == [{
+ 'type': 'function',
+ 'function': {
+ 'name': 'search',
+ 'arguments': {
+ 'input': 'hello world'
+ }
+ },
+ }]
+
+ def test_parse_multiple_actions(self):
+ text = ('Thought: a\nAction: tool_a[x]\n'
+ 'Observation: ok\n'
+ 'Thought: b\nAction: tool_b[y z]')
+ out = self.p.parse(text)
+ assert [c['function']['name'] for c in out] == ['tool_a', 'tool_b']
+ assert out[1]['function']['arguments'] == {'input': 'y z'}
+
+ def test_clean_removes_action_lines(self):
+ text = 'Thought: hi\nAction: search[x]\nDone'
+ cleaned = self.p.clean(text)
+ assert 'Action: search' not in cleaned
+ assert 'Thought: hi' in cleaned
+ assert 'Done' in cleaned
+
+ def test_parse_empty(self):
+ assert self.p.parse('') == []
+
+
+class TestReActStreaming:
+ """ReAct has no marker → falls back to plain content passthrough.
+
+ Detection is a final-pass concern; streaming preserves content faithfully.
+ """
+
+ def test_passthrough_when_no_marker_parser(self):
+ # 'react-agent' doesn't match HermesQwen ('qwen' substring) → no parser
+ # cached → passthrough mode.
+ events, state = _stream('react-agent', [
+ ('Thought: hi\n', False),
+ ('Action: foo[bar]\n', False),
+ ('done', False),
+ ('', True),
+ ])
+ rebuilt = ''.join(e['content'] for e in events if 'content' in e)
+ assert rebuilt == 'Thought: hi\nAction: foo[bar]\ndone'
+ assert state.get('parser') is None
+
+ def test_no_tool_calls_event_emitted(self):
+ events, _ = _stream('react-agent', [
+ ('Action: foo[bar]', True),
+ ])
+ assert all('tool_calls' not in e for e in events)
+
+
+# ---------------------------------------------------------------------------
+# ClineParser
+# ---------------------------------------------------------------------------
+
+
+class TestClineParser:
+
+ def setup_method(self):
+ self.p = ClineParser()
+
+ def test_detect_simple_tool(self):
+ assert self.p.detect('foo.py')
+
+ def test_detect_ignores_html_like_tags(self):
+ # ``think`` / ``code`` are denied — even with inner content they aren't
+ # treated as tool calls.
+ assert not self.p.detect('x')
+ assert not self.p.detect('x')
+
+ def test_detect_requires_inner_param(self):
+ # No inner ``VAL`` → not a Cline call.
+ assert not self.p.detect('just text')
+
+ def test_detect_ignores_hermes_block(self):
+ # Hermes already owns ```` — Cline must skip it.
+ assert not self.p.detect('{"name":"f","arguments":{}}')
+
+ def test_no_marker_for_streaming(self):
+ # Outer tag varies per call — streaming uses passthrough, not the
+ # marker state machine.
+ assert self.p.open_marker is None
+ assert self.p.close_marker is None
+
+ def test_does_not_match_any_model_by_default(self):
+ # Cline is an app-level prompt protocol, not a model-family format.
+ assert not self.p.matches_model('qwen2.5')
+ assert not self.p.matches_model('claude-3')
+
+ def test_parse_single_arg(self):
+ text = 'src/foo.py'
+ out = self.p.parse(text)
+ assert out == [{
+ 'type': 'function',
+ 'function': {
+ 'name': 'read_file',
+ 'arguments': {
+ 'path': 'src/foo.py'
+ }
+ },
+ }]
+
+ def test_parse_multi_arg_with_whitespace(self):
+ text = ('\n'
+ ' ls -la\n'
+ ' false\n'
+ '')
+ out = self.p.parse(text)
+ fn = out[0]['function']
+ assert fn['name'] == 'execute_command'
+ assert fn['arguments'] == {'command': 'ls -la', 'requires_approval': 'false'}
+
+ def test_parse_multiple_blocks(self):
+ text = ('a'
+ ' between '
+ 'btrue')
+ out = self.p.parse(text)
+ assert [c['function']['name'] for c in out] == ['read_file', 'list_files']
+ assert out[1]['function']['arguments'] == {'path': 'b', 'recursive': 'true'}
+
+ def test_parse_skips_hermes_block(self):
+ text = '{"name":"f","arguments":{}}'
+ assert self.p.parse(text) == []
+
+ def test_clean_strips_tool_blocks(self):
+ text = 'before x after'
+ assert self.p.clean(text) == 'before after'
+
+ def test_clean_preserves_non_tool_xml(self):
+ text = 'reasoning x tail'
+ cleaned = self.p.clean(text)
+ assert 'reasoning' in cleaned
+ assert '' not in cleaned
+ assert 'tail' in cleaned
+
+ def test_clean_empty(self):
+ assert self.p.clean('') == ''
+
+
+class TestClineStreaming:
+ """Cline streams as plain content (no fixed open marker)."""
+
+ def test_content_passthrough_lossless_across_chunk_boundaries(self):
+ full = ('intro src/foo.py outro'
+ ' next x')
+ # Chunk every 4 chars — boundaries fall inside tags, args, etc.
+ chunks = [(full[i:i + 4], False) for i in range(0, len(full), 4)]
+ chunks.append(('', True))
+ events, _ = _stream('cline-bot', chunks)
+ rebuilt = ''.join(e['content'] for e in events if 'content' in e)
+ assert rebuilt == full
+ # No tool_calls events because no parser was selected by model_id.
+ assert all('tool_calls' not in e for e in events)
+
+
+# ---------------------------------------------------------------------------
+# Registry round-robin & helpers
+# ---------------------------------------------------------------------------
+
+
+class TestRegistryRoundRobin:
+
+ def test_first_match_wins_no_nested_reparse(self):
+ # Hermes block must take ownership; ReAct/Cline shouldn't see it.
+ text = '{"name":"f","arguments":{}}'
+ parser = ToolCallRegistry.detect_first(text)
+ assert parser is not None and parser.name == 'hermes_qwen'
+
+ def test_cline_wins_for_xml_tools(self):
+ text = 'x'
+ parser = ToolCallRegistry.detect_first(text)
+ assert parser is not None and parser.name == 'cline'
+
+ def test_react_wins_for_action_keyword(self):
+ text = 'Thought: hi\nAction: search[x]'
+ parser = ToolCallRegistry.detect_first(text)
+ assert parser is not None and parser.name == 'react'
+
+ def test_no_parser_for_plain_text(self):
+ assert ToolCallRegistry.detect_first('just some plain text') is None
+ assert ToolCallRegistry.detect_first('') is None
+
+ def test_select_for_qwen_picks_hermes(self):
+ parser = ToolCallRegistry.select_for_model('qwen2.5-7b')
+ assert parser is not None and parser.name == 'hermes_qwen'
+
+ def test_select_for_unknown_returns_none(self):
+ assert ToolCallRegistry.select_for_model('llama-3.1-8b') is None
+ assert ToolCallRegistry.select_for_model(None) is None
+
+
+class TestTrailingPrefixOf:
+ """Holdback length helper used by the marker state machine."""
+
+ def test_no_prefix(self):
+ assert trailing_prefix_of('hello world', '') == 0
+
+ def test_partial_prefix_4_chars(self):
+ # buf ends with '' length 4.
+ assert trailing_prefix_of('hello ') == 4
+
+ def test_partial_prefix_1_char(self):
+ assert trailing_prefix_of('hello <', '') == 1
+
+ def test_full_marker_returns_zero(self):
+ # Full marker at end is NOT a strict prefix (search range is 1..len-1),
+ # so the helper returns 0 — block code path will see the marker via
+ # ``find()`` rather than holdback.
+ assert trailing_prefix_of('text', '') == 0
+
+ def test_empty_buf(self):
+ assert trailing_prefix_of('', '') == 0
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])