Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions cookbook/client/tinker/modelscope/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from tqdm import tqdm
from typing import Any, Dict, List

import swanlab

from tinker import types
from twinkle.tracker import register_tracker, dispatch
from twinkle.tracker.swanlab import SwanLabTracker
from twinkle import init_tinker_client, get_logger
from twinkle.dataset import Dataset, DatasetMeta, LazyDataset
from twinkle.dataloader import DataLoader
Expand Down Expand Up @@ -96,10 +96,9 @@ def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# ---------------------------------------------------------------------------

def train():
# Step 0: Initialize SwanLab if enabled
# Step 0: Register tracker if enabled
if use_swanlab:
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
swanlab.init(
register_tracker(SwanLabTracker(
project='twinkle-dpo',
experiment_name='dpo-lora-training',
config={
Expand All @@ -111,8 +110,9 @@ def train():
'max_length': max_length,
'lora_rank': lora_rank,
},
)
logger.info('SwanLab initialized')
api_key=os.environ.get('SWANLAB_API_KEY'),
))
logger.info('SwanLabTracker registered')

# Step 1: Prepare dataset & dataloader
logger.info('Loading DPO dataset...')
Expand Down Expand Up @@ -188,9 +188,9 @@ def train():

logger.info(f'[Step {step}] metrics={optim_result.metrics}')

# Log metrics to SwanLab
# Dispatch metrics to registered trackers
if use_swanlab and optim_result.metrics:
swanlab.log(optim_result.metrics, step=step)
dispatch(optim_result.metrics, step=step)

# Step 4: Save checkpoint
save_result = training_client.save_state('dpo-lora-final').result()
Expand Down
18 changes: 9 additions & 9 deletions cookbook/client/tinker/self_host/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from tqdm import tqdm
from typing import Any, Dict, List

import swanlab

from tinker import types
from twinkle.tracker import register_tracker, dispatch
from twinkle.tracker.swanlab import SwanLabTracker
from twinkle import init_tinker_client, get_logger
from twinkle.dataset import Dataset, DatasetMeta, LazyDataset
from twinkle.dataloader import DataLoader
Expand Down Expand Up @@ -96,10 +96,9 @@ def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# ---------------------------------------------------------------------------

def train():
# Step 0: Initialize SwanLab if enabled
# Step 0: Register tracker if enabled
if use_swanlab:
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
swanlab.init(
register_tracker(SwanLabTracker(
project='twinkle-dpo',
experiment_name='dpo-lora-training',
config={
Expand All @@ -111,8 +110,9 @@ def train():
'max_length': max_length,
'lora_rank': lora_rank,
},
)
logger.info('SwanLab initialized')
api_key=os.environ.get('SWANLAB_API_KEY'),
))
logger.info('SwanLabTracker registered')

# Step 1: Prepare dataset & dataloader
logger.info('Loading DPO dataset...')
Expand Down Expand Up @@ -188,9 +188,9 @@ def train():

logger.info(f'[Step {step}] metrics={optim_result.metrics}')

# Log metrics to SwanLab
# Dispatch metrics to registered trackers
if use_swanlab and optim_result.metrics:
swanlab.log(optim_result.metrics, step=step)
dispatch(optim_result.metrics, step=step)

# Step 4: Save checkpoint
save_result = training_client.save_state('dpo-lora-final').result()
Expand Down
18 changes: 9 additions & 9 deletions cookbook/client/twinkle/self_host/short_math_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from peft import LoraConfig
from typing import List, Tuple, Dict, Any

import swanlab

from twinkle import get_logger
from twinkle.tracker import register_tracker, dispatch
from twinkle.tracker.swanlab import SwanLabTracker
from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.advantage import GRPOAdvantage
Expand Down Expand Up @@ -119,10 +119,9 @@ def compute_rewards(


def train():
# Step 0: Initialize SwanLab if enabled
# Step 0: Register tracker if enabled
if USE_SWANLAB:
swanlab.login(api_key=os.environ.get('SWANLAB_API_KEY', ''))
swanlab.init(
register_tracker(SwanLabTracker(
project=SWANLAB_PROJECT,
experiment_name=SWANLAB_EXPERIMENT_NAME,
config={
Expand All @@ -136,8 +135,9 @@ def train():
'sync_interval': SYNC_INTERVAL,
'gradient_accumulation_steps': GRADIENT_ACCUMULATION_STEPS,
},
)
logger.info('SwanLab initialized')
api_key=os.environ.get('SWANLAB_API_KEY', ''),
))
logger.info('SwanLabTracker registered')

# Step 1: Initialize the Twinkle client
client = init_twinkle_client(
Expand Down Expand Up @@ -286,9 +286,9 @@ def train():
log_dict['train/frac_reward_zero_std'] = frac_zero_std
logger.info(f'Step {step}: {log_dict}')

# Log metrics to SwanLab
# Dispatch metrics to registered trackers
if USE_SWANLAB and log_dict:
swanlab.log(log_dict, step=step)
dispatch(log_dict, step=step)

step += 1
metrics.reset()
Expand Down
17 changes: 10 additions & 7 deletions cookbook/rl/short_math_grpo_multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from twinkle.reward.base import Reward
from twinkle.sampler import vLLMSampler
from twinkle.preprocessor.llm import GSM8KProcessor
from twinkle.tracker import register_tracker, dispatch
from twinkle.tracker.swanlab import SwanLabTracker

logger = get_logger()

Expand Down Expand Up @@ -59,12 +61,6 @@
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')

import swanlab
swanlab.init(
project='twinkle',
)


# ========== Reward Functions ==========
class GSM8KBrevityReward(Reward):
"""Brevity reward: rewards shorter completions that contain a valid answer.
Expand Down Expand Up @@ -122,6 +118,11 @@ def compute_rewards(

# ========== Main ==========
def main():
# Register SwanLab tracker
register_tracker(SwanLabTracker(
project='twinkle',
))

# Device groups: 8 GPUs for model (tp=2 x ep=2 x pp=2), 4 GPUs for sampler (dp=2 x tp=2)
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
Expand Down Expand Up @@ -292,7 +293,9 @@ def main():

log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME))
swanlab.log(log_dict)
# model.calculate_metric() already dispatches model metrics internally;
# this dispatch sends the full merged set for reward coverage.
dispatch(log_dict, step=optim_step)
metrics.reset()
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')

Expand Down
84 changes: 84 additions & 0 deletions cookbook/transformers/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from peft import LoraConfig
import twinkle
from twinkle import get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor
from twinkle.tracker import register_tracker, list_trackers, clear_trackers
logger = get_logger()
# ── Configuration ──────────────────────────────────────────────────────────────
MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct'
DATASET_ID = 'ms://swift/self-cognition'
TEMPLATE_NAME = 'Template'
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
TRAIN_STEPS = 5
# ── Tracker selection ──────────────────────────────────────────────────────────
def setup_tracker():
"""Register either SwanLabTracker (if API key available) or PrintTracker."""
if os.environ.get('SWANLAB_API_KEY'):
from twinkle.tracker.swanlab import SwanLabTracker
tracker = SwanLabTracker(
project='twinkle-test',
experiment_name='tracker-integration-test',
config={'model': MODEL_ID, 'lr': LEARNING_RATE, 'steps': TRAIN_STEPS},
output_dir='./test_tracker_output',
)
register_tracker(tracker)
logger.info('SwanLabTracker registered — project=twinkle-test')
return tracker
else:
from twinkle.tracker import ExperimentTracker
class PrintTracker(ExperimentTracker):
def __init__(self):
self.logged: list[tuple[int, dict]] = []
def log(self, data: dict, step: int) -> None:
self.logged.append((step, data))
logger.info('[PrintTracker] step=%s metrics=%s', step, data)
def cleanup(self) -> None:
logger.info('[PrintTracker] cleanup — %s dispatches', len(self.logged))
tracker = PrintTracker()
register_tracker(tracker)
logger.info('PrintTracker registered (set SWANLAB_API_KEY for SwanLab)')
return tracker
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
twinkle.initialize(mode='local', seed=42)
tracker = setup_tracker()
assert len(list_trackers()) == 1
logger.info('Tracker ready: %s', type(tracker).__name__)
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(10)))
dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID)
dataset.map(SelfCognitionProcessor('test_model', 'test_author'))
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
model = TransformersModel(model_id=MODEL_ID)
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1)
model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler', num_warmup_steps=1, num_training_steps=TRAIN_STEPS
)
for step, batch in enumerate(dataloader):
if step >= TRAIN_STEPS:
break
model.forward_backward(inputs=batch)
model.clip_grad_and_step()
metric = model.calculate_metric(is_training=True)
logger.info('Step %s raw metric: %s', step + 1, metric)
# Verification (only works for PrintTracker)
if hasattr(tracker, 'logged'):
n = len(tracker.logged)
assert n > 0, 'No metrics were dispatched — dispatch() not called'
logger.info('=== Dispatch verification ===')
logger.info('Total dispatches: %s', n)
for i, (step, data) in enumerate(tracker.logged):
all_floats = all(isinstance(v, float) for v in data.values())
logger.info(' [%s] step=%s keys=%s all_float=%s', i + 1, step, list(data.keys()), all_floats)
clear_trackers()
assert len(list_trackers()) == 0
logger.info('=== Test complete ===')
if __name__ == '__main__':
main()
Loading
Loading