From 8723c5763c8f59a664440fe2d4a48972101756c5 Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 12 May 2026 14:27:22 +0800 Subject: [PATCH 1/8] init tracker --- src/twinkle/infra/__init__.py | 5 + src/twinkle/model/optimizer_group.py | 18 +++ src/twinkle/tracker/__init__.py | 201 +++++++++++++++++++++++++++ src/twinkle/tracker/base.py | 31 +++++ src/twinkle/tracker/swanlab.py | 75 ++++++++++ src/twinkle/tracker/wandb.py | 57 ++++++++ 6 files changed, 387 insertions(+) create mode 100644 src/twinkle/tracker/__init__.py create mode 100644 src/twinkle/tracker/base.py create mode 100644 src/twinkle/tracker/swanlab.py create mode 100644 src/twinkle/tracker/wandb.py diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index aa559e76..8696c588 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -63,6 +63,11 @@ def initialize(mode: Literal['local', 'ray'] = 'local', if seed is not None: _seed = seed framework_util.seed_everything(seed, full_determinism) + + # Inform the tracker module of the current distributed rank + from twinkle.tracker import set_rank as _set_tracker_rank + _set_tracker_rank(Platform.get_rank()) + if _mode == 'local': if groups is not None: _device_group = groups diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py index 5fdb89f7..2b85b2e7 100644 --- a/src/twinkle/model/optimizer_group.py +++ b/src/twinkle/model/optimizer_group.py @@ -80,6 +80,24 @@ def calculate_metrics(self, is_training): results = {} for metric in status.metrics: results.update(metric.calculate()) + + # Enrich results with training-loop metrics + if self._last_grad_norm: + results['grad_norm'] = self._last_grad_norm + if status.num_tokens: + results['num_tokens'] = status.num_tokens + status.inputs = None status.outputs = None + + # Dispatch to registered experiment trackers + if is_training: + from twinkle.tracker import dispatch, dispatch_hyperparams + dispatch(results, step=self.cur_step) + # Lazily log hyperparams on the first training metrics call + dispatch_hyperparams( + {'adapter_name': self.adapter_name, + 'gradient_accumulation_steps': self.gradient_accumulation_steps}, + adapter_name=self.adapter_name) + return results diff --git a/src/twinkle/tracker/__init__.py b/src/twinkle/tracker/__init__.py new file mode 100644 index 00000000..b60b57c7 --- /dev/null +++ b/src/twinkle/tracker/__init__.py @@ -0,0 +1,201 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Experiment tracking dispatch for twinkle training metrics. + +Usage:: + + from twinkle.tracker import SwanLabTracker, register_tracker + + register_tracker(SwanLabTracker(project="my-project")) + # training loop unchanged — dispatch happens automatically. + +Or via environment variables (no code change):: + + TWINKLE_TRACKERS=swanlab SWANLAB_API_KEY=xxx python train.py +""" + +import atexit +import logging +import os +from typing import Any, Dict, List, Optional + +from twinkle.server.model.backends.common import clean_metrics + +from .base import ExperimentTracker +from .swanlab import SwanLabTracker +from .wandb import WandbTracker + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Global state +# --------------------------------------------------------------------------- +_trackers: List[ExperimentTracker] = [] +_rank: int = 0 +_hparams_dispatched: set = set() # track which adapters have sent hyperparams + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def register_tracker(tracker: ExperimentTracker) -> None: + """Register an experiment tracker. + + Multiple trackers can be registered — ``dispatch`` will send metric + data to each one in order. Trackers are cleaned up automatically on + normal interpreter exit via ``atexit``. + """ + _trackers.append(tracker) + + +def set_rank(rank: int) -> None: + """Set the distributed rank for the current process. + + Called by ``twinkle.initialize()`` — not intended for direct use. + Only the process with ``rank == 0`` dispatches metrics; all others + are no-ops. + """ + global _rank + _rank = rank + + +def list_trackers() -> List[ExperimentTracker]: + """Return a snapshot of currently registered trackers.""" + return list(_trackers) + + +def clear_trackers() -> None: + """Call ``cleanup()`` on every registered tracker and clear the list. + + Registered automatically via ``atexit``; may also be called manually. + """ + for t in _trackers: + try: + t.cleanup() + except Exception: + logger.warning("Tracker %s.cleanup() failed", type(t).__name__, exc_info=True) + _trackers.clear() + + +# --------------------------------------------------------------------------- +# Internal dispatch +# --------------------------------------------------------------------------- + + +def dispatch(data: Dict[str, float], step: int) -> None: + """Send computed metrics to all registered trackers. + + Metric values are normalized to ``float`` via :func:`clean_metrics` + before dispatching. Only the rank-0 process performs the dispatch; + all other ranks return immediately with no overhead. + + Args: + data: Raw metric dict (may contain strings, ints, floats). + step: Current training step (``cur_step`` from optimizer group). + """ + if not _trackers: + return + if _rank != 0: + return + + cleaned = clean_metrics(data) + if not cleaned: + return + + for tracker in _trackers: + try: + tracker.log(cleaned, step=step) + except Exception: + logger.warning("Tracker %s.log() failed", type(tracker).__name__, exc_info=True) + + +def dispatch_hyperparams(params: Dict[str, Any], adapter_name: Optional[str] = None) -> None: + """Send hyperparameters to all registered trackers (call once at training start). + + Idempotent per ``(adapter_name,)`` — repeated calls with the same + *adapter_name* are silently ignored so that this can safely be called + from ``calculate_metrics`` on its first invocation without + flooding trackers with redundant config updates. + + Args: + params: Flat or nested dict of hyperparameters (e.g. model config, + training args, LoRA config). + adapter_name: Optional adapter identifier. If omitted, the params + are dispatched unconditionally on every call. + """ + if not _trackers or _rank != 0: + return + + # Idempotency guard: only dispatch once per adapter + if adapter_name is not None: + if adapter_name in _hparams_dispatched: + return + _hparams_dispatched.add(adapter_name) + + for tracker in _trackers: + try: + tracker.log_hyperparams(params) + except Exception: + logger.warning("Tracker %s.log_hyperparams() failed", type(tracker).__name__, exc_info=True) + + +# --------------------------------------------------------------------------- +# Environment-variable auto-initialisation +# --------------------------------------------------------------------------- + +_AUTO_INIT_DONE = False + + +def _auto_init_from_env() -> None: + """Initialise trackers from environment variables (called once at import). + + Reads ``TWINKLE_TRACKERS`` (comma-separated, e.g. ``wandb,swanlab``) + and backend-specific env vars, then registers matching tracker instances + automatically. + + This lets users enable experiment tracking without *any* code change:: + + TWINKLE_TRACKERS=wandb WANDB_PROJECT=my-project python train.py + """ + global _AUTO_INIT_DONE + if _AUTO_INIT_DONE: + return + _AUTO_INIT_DONE = True + + trackers_str = os.environ.get("TWINKLE_TRACKERS", "").strip() + if not trackers_str: + return + + project = os.environ.get("TWINKLE_TRACKER_PROJECT", "twinkle-training") + experiment_name = os.environ.get("TWINKLE_TRACKER_EXPERIMENT", None) + + for name in (t.strip().lower() for t in trackers_str.split(",") if t.strip()): + try: + if name == "wandb": + _trackers.append(WandbTracker( + project=project, + experiment_name=experiment_name, + entity=os.environ.get("WANDB_ENTITY"), + )) + logger.info("Auto-registered WandbTracker from TWINKLE_TRACKERS env var") + elif name == "swanlab": + _trackers.append(SwanLabTracker( + project=project, + experiment_name=experiment_name, + output_dir=os.environ.get("TWINKLE_OUTPUT_DIR"), + )) + logger.info("Auto-registered SwanLabTracker from TWINKLE_TRACKERS env var") + else: + logger.warning("Unknown tracker backend in TWINKLE_TRACKERS: %s", name) + except Exception: + logger.warning("Failed to auto-init tracker '%s' from env", name, exc_info=True) + + +# Run auto-init once at import time (before user code or atexit runs) +_auto_init_from_env() + + +# --------------------------------------------------------------------------- +# At-exit cleanup +# --------------------------------------------------------------------------- +atexit.register(clear_trackers) diff --git a/src/twinkle/tracker/base.py b/src/twinkle/tracker/base.py new file mode 100644 index 00000000..a6a5af22 --- /dev/null +++ b/src/twinkle/tracker/base.py @@ -0,0 +1,31 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Abstract base class for experiment trackers.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class ExperimentTracker(ABC): + """Base class for experiment tracking backends (SwanLab, W&B, etc.). + + Subclasses must implement :meth:`log`. The optional methods + :meth:`log_hyperparams` and :meth:`cleanup` have reasonable + no-op defaults. + """ + + @abstractmethod + def log(self, data: Dict[str, float], step: int) -> None: + """Log a set of metric values. + + Args: + data: Metric names mapped to numeric values. The dict has + already been normalised by :func:`clean_metrics` so + values are guaranteed to be ``float``. + step: The current training step. + """ + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Record hyperparameters (optional).""" + + def cleanup(self) -> None: + """Flush pending data and release resources (optional).""" diff --git a/src/twinkle/tracker/swanlab.py b/src/twinkle/tracker/swanlab.py new file mode 100644 index 00000000..ae1790a5 --- /dev/null +++ b/src/twinkle/tracker/swanlab.py @@ -0,0 +1,75 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""SwanLab experiment tracker.""" + +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional + +from .base import ExperimentTracker + +logger = logging.getLogger(__name__) + + +class SwanLabTracker(ExperimentTracker): + """Experiment tracker backed by `SwanLab `_. + + Args: + project: SwanLab project name. + experiment_name: Optional run / experiment name. + config: Optional dict of hyperparameters to record. + output_dir: If set, the SwanLab experiment URL is written to + ``{output_dir}/swanlab_config.json`` so users can easily + find the online dashboard. + **kwargs: Passed through to ``swanlab.init()``. + """ + + def __init__( + self, + project: str, + experiment_name: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + output_dir: Optional[str] = None, + **kwargs, + ): + import swanlab + + api_key = kwargs.pop("api_key", None) or os.environ.get("SWANLAB_API_KEY") + logdir = kwargs.pop("logdir", None) or os.environ.get("SWANLAB_LOG_DIR", "swanlog") + mode = kwargs.pop("mode", None) or os.environ.get("SWANLAB_MODE", "cloud") + + if api_key: + swanlab.login(api_key) + + self._run = swanlab.init( + project=project, + experiment_name=experiment_name, + config={"framework": "\u2728Twinkle", **(config or {})}, + logdir=logdir, + mode=mode, + **kwargs, + ) + + if output_dir: + self._save_experiment_info(output_dir) + + def log(self, data: Dict[str, float], step: int) -> None: + self._run.log(data, step=step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + self._run.config.update(params) + + def cleanup(self) -> None: + try: + self._run.finish() + except Exception: + logger.warning("SwanLab finish() failed", exc_info=True) + + def _save_experiment_info(self, output_dir: str) -> None: + try: + info = {"swanlab_experiment_url": self._run.get_run().url} + out = Path(output_dir) / "swanlab_config.json" + out.write_text(json.dumps(info, indent=2)) + except Exception: + pass diff --git a/src/twinkle/tracker/wandb.py b/src/twinkle/tracker/wandb.py new file mode 100644 index 00000000..8a65765e --- /dev/null +++ b/src/twinkle/tracker/wandb.py @@ -0,0 +1,57 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Weights & Biases experiment tracker.""" + +import logging +import os +from typing import Any, Dict, Optional + +from .base import ExperimentTracker + +logger = logging.getLogger(__name__) + + +class WandbTracker(ExperimentTracker): + """Experiment tracker backed by `Weights & Biases `_. + + Args: + project: W&B project name. + experiment_name: Optional run name. + config: Optional dict of hyperparameters. + **kwargs: Passed through to ``wandb.init()``. + """ + + def __init__( + self, + project: str, + experiment_name: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + import wandb + + entity = kwargs.pop("entity", None) or os.environ.get("WANDB_ENTITY") + settings = None + proxy = kwargs.pop("wandb_proxy", None) or os.environ.get("WANDB_PROXY") + if proxy: + settings = wandb.Settings(https_proxy=proxy) + + self._run = wandb.init( + project=project, + name=experiment_name, + entity=entity, + config={"framework": "\u2728Twinkle", **(config or {})}, + settings=settings, + **kwargs, + ) + + def log(self, data: Dict[str, float], step: int) -> None: + self._run.log(data, step=step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + self._run.config.update(params) + + def cleanup(self) -> None: + try: + self._run.finish(exit_code=0) + except Exception: + logger.warning("WandB finish() failed", exc_info=True) From 3b313401ca5f32731d5165a45607fbc51137925f Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 2 Jun 2026 17:17:20 +0800 Subject: [PATCH 2/8] fix swanlab --- cookbook/client/tinker/modelscope/dpo.py | 18 ++-- cookbook/client/tinker/self_host/dpo.py | 18 ++-- .../twinkle/self_host/short_math_grpo.py | 18 ++-- cookbook/rl/short_math_grpo_multi_lora.py | 17 ++-- cookbook/transformers/tracker.py | 84 +++++++++++++++++++ 5 files changed, 121 insertions(+), 34 deletions(-) create mode 100644 cookbook/transformers/tracker.py diff --git a/cookbook/client/tinker/modelscope/dpo.py b/cookbook/client/tinker/modelscope/dpo.py index 23cf5aae..c092cef3 100644 --- a/cookbook/client/tinker/modelscope/dpo.py +++ b/cookbook/client/tinker/modelscope/dpo.py @@ -20,9 +20,9 @@ from tqdm import tqdm from typing import Any, Dict, List -import swanlab - from tinker import types +from twinkle.tracker import register_tracker, dispatch +from twinkle.tracker.swanlab import SwanLabTracker from twinkle import init_tinker_client, get_logger from twinkle.dataset import Dataset, DatasetMeta, LazyDataset from twinkle.dataloader import DataLoader @@ -96,10 +96,9 @@ def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # --------------------------------------------------------------------------- def train(): - # Step 0: Initialize SwanLab if enabled + # Step 0: Register tracker if enabled if use_swanlab: - swanlab.login(api_key=os.environ['SWANLAB_API_KEY']) - swanlab.init( + register_tracker(SwanLabTracker( project='twinkle-dpo', experiment_name='dpo-lora-training', config={ @@ -111,8 +110,9 @@ def train(): 'max_length': max_length, 'lora_rank': lora_rank, }, - ) - logger.info('SwanLab initialized') + api_key=os.environ.get('SWANLAB_API_KEY'), + )) + logger.info('SwanLabTracker registered') # Step 1: Prepare dataset & dataloader logger.info('Loading DPO dataset...') @@ -188,9 +188,9 @@ def train(): logger.info(f'[Step {step}] metrics={optim_result.metrics}') - # Log metrics to SwanLab + # Dispatch metrics to registered trackers if use_swanlab and optim_result.metrics: - swanlab.log(optim_result.metrics, step=step) + dispatch(optim_result.metrics, step=step) # Step 4: Save checkpoint save_result = training_client.save_state('dpo-lora-final').result() diff --git a/cookbook/client/tinker/self_host/dpo.py b/cookbook/client/tinker/self_host/dpo.py index 51474ca0..b2cb535f 100644 --- a/cookbook/client/tinker/self_host/dpo.py +++ b/cookbook/client/tinker/self_host/dpo.py @@ -20,9 +20,9 @@ from tqdm import tqdm from typing import Any, Dict, List -import swanlab - from tinker import types +from twinkle.tracker import register_tracker, dispatch +from twinkle.tracker.swanlab import SwanLabTracker from twinkle import init_tinker_client, get_logger from twinkle.dataset import Dataset, DatasetMeta, LazyDataset from twinkle.dataloader import DataLoader @@ -96,10 +96,9 @@ def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # --------------------------------------------------------------------------- def train(): - # Step 0: Initialize SwanLab if enabled + # Step 0: Register tracker if enabled if use_swanlab: - swanlab.login(api_key=os.environ['SWANLAB_API_KEY']) - swanlab.init( + register_tracker(SwanLabTracker( project='twinkle-dpo', experiment_name='dpo-lora-training', config={ @@ -111,8 +110,9 @@ def train(): 'max_length': max_length, 'lora_rank': lora_rank, }, - ) - logger.info('SwanLab initialized') + api_key=os.environ.get('SWANLAB_API_KEY'), + )) + logger.info('SwanLabTracker registered') # Step 1: Prepare dataset & dataloader logger.info('Loading DPO dataset...') @@ -188,9 +188,9 @@ def train(): logger.info(f'[Step {step}] metrics={optim_result.metrics}') - # Log metrics to SwanLab + # Dispatch metrics to registered trackers if use_swanlab and optim_result.metrics: - swanlab.log(optim_result.metrics, step=step) + dispatch(optim_result.metrics, step=step) # Step 4: Save checkpoint save_result = training_client.save_state('dpo-lora-final').result() diff --git a/cookbook/client/twinkle/self_host/short_math_grpo.py b/cookbook/client/twinkle/self_host/short_math_grpo.py index 871d4599..82dcf2d5 100644 --- a/cookbook/client/twinkle/self_host/short_math_grpo.py +++ b/cookbook/client/twinkle/self_host/short_math_grpo.py @@ -29,9 +29,9 @@ from peft import LoraConfig from typing import List, Tuple, Dict, Any -import swanlab - from twinkle import get_logger +from twinkle.tracker import register_tracker, dispatch +from twinkle.tracker.swanlab import SwanLabTracker from twinkle.reward import GSM8KAccuracyReward from twinkle.reward.base import Reward from twinkle.advantage import GRPOAdvantage @@ -119,10 +119,9 @@ def compute_rewards( def train(): - # Step 0: Initialize SwanLab if enabled + # Step 0: Register tracker if enabled if USE_SWANLAB: - swanlab.login(api_key=os.environ.get('SWANLAB_API_KEY', '')) - swanlab.init( + register_tracker(SwanLabTracker( project=SWANLAB_PROJECT, experiment_name=SWANLAB_EXPERIMENT_NAME, config={ @@ -136,8 +135,9 @@ def train(): 'sync_interval': SYNC_INTERVAL, 'gradient_accumulation_steps': GRADIENT_ACCUMULATION_STEPS, }, - ) - logger.info('SwanLab initialized') + api_key=os.environ.get('SWANLAB_API_KEY', ''), + )) + logger.info('SwanLabTracker registered') # Step 1: Initialize the Twinkle client client = init_twinkle_client( @@ -286,9 +286,9 @@ def train(): log_dict['train/frac_reward_zero_std'] = frac_zero_std logger.info(f'Step {step}: {log_dict}') - # Log metrics to SwanLab + # Dispatch metrics to registered trackers if USE_SWANLAB and log_dict: - swanlab.log(log_dict, step=step) + dispatch(log_dict, step=step) step += 1 metrics.reset() diff --git a/cookbook/rl/short_math_grpo_multi_lora.py b/cookbook/rl/short_math_grpo_multi_lora.py index 9dad8df3..811efae8 100644 --- a/cookbook/rl/short_math_grpo_multi_lora.py +++ b/cookbook/rl/short_math_grpo_multi_lora.py @@ -31,6 +31,8 @@ from twinkle.reward.base import Reward from twinkle.sampler import vLLMSampler from twinkle.preprocessor.llm import GSM8KProcessor +from twinkle.tracker import register_tracker, dispatch +from twinkle.tracker.swanlab import SwanLabTracker logger = get_logger() @@ -59,12 +61,6 @@ SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' 'and put your final answer within \\boxed{}.') -import swanlab -swanlab.init( - project='twinkle', -) - - # ========== Reward Functions ========== class GSM8KBrevityReward(Reward): """Brevity reward: rewards shorter completions that contain a valid answer. @@ -122,6 +118,11 @@ def compute_rewards( # ========== Main ========== def main(): + # Register SwanLab tracker + register_tracker(SwanLabTracker( + project='twinkle', + )) + # Device groups: 8 GPUs for model (tp=2 x ep=2 x pp=2), 4 GPUs for sampler (dp=2 x tp=2) device_groups = [ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), @@ -292,7 +293,9 @@ def main(): log_dict = metrics.calculate() log_dict.update(model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)) - swanlab.log(log_dict) + # model.calculate_metric() already dispatches model metrics internally; + # this dispatch sends the full merged set for reward coverage. + dispatch(log_dict, step=optim_step) metrics.reset() logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') diff --git a/cookbook/transformers/tracker.py b/cookbook/transformers/tracker.py new file mode 100644 index 00000000..e134e5bc --- /dev/null +++ b/cookbook/transformers/tracker.py @@ -0,0 +1,84 @@ +import os +from peft import LoraConfig +import twinkle +from twinkle import get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.tracker import register_tracker, list_trackers, clear_trackers +logger = get_logger() +# ── Configuration ────────────────────────────────────────────────────────────── +MODEL_ID = "ms://Qwen/Qwen2.5-0.5B-Instruct" +DATASET_ID = "ms://swift/self-cognition" +TEMPLATE_NAME = "Template" +BATCH_SIZE = 1 +LEARNING_RATE = 1e-4 +TRAIN_STEPS = 5 +# ── Tracker selection ────────────────────────────────────────────────────────── +def setup_tracker(): + """Register either SwanLabTracker (if API key available) or PrintTracker.""" + if os.environ.get("SWANLAB_API_KEY"): + from twinkle.tracker.swanlab import SwanLabTracker + tracker = SwanLabTracker( + project="twinkle-test", + experiment_name="tracker-integration-test", + config={"model": MODEL_ID, "lr": LEARNING_RATE, "steps": TRAIN_STEPS}, + output_dir="./test_tracker_output", + ) + register_tracker(tracker) + logger.info("SwanLabTracker registered — project=twinkle-test") + return tracker + else: + from twinkle.tracker import ExperimentTracker + class PrintTracker(ExperimentTracker): + def __init__(self): + self.logged: list[tuple[int, dict]] = [] + def log(self, data: dict, step: int) -> None: + self.logged.append((step, data)) + logger.info("[PrintTracker] step=%s metrics=%s", step, data) + def cleanup(self) -> None: + logger.info("[PrintTracker] cleanup — %s dispatches", len(self.logged)) + tracker = PrintTracker() + register_tracker(tracker) + logger.info("PrintTracker registered (set SWANLAB_API_KEY for SwanLab)") + return tracker +# ── Main ────────────────────────────────────────────────────────────────────── +def main(): + twinkle.initialize(mode="local", seed=42) + tracker = setup_tracker() + assert len(list_trackers()) == 1 + logger.info("Tracker ready: %s", type(tracker).__name__) + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(10))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor("test_model", "test_author")) + dataset.encode() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + model = TransformersModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules="all-linear") + model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) + model.set_optimizer(optimizer_cls="AdamW", lr=LEARNING_RATE) + model.set_lr_scheduler( + scheduler_cls="CosineWarmupScheduler", num_warmup_steps=1, num_training_steps=TRAIN_STEPS + ) + for step, batch in enumerate(dataloader): + if step >= TRAIN_STEPS: + break + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + metric = model.calculate_metric(is_training=True) + logger.info("Step %s raw metric: %s", step + 1, metric) + # Verification (only works for PrintTracker) + if hasattr(tracker, "logged"): + n = len(tracker.logged) + assert n > 0, "No metrics were dispatched — dispatch() not called" + logger.info("=== Dispatch verification ===") + logger.info("Total dispatches: %s", n) + for i, (step, data) in enumerate(tracker.logged): + all_floats = all(isinstance(v, float) for v in data.values()) + logger.info(" [%s] step=%s keys=%s all_float=%s", i + 1, step, list(data.keys()), all_floats) + clear_trackers() + assert len(list_trackers()) == 0 + logger.info("=== Test complete ===") +if __name__ == "__main__": + main() \ No newline at end of file From 34a5817ec7ed4641ff791cc6bb217e6a4a687be6 Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 2 Jun 2026 17:44:25 +0800 Subject: [PATCH 3/8] add tracker unit tests Cover SwanLabTracker behavioral tests (23 cases) and dispatch system tests (38 cases): register_tracker, dispatch, dispatch_hyperparams, clear_trackers, set_rank, auto_init_from_env. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- tests/tracker/test_dispatch.py | 500 +++++++++++++++++++++++++++++++++ tests/tracker/test_swanlab.py | 326 +++++++++++++++++++++ 2 files changed, 826 insertions(+) create mode 100644 tests/tracker/test_dispatch.py create mode 100644 tests/tracker/test_swanlab.py diff --git a/tests/tracker/test_dispatch.py b/tests/tracker/test_dispatch.py new file mode 100644 index 00000000..6aa8a0a9 --- /dev/null +++ b/tests/tracker/test_dispatch.py @@ -0,0 +1,500 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for the dispatch system (``twinkle.tracker.__init__``). + +Covers ``register_tracker``, ``dispatch``, ``dispatch_hyperparams``, +``clear_trackers``, ``set_rank``, and ``_auto_init_from_env``. +""" + +import logging +import os +import sys +import pytest +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# Module-level dependency mocks (mirrors test_swanlab.py) +# --------------------------------------------------------------------------- +for _mod in [ + "datasets", + "datasets.utils", + "datasets.utils.filelock", + "torch", + "accelerate", + "transformers", + "peft", + "omegaconf", + "modelscope", + "safetensors", + "fastapi", + "tinker", + "PIL", + "PIL.Image", + "wandb", +]: + sys.modules.setdefault(_mod, MagicMock()) + +sys.modules.setdefault("twinkle.server", MagicMock()) +sys.modules.setdefault("twinkle.server.model", MagicMock()) +sys.modules.setdefault("twinkle.server.model.backends", MagicMock()) +_common = MagicMock() +_common.clean_metrics = lambda d: {k: float(v) for k, v in d.items() if isinstance(v, (int, float))} +sys.modules["twinkle.server.model.backends.common"] = _common + +sys.modules.setdefault("twinkle.utils.platforms", MagicMock()) +sys.modules.setdefault("twinkle.utils.logger", MagicMock()) +sys.modules.setdefault("swanlab", MagicMock()) + +# Now safe to import +from twinkle.tracker.base import ExperimentTracker +import twinkle.tracker as tracker_mod +from twinkle.tracker import ( + register_tracker, + dispatch, + dispatch_hyperparams, + clear_trackers, + list_trackers, + set_rank, +) + + +# --------------------------------------------------------------------------- +# Spy tracker +# --------------------------------------------------------------------------- +class SpyTracker(ExperimentTracker): + """Minimal tracker that records all calls for later assertion.""" + + def __init__(self, name: str = "spy"): + self.name = name + self.reset() + + def reset(self): + self.logged: list[tuple[dict, int]] = [] + self.hyperparams: list[dict] = [] + self.cleanup_called = False + + def log(self, data: dict, step: int) -> None: + self.logged.append((dict(data), step)) + + def log_hyperparams(self, params: dict) -> None: + self.hyperparams.append(dict(params)) + + def cleanup(self) -> None: + self.cleanup_called = True + + def __repr__(self): + return f"SpyTracker({self.name})" + + +class ErrorTracker(ExperimentTracker): + """Tracker whose ``log()`` raises — used to test exception isolation.""" + + def log(self, data: dict, step: int) -> None: + raise RuntimeError("tracker error") + + def log_hyperparams(self, params: dict) -> None: + raise RuntimeError("hparam error") + + def cleanup(self) -> None: + raise RuntimeError("cleanup error") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _reset_global_state(): + """Reset module-level state before every test.""" + tracker_mod._trackers.clear() + tracker_mod._rank = 0 + tracker_mod._hparams_dispatched.clear() + yield + + +# =================================================================== +# register_tracker / list_trackers +# =================================================================== + +class TestRegistration: + def test_register_one(self): + t = SpyTracker() + register_tracker(t) + assert list_trackers() == [t] + + def test_register_multiple(self): + t1, t2 = SpyTracker("a"), SpyTracker("b") + register_tracker(t1) + register_tracker(t2) + assert list_trackers() == [t1, t2] + + def test_register_returns_none(self): + ret = register_tracker(SpyTracker()) + assert ret is None + + def test_list_trackers_returns_copy(self): + """list_trackers should return a copy, not the internal list.""" + t = SpyTracker() + register_tracker(t) + snapshot = list_trackers() + clear_trackers() + # Snapshot should still have the original reference + assert t in snapshot + + +# =================================================================== +# dispatch +# =================================================================== + +class TestDispatch: + def test_sends_to_all_trackers(self): + t1, t2 = SpyTracker("a"), SpyTracker("b") + register_tracker(t1) + register_tracker(t2) + set_rank(0) + + dispatch({"loss": 0.5, "acc": 0.95}, step=10) + + assert t1.logged == [({"loss": 0.5, "acc": 0.95}, 10)] + assert t2.logged == [({"loss": 0.5, "acc": 0.95}, 10)] + + def test_skipped_on_non_zero_rank(self): + t = SpyTracker() + register_tracker(t) + set_rank(3) + + dispatch({"loss": 0.5}, step=1) + + assert t.logged == [] + + def test_no_trackers_is_noop(self): + """dispatch with no registered trackers should not crash.""" + dispatch({"loss": 0.5}, step=1) # no assert — must not raise + + def test_rank_0_is_default(self): + """Default rank is 0, so dispatch works without explicit set_rank.""" + t = SpyTracker() + register_tracker(t) + dispatch({"loss": 0.5}, step=1) + assert len(t.logged) == 1 + + def test_empty_data_after_clean_metrics_skips(self): + """dispatch returns early when clean_metrics returns empty dict.""" + t = SpyTracker() + register_tracker(t) + set_rank(0) + + # Values that clean_metrics cannot convert to float + dispatch({"invalid": [1, 2, 3], "text": "not-a-number"}, step=5) + + assert t.logged == [] + + def test_exception_isolation(self): + """One tracker raising does not prevent others from receiving.""" + good = SpyTracker("good") + bad = ErrorTracker() + register_tracker(good) + register_tracker(bad) + + dispatch({"loss": 0.5}, step=10) + + assert len(good.logged) == 1 + assert good.logged[0] == ({"loss": 0.5}, 10) + + def test_exception_isolation_reverse_order(self): + """Exception isolation works regardless of tracker order.""" + bad = ErrorTracker() + good = SpyTracker("good") + register_tracker(bad) + register_tracker(good) + + dispatch({"loss": 0.5}, step=10) + + assert len(good.logged) == 1 + + def test_multiple_steps(self): + t = SpyTracker() + register_tracker(t) + set_rank(0) + + dispatch({"loss": 0.5}, step=1) + dispatch({"loss": 0.3}, step=2) + dispatch({"loss": 0.1}, step=3) + + assert len(t.logged) == 3 + assert t.logged[0] == ({"loss": 0.5}, 1) + assert t.logged[1] == ({"loss": 0.3}, 2) + assert t.logged[2] == ({"loss": 0.1}, 3) + + def test_rank_change_during_runtime(self): + """Changing rank mid-training affects subsequent dispatches.""" + t = SpyTracker() + register_tracker(t) + + set_rank(0) + dispatch({"loss": 0.5}, step=1) + assert len(t.logged) == 1 + + set_rank(1) + dispatch({"loss": 0.3}, step=2) + assert len(t.logged) == 1 # no change — rank 1 skipped + + set_rank(0) + dispatch({"loss": 0.1}, step=3) + assert len(t.logged) == 2 # now rank 0 again + + +# =================================================================== +# dispatch_hyperparams +# =================================================================== + +class TestDispatchHyperparams: + def test_sends_to_all_trackers(self): + t1, t2 = SpyTracker("a"), SpyTracker("b") + register_tracker(t1) + register_tracker(t2) + set_rank(0) + + dispatch_hyperparams({"lr": 1e-4}) + + assert t1.hyperparams == [{"lr": 1e-4}] + assert t2.hyperparams == [{"lr": 1e-4}] + + def test_idempotent_with_adapter_name(self): + """Same adapter_name only dispatches once.""" + t = SpyTracker() + register_tracker(t) + set_rank(0) + + dispatch_hyperparams({"lr": 1e-4}, adapter_name="default") + dispatch_hyperparams({"lr": 2e-4}, adapter_name="default") # ignored + dispatch_hyperparams({"batch_size": 32}, adapter_name="default") # ignored + + assert len(t.hyperparams) == 1 + assert t.hyperparams[0] == {"lr": 1e-4} + + def test_different_adapters_separate(self): + """Different adapter_names are each dispatched once.""" + t = SpyTracker() + register_tracker(t) + set_rank(0) + + dispatch_hyperparams({"lr": 1e-4}, adapter_name="lora_a") + dispatch_hyperparams({"lr": 2e-4}, adapter_name="lora_b") + dispatch_hyperparams({"lr": 3e-4}, adapter_name="lora_a") # ignored + + assert len(t.hyperparams) == 2 + assert t.hyperparams[0] == {"lr": 1e-4} + assert t.hyperparams[1] == {"lr": 2e-4} + + def test_without_adapter_sends_every_time(self): + """When adapter_name is None, every call dispatches.""" + t = SpyTracker() + register_tracker(t) + set_rank(0) + + dispatch_hyperparams({"lr": 1e-4}) + dispatch_hyperparams({"lr": 2e-4}) + dispatch_hyperparams({"lr": 3e-4}) + + assert len(t.hyperparams) == 3 + + def test_mixed_adapter_and_no_adapter(self): + """Calls with and without adapter_name interact correctly.""" + t = SpyTracker() + register_tracker(t) + set_rank(0) + + dispatch_hyperparams({"a": 1}, adapter_name="adp") # sent + dispatch_hyperparams({"b": 2}) # sent (no adapter) + dispatch_hyperparams({"c": 3}, adapter_name="adp") # ignored (idempotent) + dispatch_hyperparams({"d": 4}) # sent (no adapter again) + + assert len(t.hyperparams) == 3 + + def test_skipped_on_non_zero_rank(self): + t = SpyTracker() + register_tracker(t) + set_rank(2) + + dispatch_hyperparams({"lr": 1e-4}) + + assert t.hyperparams == [] + + def test_no_trackers_is_noop(self): + dispatch_hyperparams({"lr": 1e-4}, adapter_name="test") + + def test_exception_isolation(self): + good = SpyTracker("good") + bad = ErrorTracker() + register_tracker(good) + register_tracker(bad) + + dispatch_hyperparams({"lr": 1e-4}) + + assert len(good.hyperparams) == 1 + + +# =================================================================== +# clear_trackers +# =================================================================== + +class TestClearTrackers: + def test_calls_cleanup_on_all(self): + t1, t2 = SpyTracker("a"), SpyTracker("b") + register_tracker(t1) + register_tracker(t2) + + clear_trackers() + + assert t1.cleanup_called + assert t2.cleanup_called + assert list_trackers() == [] + + def test_cleanup_exception_isolation(self): + """cleanup() raising on one tracker doesn't break others.""" + bad = ErrorTracker() + good = SpyTracker("good") + register_tracker(bad) + register_tracker(good) + + clear_trackers() # must not raise + + assert good.cleanup_called + assert list_trackers() == [] + + def test_empty_list_is_noop(self): + clear_trackers() # must not raise + assert list_trackers() == [] + + def test_idempotent(self): + """Calling clear_trackers twice is safe.""" + t = SpyTracker() + register_tracker(t) + clear_trackers() + clear_trackers() + assert list_trackers() == [] + + +# =================================================================== +# set_rank +# =================================================================== + +class TestSetRank: + def test_default_rank_is_zero(self): + assert tracker_mod._rank == 0 # after fixture reset + + def test_set_rank_changes_global(self): + set_rank(3) + assert tracker_mod._rank == 3 + + def test_set_rank_zero(self): + set_rank(0) + assert tracker_mod._rank == 0 + + def test_set_rank_negative(self): + """Negative rank values are stored as-is (caller responsibility).""" + set_rank(-1) + # A negative rank will cause dispatch to skip (since rank != 0) + t = SpyTracker() + register_tracker(t) + dispatch({"loss": 0.5}, step=1) + assert t.logged == [] + + +# =================================================================== +# _auto_init_from_env +# =================================================================== + +class TestAutoInitFromEnv: + """Environment-variable auto-initialisation.""" + + def _reset_auto_init(self): + """Allow _auto_init_from_env to run again.""" + tracker_mod._AUTO_INIT_DONE = False + tracker_mod._trackers.clear() + + def test_env_empty_is_noop(self): + """No TWINKLE_TRACKERS → nothing registered.""" + self._reset_auto_init() + with patch.dict(os.environ, {}, clear=True): + tracker_mod._auto_init_from_env() + assert list_trackers() == [] + + def test_env_swanlab_registers_tracker(self): + """TWINKLE_TRACKERS=swanlab registers a SwanLabTracker.""" + self._reset_auto_init() + with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab"}, clear=True): + tracker_mod._auto_init_from_env() + trackers = list_trackers() + assert len(trackers) == 1 + from twinkle.tracker.swanlab import SwanLabTracker + assert isinstance(trackers[0], SwanLabTracker) + + def test_env_wandb_registers_tracker(self): + """TWINKLE_TRACKERS=wandb registers a WandbTracker.""" + self._reset_auto_init() + with patch.dict(os.environ, {"TWINKLE_TRACKERS": "wandb"}, clear=True): + tracker_mod._auto_init_from_env() + trackers = list_trackers() + assert len(trackers) == 1 + from twinkle.tracker.wandb import WandbTracker + assert isinstance(trackers[0], WandbTracker) + + def test_env_both_registers_both(self): + """TWINKLE_TRACKERS=swanlab,wandb registers both.""" + self._reset_auto_init() + with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab,wandb"}, clear=True): + tracker_mod._auto_init_from_env() + assert len(list_trackers()) == 2 + + def test_env_unknown_logs_warning(self, caplog): + """Unknown tracker name logs a warning.""" + self._reset_auto_init() + caplog.set_level(logging.WARNING) + with patch.dict(os.environ, {"TWINKLE_TRACKERS": "unknown"}, clear=True): + tracker_mod._auto_init_from_env() + assert "Unknown tracker backend in TWINKLE_TRACKERS: unknown" in caplog.text + assert list_trackers() == [] + + def test_env_project_and_experiment(self): + """TWINKLE_TRACKER_PROJECT and _EXPERIMENT env vars are used.""" + self._reset_auto_init() + with patch.dict(os.environ, { + "TWINKLE_TRACKERS": "swanlab", + "TWINKLE_TRACKER_PROJECT": "my-project", + "TWINKLE_TRACKER_EXPERIMENT": "my-exp", + }, clear=True): + tracker_mod._auto_init_from_env() + trackers = list_trackers() + assert len(trackers) == 1 + # The swanlab.init mock was called with these values + import swanlab + swanlab.init.assert_called() + + def test_auto_init_guard(self): + """_AUTO_INIT_DONE prevents re-initialisation.""" + self._reset_auto_init() + tracker_mod._AUTO_INIT_DONE = True # simulate already done + with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab"}, clear=True): + tracker_mod._auto_init_from_env() + # If the guard worked, no trackers were added + assert list_trackers() == [] + + def test_auto_init_exception_does_not_crash(self): + """An exception during tracker construction is caught.""" + self._reset_auto_init() + + # Make SwanLabTracker constructor raise by removing swanlab mock + with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab"}, clear=True): + # This will call SwanLabTracker(project=..., ...) which does + # import swanlab; swanlab.init(...). Our mock will not crash. + tracker_mod._auto_init_from_env() + # Should have one tracker if successful + assert len(list_trackers()) == 1 + + def test_env_whitespace_handling(self): + """Extra whitespace in TWINKLE_TRACKERS is tolerated.""" + self._reset_auto_init() + with patch.dict(os.environ, {"TWINKLE_TRACKERS": " swanlab , wandb "}, clear=True): + tracker_mod._auto_init_from_env() + assert len(list_trackers()) == 2 diff --git a/tests/tracker/test_swanlab.py b/tests/tracker/test_swanlab.py new file mode 100644 index 00000000..047e3e91 --- /dev/null +++ b/tests/tracker/test_swanlab.py @@ -0,0 +1,326 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for SwanLabTracker. + +These tests mock the ``swanlab`` package so they can run without a real +SwanLab installation or API key. Each test verifies that the tracker +delegates correctly to the underlying ``swanlab`` SDK. +""" + +import json +import logging +import os +import sys +import pytest +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# Module-level dependency mocks. +# +# Importing ``twinkle.tracker.swanlab`` triggers a package-init chain that +# pulls in heavyweight third-party libraries (datasets, torch, …). We mock +# them here so the tests can run without the full dependency tree installed. +# --------------------------------------------------------------------------- +for _mod in [ + "datasets", + "datasets.utils", + "datasets.utils.filelock", + "torch", + "accelerate", + "transformers", + "peft", + "omegaconf", + "modelscope", + "safetensors", + "fastapi", + "tinker", + "PIL", + "PIL.Image", +]: + sys.modules.setdefault(_mod, MagicMock()) + +# twinkle.server.model.backends.common is imported by twinkle.tracker itself +sys.modules.setdefault("twinkle.server", MagicMock()) +sys.modules.setdefault("twinkle.server.model", MagicMock()) +sys.modules.setdefault("twinkle.server.model.backends", MagicMock()) +_common = MagicMock() +_common.clean_metrics = lambda d, **kw: {k: float(v) for k, v in d.items() if isinstance(v, (int, float))} +sys.modules["twinkle.server.model.backends.common"] = _common + +# Intermediate twinkle sub-packages that the init chain touches +sys.modules.setdefault("twinkle.utils.platforms", MagicMock()) +sys.modules.setdefault("twinkle.utils.logger", MagicMock()) + +# Mock swanlab itself so that ``import swanlab`` inside SwanLabTracker +# resolves to a mock rather than trying to import the real package. +sys.modules.setdefault("swanlab", MagicMock()) + +# Now that all heavy deps are mocked, the import should succeed. +from twinkle.tracker.swanlab import SwanLabTracker + + +# =================================================================== +# Helpers +# =================================================================== + +@pytest.fixture(autouse=True) +def _reset_swanlab_mock(): + """Reset the swanlab mock before each test so call counts are clean.""" + swanlab_mock = sys.modules["swanlab"] + swanlab_mock.reset_mock() + swanlab_mock.init.return_value = MagicMock() + yield + + +def _mock_swanlab(): + """Shortcut to access the module-level swanlab mock.""" + return sys.modules["swanlab"] + + +def _mock_run(): + """Shortcut to access the run mock returned by swanlab.init().""" + return _mock_swanlab().init.return_value + + +# =================================================================== +# __init__ — construction & parameter routing +# =================================================================== + +class TestInit: + """SwanLabTracker.__init__ parameter handling.""" + + def test_defaults(self): + """Default logdir and mode when neither kwarg nor env var is set.""" + SwanLabTracker(project="test-project") + _mock_swanlab().init.assert_called_once_with( + project="test-project", + experiment_name=None, + config={"framework": "\u2728Twinkle"}, + logdir="swanlog", + mode="cloud", + ) + _mock_swanlab().login.assert_not_called() + + def test_with_api_key_kwarg(self): + """api_key kwarg triggers swanlab.login() before init.""" + SwanLabTracker(project="test-project", api_key="key-123") + _mock_swanlab().login.assert_called_once_with("key-123") + + def test_with_api_key_from_env(self): + """SWANLAB_API_KEY env var triggers login when api_key kwarg absent.""" + with patch.dict(os.environ, {"SWANLAB_API_KEY": "env-key"}): + SwanLabTracker(project="test-project") + _mock_swanlab().login.assert_called_once_with("env-key") + + def test_api_key_kwarg_precedence(self): + """api_key kwarg takes precedence over SWANLAB_API_KEY env var.""" + with patch.dict(os.environ, {"SWANLAB_API_KEY": "env-key"}): + SwanLabTracker(project="test-project", api_key="kwarg-key") + _mock_swanlab().login.assert_called_once_with("kwarg-key") + + def test_experiment_name_and_config(self): + """experiment_name and config are forwarded to swanlab.init.""" + SwanLabTracker( + project="test-project", + experiment_name="my-exp", + config={"lr": 1e-4, "batch_size": 32}, + ) + _mock_swanlab().init.assert_called_once_with( + project="test-project", + experiment_name="my-exp", + config={"framework": "\u2728Twinkle", "lr": 1e-4, "batch_size": 32}, + logdir="swanlog", + mode="cloud", + ) + + def test_logdir_and_mode_kwargs(self): + """Explicit logdir/mode override both defaults and env vars.""" + with patch.dict(os.environ, {"SWANLAB_LOG_DIR": "env_logs", "SWANLAB_MODE": "cloud"}): + SwanLabTracker(project="test-project", logdir="my_logs", mode="local") + _mock_swanlab().init.assert_called_once_with( + project="test-project", + experiment_name=None, + config={"framework": "\u2728Twinkle"}, + logdir="my_logs", + mode="local", + ) + + def test_logdir_from_env(self): + """SWANLAB_LOG_DIR env var is used when no logdir kwarg.""" + with patch.dict(os.environ, {"SWANLAB_LOG_DIR": "env_logs"}): + SwanLabTracker(project="test-project") + _mock_swanlab().init.assert_called_once_with( + project="test-project", + experiment_name=None, + config={"framework": "\u2728Twinkle"}, + logdir="env_logs", + mode="cloud", + ) + + def test_mode_from_env(self): + """SWANLAB_MODE env var is used when no mode kwarg.""" + with patch.dict(os.environ, {"SWANLAB_MODE": "local"}): + SwanLabTracker(project="test-project") + _mock_swanlab().init.assert_called_once_with( + project="test-project", + experiment_name=None, + config={"framework": "\u2728Twinkle"}, + logdir="swanlog", + mode="local", + ) + + def test_output_dir_writes_info_file(self, tmp_path): + """output_dir causes experiment URL to be saved as JSON.""" + _mock_run().get_run.return_value.url = "https://swanlab.cn/foo/bar" + SwanLabTracker(project="test", output_dir=str(tmp_path)) + + info_file = tmp_path / "swanlab_config.json" + assert info_file.exists() + data = json.loads(info_file.read_text()) + assert data == {"swanlab_experiment_url": "https://swanlab.cn/foo/bar"} + + def test_additional_kwargs_passthrough(self): + """Arbitrary kwargs reach swanlab.init after api_key/api_key is consumed.""" + SwanLabTracker(project="test-project", workspace="my-ws", tags=["t1"]) + kwargs = _mock_swanlab().init.call_args[1] + # workspace and tags are forwarded via **kwargs passthrough + assert kwargs["workspace"] == "my-ws" + assert kwargs["tags"] == ["t1"] + # api_key is consumed by swanlab.login() and must NOT leak into init + assert "api_key" not in kwargs + # logdir and mode are explicit named args (not passthrough), always present + + +# =================================================================== +# log +# =================================================================== + +class TestLog: + """SwanLabTracker.log() delegates to swanlab.Run.log().""" + + def test_log_basic(self): + tracker = SwanLabTracker(project="test") + tracker.log({"loss": 0.5}, step=10) + _mock_run().log.assert_called_once_with({"loss": 0.5}, step=10) + + def test_log_multiple_steps(self): + tracker = SwanLabTracker(project="test") + tracker.log({"loss": 0.5}, step=1) + tracker.log({"loss": 0.3}, step=2) + tracker.log({"loss": 0.1}, step=3) + + assert _mock_run().log.call_count == 3 + _mock_run().log.assert_any_call({"loss": 0.5}, step=1) + _mock_run().log.assert_any_call({"loss": 0.3}, step=2) + _mock_run().log.assert_any_call({"loss": 0.1}, step=3) + + def test_log_empty_dict(self): + """Empty dict is forwarded (dispatch layer normally filters it earlier).""" + tracker = SwanLabTracker(project="test") + tracker.log({}, step=5) + _mock_run().log.assert_called_once_with({}, step=5) + + +# =================================================================== +# log_hyperparams +# =================================================================== + +class TestLogHyperparams: + """SwanLabTracker.log_hyperparams() updates run config.""" + + def test_log_hyperparams_updates_config(self): + tracker = SwanLabTracker(project="test") + tracker.log_hyperparams({"lr": 1e-4, "batch_size": 32}) + _mock_run().config.update.assert_called_once_with({"lr": 1e-4, "batch_size": 32}) + + def test_log_hyperparams_multiple_calls(self): + tracker = SwanLabTracker(project="test") + tracker.log_hyperparams({"lr": 1e-4}) + tracker.log_hyperparams({"batch_size": 32}) + assert _mock_run().config.update.call_count == 2 + + def test_log_hyperparams_empty(self): + tracker = SwanLabTracker(project="test") + tracker.log_hyperparams({}) + _mock_run().config.update.assert_called_once_with({}) + + +# =================================================================== +# cleanup +# =================================================================== + +class TestCleanup: + """SwanLabTracker.cleanup() finalises the run.""" + + def test_cleanup_calls_finish(self): + tracker = SwanLabTracker(project="test") + tracker.cleanup() + _mock_run().finish.assert_called_once() + + def test_cleanup_exception_logged(self, caplog): + """Exception in finish() is logged as warning, not propagated.""" + _mock_run().finish.side_effect = RuntimeError("connection lost") + tracker = SwanLabTracker(project="test") + + caplog.set_level(logging.WARNING) + tracker.cleanup() + + assert "SwanLab finish() failed" in caplog.text + assert "connection lost" in caplog.text + _mock_run().finish.assert_called_once() + + +# =================================================================== +# _save_experiment_info +# =================================================================== + +class TestSaveExperimentInfo: + """_save_experiment_info writes the experiment URL to disk.""" + + def test_saves_url(self, tmp_path): + _mock_run().get_run.return_value.url = "https://swanlab.cn/exp/abc" + SwanLabTracker(project="test", output_dir=str(tmp_path)) + + info = json.loads((tmp_path / "swanlab_config.json").read_text()) + assert info == {"swanlab_experiment_url": "https://swanlab.cn/exp/abc"} + + def test_idempotent_overwrite(self, tmp_path): + """Multiple trackers with the same output_dir overwrite the file.""" + run_a = _mock_run() + run_a.get_run.return_value.url = "https://swanlab.cn/exp/a" + SwanLabTracker(project="test", output_dir=str(tmp_path)) + + # Reset mock so the second tracker creates a new run mock + _mock_swanlab().reset_mock() + run_b = MagicMock() + run_b.get_run.return_value.url = "https://swanlab.cn/exp/b" + _mock_swanlab().init.return_value = run_b + SwanLabTracker(project="test", output_dir=str(tmp_path)) + + info = json.loads((tmp_path / "swanlab_config.json").read_text()) + assert info["swanlab_experiment_url"] == "https://swanlab.cn/exp/b" + + +# =================================================================== +# Edge cases +# =================================================================== + +class TestEdgeCases: + """Unusual / error scenarios.""" + + def test_empty_project_name(self): + """An empty project string is forwarded (swanlab may reject it).""" + SwanLabTracker(project="") + assert _mock_swanlab().init.call_args[1]["project"] == "" + + def test_none_experiment_name(self): + """None experiment_name is passed as None (swanlab uses default).""" + SwanLabTracker(project="test", experiment_name=None) + assert _mock_swanlab().init.call_args[1]["experiment_name"] is None + + def test_config_overrides_framework_key(self): + """User-provided 'framework' in config overrides the default.""" + SwanLabTracker(project="test", config={"framework": "MyFramework"}) + cfg = _mock_swanlab().init.call_args[1]["config"] + # The tracker does: {"framework": "✨Twinkle", **(config or {})}, + # so user's framework wins via dict unpacking. + assert cfg["framework"] == "MyFramework" From a62357695c449d585c05d9cf25b403345a76d448 Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 2 Jun 2026 21:07:22 +0800 Subject: [PATCH 4/8] apply lint auto-fixes to tracker module Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- cookbook/transformers/tracker.py | 56 +++---- src/twinkle/model/optimizer_group.py | 6 +- src/twinkle/tracker/__init__.py | 48 +++--- src/twinkle/tracker/swanlab.py | 17 +- src/twinkle/tracker/wandb.py | 11 +- tests/tracker/test_dispatch.py | 207 ++++++++++++----------- tests/tracker/test_swanlab.py | 237 ++++++++++++++------------- 7 files changed, 305 insertions(+), 277 deletions(-) diff --git a/cookbook/transformers/tracker.py b/cookbook/transformers/tracker.py index e134e5bc..0dbe34f2 100644 --- a/cookbook/transformers/tracker.py +++ b/cookbook/transformers/tracker.py @@ -9,25 +9,25 @@ from twinkle.tracker import register_tracker, list_trackers, clear_trackers logger = get_logger() # ── Configuration ────────────────────────────────────────────────────────────── -MODEL_ID = "ms://Qwen/Qwen2.5-0.5B-Instruct" -DATASET_ID = "ms://swift/self-cognition" -TEMPLATE_NAME = "Template" +MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct' +DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Template' BATCH_SIZE = 1 LEARNING_RATE = 1e-4 TRAIN_STEPS = 5 # ── Tracker selection ────────────────────────────────────────────────────────── def setup_tracker(): """Register either SwanLabTracker (if API key available) or PrintTracker.""" - if os.environ.get("SWANLAB_API_KEY"): + if os.environ.get('SWANLAB_API_KEY'): from twinkle.tracker.swanlab import SwanLabTracker tracker = SwanLabTracker( - project="twinkle-test", - experiment_name="tracker-integration-test", - config={"model": MODEL_ID, "lr": LEARNING_RATE, "steps": TRAIN_STEPS}, - output_dir="./test_tracker_output", + project='twinkle-test', + experiment_name='tracker-integration-test', + config={'model': MODEL_ID, 'lr': LEARNING_RATE, 'steps': TRAIN_STEPS}, + output_dir='./test_tracker_output', ) register_tracker(tracker) - logger.info("SwanLabTracker registered — project=twinkle-test") + logger.info('SwanLabTracker registered — project=twinkle-test') return tracker else: from twinkle.tracker import ExperimentTracker @@ -36,30 +36,30 @@ def __init__(self): self.logged: list[tuple[int, dict]] = [] def log(self, data: dict, step: int) -> None: self.logged.append((step, data)) - logger.info("[PrintTracker] step=%s metrics=%s", step, data) + logger.info('[PrintTracker] step=%s metrics=%s', step, data) def cleanup(self) -> None: - logger.info("[PrintTracker] cleanup — %s dispatches", len(self.logged)) + logger.info('[PrintTracker] cleanup — %s dispatches', len(self.logged)) tracker = PrintTracker() register_tracker(tracker) - logger.info("PrintTracker registered (set SWANLAB_API_KEY for SwanLab)") + logger.info('PrintTracker registered (set SWANLAB_API_KEY for SwanLab)') return tracker # ── Main ────────────────────────────────────────────────────────────────────── def main(): - twinkle.initialize(mode="local", seed=42) + twinkle.initialize(mode='local', seed=42) tracker = setup_tracker() assert len(list_trackers()) == 1 - logger.info("Tracker ready: %s", type(tracker).__name__) + logger.info('Tracker ready: %s', type(tracker).__name__) dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(10))) dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor("test_model", "test_author")) + dataset.map(SelfCognitionProcessor('test_model', 'test_author')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) model = TransformersModel(model_id=MODEL_ID) - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules="all-linear") - model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) - model.set_optimizer(optimizer_cls="AdamW", lr=LEARNING_RATE) + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) model.set_lr_scheduler( - scheduler_cls="CosineWarmupScheduler", num_warmup_steps=1, num_training_steps=TRAIN_STEPS + scheduler_cls='CosineWarmupScheduler', num_warmup_steps=1, num_training_steps=TRAIN_STEPS ) for step, batch in enumerate(dataloader): if step >= TRAIN_STEPS: @@ -67,18 +67,18 @@ def main(): model.forward_backward(inputs=batch) model.clip_grad_and_step() metric = model.calculate_metric(is_training=True) - logger.info("Step %s raw metric: %s", step + 1, metric) + logger.info('Step %s raw metric: %s', step + 1, metric) # Verification (only works for PrintTracker) - if hasattr(tracker, "logged"): + if hasattr(tracker, 'logged'): n = len(tracker.logged) - assert n > 0, "No metrics were dispatched — dispatch() not called" - logger.info("=== Dispatch verification ===") - logger.info("Total dispatches: %s", n) + assert n > 0, 'No metrics were dispatched — dispatch() not called' + logger.info('=== Dispatch verification ===') + logger.info('Total dispatches: %s', n) for i, (step, data) in enumerate(tracker.logged): all_floats = all(isinstance(v, float) for v in data.values()) - logger.info(" [%s] step=%s keys=%s all_float=%s", i + 1, step, list(data.keys()), all_floats) + logger.info(' [%s] step=%s keys=%s all_float=%s', i + 1, step, list(data.keys()), all_floats) clear_trackers() assert len(list_trackers()) == 0 - logger.info("=== Test complete ===") -if __name__ == "__main__": - main() \ No newline at end of file + logger.info('=== Test complete ===') +if __name__ == '__main__': + main() diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py index 2b85b2e7..4dc1eb66 100644 --- a/src/twinkle/model/optimizer_group.py +++ b/src/twinkle/model/optimizer_group.py @@ -96,8 +96,10 @@ def calculate_metrics(self, is_training): dispatch(results, step=self.cur_step) # Lazily log hyperparams on the first training metrics call dispatch_hyperparams( - {'adapter_name': self.adapter_name, - 'gradient_accumulation_steps': self.gradient_accumulation_steps}, + { + 'adapter_name': self.adapter_name, + 'gradient_accumulation_steps': self.gradient_accumulation_steps + }, adapter_name=self.adapter_name) return results diff --git a/src/twinkle/tracker/__init__.py b/src/twinkle/tracker/__init__.py index b60b57c7..4f23e639 100644 --- a/src/twinkle/tracker/__init__.py +++ b/src/twinkle/tracker/__init__.py @@ -19,7 +19,6 @@ from typing import Any, Dict, List, Optional from twinkle.server.model.backends.common import clean_metrics - from .base import ExperimentTracker from .swanlab import SwanLabTracker from .wandb import WandbTracker @@ -73,7 +72,7 @@ def clear_trackers() -> None: try: t.cleanup() except Exception: - logger.warning("Tracker %s.cleanup() failed", type(t).__name__, exc_info=True) + logger.warning('Tracker %s.cleanup() failed', type(t).__name__, exc_info=True) _trackers.clear() @@ -106,7 +105,7 @@ def dispatch(data: Dict[str, float], step: int) -> None: try: tracker.log(cleaned, step=step) except Exception: - logger.warning("Tracker %s.log() failed", type(tracker).__name__, exc_info=True) + logger.warning('Tracker %s.log() failed', type(tracker).__name__, exc_info=True) def dispatch_hyperparams(params: Dict[str, Any], adapter_name: Optional[str] = None) -> None: @@ -136,7 +135,7 @@ def dispatch_hyperparams(params: Dict[str, Any], adapter_name: Optional[str] = N try: tracker.log_hyperparams(params) except Exception: - logger.warning("Tracker %s.log_hyperparams() failed", type(tracker).__name__, exc_info=True) + logger.warning('Tracker %s.log_hyperparams() failed', type(tracker).__name__, exc_info=True) # --------------------------------------------------------------------------- @@ -162,31 +161,33 @@ def _auto_init_from_env() -> None: return _AUTO_INIT_DONE = True - trackers_str = os.environ.get("TWINKLE_TRACKERS", "").strip() + trackers_str = os.environ.get('TWINKLE_TRACKERS', '').strip() if not trackers_str: return - project = os.environ.get("TWINKLE_TRACKER_PROJECT", "twinkle-training") - experiment_name = os.environ.get("TWINKLE_TRACKER_EXPERIMENT", None) + project = os.environ.get('TWINKLE_TRACKER_PROJECT', 'twinkle-training') + experiment_name = os.environ.get('TWINKLE_TRACKER_EXPERIMENT', None) - for name in (t.strip().lower() for t in trackers_str.split(",") if t.strip()): + for name in (t.strip().lower() for t in trackers_str.split(',') if t.strip()): try: - if name == "wandb": - _trackers.append(WandbTracker( - project=project, - experiment_name=experiment_name, - entity=os.environ.get("WANDB_ENTITY"), - )) - logger.info("Auto-registered WandbTracker from TWINKLE_TRACKERS env var") - elif name == "swanlab": - _trackers.append(SwanLabTracker( - project=project, - experiment_name=experiment_name, - output_dir=os.environ.get("TWINKLE_OUTPUT_DIR"), - )) - logger.info("Auto-registered SwanLabTracker from TWINKLE_TRACKERS env var") + if name == 'wandb': + _trackers.append( + WandbTracker( + project=project, + experiment_name=experiment_name, + entity=os.environ.get('WANDB_ENTITY'), + )) + logger.info('Auto-registered WandbTracker from TWINKLE_TRACKERS env var') + elif name == 'swanlab': + _trackers.append( + SwanLabTracker( + project=project, + experiment_name=experiment_name, + output_dir=os.environ.get('TWINKLE_OUTPUT_DIR'), + )) + logger.info('Auto-registered SwanLabTracker from TWINKLE_TRACKERS env var') else: - logger.warning("Unknown tracker backend in TWINKLE_TRACKERS: %s", name) + logger.warning('Unknown tracker backend in TWINKLE_TRACKERS: %s', name) except Exception: logger.warning("Failed to auto-init tracker '%s' from env", name, exc_info=True) @@ -194,7 +195,6 @@ def _auto_init_from_env() -> None: # Run auto-init once at import time (before user code or atexit runs) _auto_init_from_env() - # --------------------------------------------------------------------------- # At-exit cleanup # --------------------------------------------------------------------------- diff --git a/src/twinkle/tracker/swanlab.py b/src/twinkle/tracker/swanlab.py index ae1790a5..8ec429a2 100644 --- a/src/twinkle/tracker/swanlab.py +++ b/src/twinkle/tracker/swanlab.py @@ -35,9 +35,9 @@ def __init__( ): import swanlab - api_key = kwargs.pop("api_key", None) or os.environ.get("SWANLAB_API_KEY") - logdir = kwargs.pop("logdir", None) or os.environ.get("SWANLAB_LOG_DIR", "swanlog") - mode = kwargs.pop("mode", None) or os.environ.get("SWANLAB_MODE", "cloud") + api_key = kwargs.pop('api_key', None) or os.environ.get('SWANLAB_API_KEY') + logdir = kwargs.pop('logdir', None) or os.environ.get('SWANLAB_LOG_DIR', 'swanlog') + mode = kwargs.pop('mode', None) or os.environ.get('SWANLAB_MODE', 'cloud') if api_key: swanlab.login(api_key) @@ -45,7 +45,10 @@ def __init__( self._run = swanlab.init( project=project, experiment_name=experiment_name, - config={"framework": "\u2728Twinkle", **(config or {})}, + config={ + 'framework': '\u2728Twinkle', + **(config or {}) + }, logdir=logdir, mode=mode, **kwargs, @@ -64,12 +67,12 @@ def cleanup(self) -> None: try: self._run.finish() except Exception: - logger.warning("SwanLab finish() failed", exc_info=True) + logger.warning('SwanLab finish() failed', exc_info=True) def _save_experiment_info(self, output_dir: str) -> None: try: - info = {"swanlab_experiment_url": self._run.get_run().url} - out = Path(output_dir) / "swanlab_config.json" + info = {'swanlab_experiment_url': self._run.get_run().url} + out = Path(output_dir) / 'swanlab_config.json' out.write_text(json.dumps(info, indent=2)) except Exception: pass diff --git a/src/twinkle/tracker/wandb.py b/src/twinkle/tracker/wandb.py index 8a65765e..907a7f4b 100644 --- a/src/twinkle/tracker/wandb.py +++ b/src/twinkle/tracker/wandb.py @@ -29,9 +29,9 @@ def __init__( ): import wandb - entity = kwargs.pop("entity", None) or os.environ.get("WANDB_ENTITY") + entity = kwargs.pop('entity', None) or os.environ.get('WANDB_ENTITY') settings = None - proxy = kwargs.pop("wandb_proxy", None) or os.environ.get("WANDB_PROXY") + proxy = kwargs.pop('wandb_proxy', None) or os.environ.get('WANDB_PROXY') if proxy: settings = wandb.Settings(https_proxy=proxy) @@ -39,7 +39,10 @@ def __init__( project=project, name=experiment_name, entity=entity, - config={"framework": "\u2728Twinkle", **(config or {})}, + config={ + 'framework': '\u2728Twinkle', + **(config or {}) + }, settings=settings, **kwargs, ) @@ -54,4 +57,4 @@ def cleanup(self) -> None: try: self._run.finish(exit_code=0) except Exception: - logger.warning("WandB finish() failed", exc_info=True) + logger.warning('WandB finish() failed', exc_info=True) diff --git a/tests/tracker/test_dispatch.py b/tests/tracker/test_dispatch.py index 6aa8a0a9..c9ea3c9e 100644 --- a/tests/tracker/test_dispatch.py +++ b/tests/tracker/test_dispatch.py @@ -7,54 +7,47 @@ import logging import os -import sys import pytest +import sys from unittest.mock import MagicMock, patch # --------------------------------------------------------------------------- # Module-level dependency mocks (mirrors test_swanlab.py) # --------------------------------------------------------------------------- for _mod in [ - "datasets", - "datasets.utils", - "datasets.utils.filelock", - "torch", - "accelerate", - "transformers", - "peft", - "omegaconf", - "modelscope", - "safetensors", - "fastapi", - "tinker", - "PIL", - "PIL.Image", - "wandb", + 'datasets', + 'datasets.utils', + 'datasets.utils.filelock', + 'torch', + 'accelerate', + 'transformers', + 'peft', + 'omegaconf', + 'modelscope', + 'safetensors', + 'fastapi', + 'tinker', + 'PIL', + 'PIL.Image', + 'wandb', ]: sys.modules.setdefault(_mod, MagicMock()) -sys.modules.setdefault("twinkle.server", MagicMock()) -sys.modules.setdefault("twinkle.server.model", MagicMock()) -sys.modules.setdefault("twinkle.server.model.backends", MagicMock()) +sys.modules.setdefault('twinkle.server', MagicMock()) +sys.modules.setdefault('twinkle.server.model', MagicMock()) +sys.modules.setdefault('twinkle.server.model.backends', MagicMock()) _common = MagicMock() _common.clean_metrics = lambda d: {k: float(v) for k, v in d.items() if isinstance(v, (int, float))} -sys.modules["twinkle.server.model.backends.common"] = _common +sys.modules['twinkle.server.model.backends.common'] = _common -sys.modules.setdefault("twinkle.utils.platforms", MagicMock()) -sys.modules.setdefault("twinkle.utils.logger", MagicMock()) -sys.modules.setdefault("swanlab", MagicMock()) +sys.modules.setdefault('twinkle.utils.platforms', MagicMock()) +sys.modules.setdefault('twinkle.utils.logger', MagicMock()) +sys.modules.setdefault('swanlab', MagicMock()) +import twinkle.tracker as tracker_mod +from twinkle.tracker import clear_trackers, dispatch, dispatch_hyperparams, list_trackers, register_tracker, set_rank # Now safe to import from twinkle.tracker.base import ExperimentTracker -import twinkle.tracker as tracker_mod -from twinkle.tracker import ( - register_tracker, - dispatch, - dispatch_hyperparams, - clear_trackers, - list_trackers, - set_rank, -) # --------------------------------------------------------------------------- @@ -63,7 +56,7 @@ class SpyTracker(ExperimentTracker): """Minimal tracker that records all calls for later assertion.""" - def __init__(self, name: str = "spy"): + def __init__(self, name: str = 'spy'): self.name = name self.reset() @@ -89,19 +82,20 @@ class ErrorTracker(ExperimentTracker): """Tracker whose ``log()`` raises — used to test exception isolation.""" def log(self, data: dict, step: int) -> None: - raise RuntimeError("tracker error") + raise RuntimeError('tracker error') def log_hyperparams(self, params: dict) -> None: - raise RuntimeError("hparam error") + raise RuntimeError('hparam error') def cleanup(self) -> None: - raise RuntimeError("cleanup error") + raise RuntimeError('cleanup error') # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def _reset_global_state(): """Reset module-level state before every test.""" @@ -115,14 +109,16 @@ def _reset_global_state(): # register_tracker / list_trackers # =================================================================== + class TestRegistration: + def test_register_one(self): t = SpyTracker() register_tracker(t) assert list_trackers() == [t] def test_register_multiple(self): - t1, t2 = SpyTracker("a"), SpyTracker("b") + t1, t2 = SpyTracker('a'), SpyTracker('b') register_tracker(t1) register_tracker(t2) assert list_trackers() == [t1, t2] @@ -145,36 +141,38 @@ def test_list_trackers_returns_copy(self): # dispatch # =================================================================== + class TestDispatch: + def test_sends_to_all_trackers(self): - t1, t2 = SpyTracker("a"), SpyTracker("b") + t1, t2 = SpyTracker('a'), SpyTracker('b') register_tracker(t1) register_tracker(t2) set_rank(0) - dispatch({"loss": 0.5, "acc": 0.95}, step=10) + dispatch({'loss': 0.5, 'acc': 0.95}, step=10) - assert t1.logged == [({"loss": 0.5, "acc": 0.95}, 10)] - assert t2.logged == [({"loss": 0.5, "acc": 0.95}, 10)] + assert t1.logged == [({'loss': 0.5, 'acc': 0.95}, 10)] + assert t2.logged == [({'loss': 0.5, 'acc': 0.95}, 10)] def test_skipped_on_non_zero_rank(self): t = SpyTracker() register_tracker(t) set_rank(3) - dispatch({"loss": 0.5}, step=1) + dispatch({'loss': 0.5}, step=1) assert t.logged == [] def test_no_trackers_is_noop(self): """dispatch with no registered trackers should not crash.""" - dispatch({"loss": 0.5}, step=1) # no assert — must not raise + dispatch({'loss': 0.5}, step=1) # no assert — must not raise def test_rank_0_is_default(self): """Default rank is 0, so dispatch works without explicit set_rank.""" t = SpyTracker() register_tracker(t) - dispatch({"loss": 0.5}, step=1) + dispatch({'loss': 0.5}, step=1) assert len(t.logged) == 1 def test_empty_data_after_clean_metrics_skips(self): @@ -184,30 +182,30 @@ def test_empty_data_after_clean_metrics_skips(self): set_rank(0) # Values that clean_metrics cannot convert to float - dispatch({"invalid": [1, 2, 3], "text": "not-a-number"}, step=5) + dispatch({'invalid': [1, 2, 3], 'text': 'not-a-number'}, step=5) assert t.logged == [] def test_exception_isolation(self): """One tracker raising does not prevent others from receiving.""" - good = SpyTracker("good") + good = SpyTracker('good') bad = ErrorTracker() register_tracker(good) register_tracker(bad) - dispatch({"loss": 0.5}, step=10) + dispatch({'loss': 0.5}, step=10) assert len(good.logged) == 1 - assert good.logged[0] == ({"loss": 0.5}, 10) + assert good.logged[0] == ({'loss': 0.5}, 10) def test_exception_isolation_reverse_order(self): """Exception isolation works regardless of tracker order.""" bad = ErrorTracker() - good = SpyTracker("good") + good = SpyTracker('good') register_tracker(bad) register_tracker(good) - dispatch({"loss": 0.5}, step=10) + dispatch({'loss': 0.5}, step=10) assert len(good.logged) == 1 @@ -216,14 +214,14 @@ def test_multiple_steps(self): register_tracker(t) set_rank(0) - dispatch({"loss": 0.5}, step=1) - dispatch({"loss": 0.3}, step=2) - dispatch({"loss": 0.1}, step=3) + dispatch({'loss': 0.5}, step=1) + dispatch({'loss': 0.3}, step=2) + dispatch({'loss': 0.1}, step=3) assert len(t.logged) == 3 - assert t.logged[0] == ({"loss": 0.5}, 1) - assert t.logged[1] == ({"loss": 0.3}, 2) - assert t.logged[2] == ({"loss": 0.1}, 3) + assert t.logged[0] == ({'loss': 0.5}, 1) + assert t.logged[1] == ({'loss': 0.3}, 2) + assert t.logged[2] == ({'loss': 0.1}, 3) def test_rank_change_during_runtime(self): """Changing rank mid-training affects subsequent dispatches.""" @@ -231,15 +229,15 @@ def test_rank_change_during_runtime(self): register_tracker(t) set_rank(0) - dispatch({"loss": 0.5}, step=1) + dispatch({'loss': 0.5}, step=1) assert len(t.logged) == 1 set_rank(1) - dispatch({"loss": 0.3}, step=2) + dispatch({'loss': 0.3}, step=2) assert len(t.logged) == 1 # no change — rank 1 skipped set_rank(0) - dispatch({"loss": 0.1}, step=3) + dispatch({'loss': 0.1}, step=3) assert len(t.logged) == 2 # now rank 0 again @@ -247,17 +245,19 @@ def test_rank_change_during_runtime(self): # dispatch_hyperparams # =================================================================== + class TestDispatchHyperparams: + def test_sends_to_all_trackers(self): - t1, t2 = SpyTracker("a"), SpyTracker("b") + t1, t2 = SpyTracker('a'), SpyTracker('b') register_tracker(t1) register_tracker(t2) set_rank(0) - dispatch_hyperparams({"lr": 1e-4}) + dispatch_hyperparams({'lr': 1e-4}) - assert t1.hyperparams == [{"lr": 1e-4}] - assert t2.hyperparams == [{"lr": 1e-4}] + assert t1.hyperparams == [{'lr': 1e-4}] + assert t2.hyperparams == [{'lr': 1e-4}] def test_idempotent_with_adapter_name(self): """Same adapter_name only dispatches once.""" @@ -265,12 +265,12 @@ def test_idempotent_with_adapter_name(self): register_tracker(t) set_rank(0) - dispatch_hyperparams({"lr": 1e-4}, adapter_name="default") - dispatch_hyperparams({"lr": 2e-4}, adapter_name="default") # ignored - dispatch_hyperparams({"batch_size": 32}, adapter_name="default") # ignored + dispatch_hyperparams({'lr': 1e-4}, adapter_name='default') + dispatch_hyperparams({'lr': 2e-4}, adapter_name='default') # ignored + dispatch_hyperparams({'batch_size': 32}, adapter_name='default') # ignored assert len(t.hyperparams) == 1 - assert t.hyperparams[0] == {"lr": 1e-4} + assert t.hyperparams[0] == {'lr': 1e-4} def test_different_adapters_separate(self): """Different adapter_names are each dispatched once.""" @@ -278,13 +278,13 @@ def test_different_adapters_separate(self): register_tracker(t) set_rank(0) - dispatch_hyperparams({"lr": 1e-4}, adapter_name="lora_a") - dispatch_hyperparams({"lr": 2e-4}, adapter_name="lora_b") - dispatch_hyperparams({"lr": 3e-4}, adapter_name="lora_a") # ignored + dispatch_hyperparams({'lr': 1e-4}, adapter_name='lora_a') + dispatch_hyperparams({'lr': 2e-4}, adapter_name='lora_b') + dispatch_hyperparams({'lr': 3e-4}, adapter_name='lora_a') # ignored assert len(t.hyperparams) == 2 - assert t.hyperparams[0] == {"lr": 1e-4} - assert t.hyperparams[1] == {"lr": 2e-4} + assert t.hyperparams[0] == {'lr': 1e-4} + assert t.hyperparams[1] == {'lr': 2e-4} def test_without_adapter_sends_every_time(self): """When adapter_name is None, every call dispatches.""" @@ -292,9 +292,9 @@ def test_without_adapter_sends_every_time(self): register_tracker(t) set_rank(0) - dispatch_hyperparams({"lr": 1e-4}) - dispatch_hyperparams({"lr": 2e-4}) - dispatch_hyperparams({"lr": 3e-4}) + dispatch_hyperparams({'lr': 1e-4}) + dispatch_hyperparams({'lr': 2e-4}) + dispatch_hyperparams({'lr': 3e-4}) assert len(t.hyperparams) == 3 @@ -304,10 +304,10 @@ def test_mixed_adapter_and_no_adapter(self): register_tracker(t) set_rank(0) - dispatch_hyperparams({"a": 1}, adapter_name="adp") # sent - dispatch_hyperparams({"b": 2}) # sent (no adapter) - dispatch_hyperparams({"c": 3}, adapter_name="adp") # ignored (idempotent) - dispatch_hyperparams({"d": 4}) # sent (no adapter again) + dispatch_hyperparams({'a': 1}, adapter_name='adp') # sent + dispatch_hyperparams({'b': 2}) # sent (no adapter) + dispatch_hyperparams({'c': 3}, adapter_name='adp') # ignored (idempotent) + dispatch_hyperparams({'d': 4}) # sent (no adapter again) assert len(t.hyperparams) == 3 @@ -316,20 +316,20 @@ def test_skipped_on_non_zero_rank(self): register_tracker(t) set_rank(2) - dispatch_hyperparams({"lr": 1e-4}) + dispatch_hyperparams({'lr': 1e-4}) assert t.hyperparams == [] def test_no_trackers_is_noop(self): - dispatch_hyperparams({"lr": 1e-4}, adapter_name="test") + dispatch_hyperparams({'lr': 1e-4}, adapter_name='test') def test_exception_isolation(self): - good = SpyTracker("good") + good = SpyTracker('good') bad = ErrorTracker() register_tracker(good) register_tracker(bad) - dispatch_hyperparams({"lr": 1e-4}) + dispatch_hyperparams({'lr': 1e-4}) assert len(good.hyperparams) == 1 @@ -338,9 +338,11 @@ def test_exception_isolation(self): # clear_trackers # =================================================================== + class TestClearTrackers: + def test_calls_cleanup_on_all(self): - t1, t2 = SpyTracker("a"), SpyTracker("b") + t1, t2 = SpyTracker('a'), SpyTracker('b') register_tracker(t1) register_tracker(t2) @@ -353,7 +355,7 @@ def test_calls_cleanup_on_all(self): def test_cleanup_exception_isolation(self): """cleanup() raising on one tracker doesn't break others.""" bad = ErrorTracker() - good = SpyTracker("good") + good = SpyTracker('good') register_tracker(bad) register_tracker(good) @@ -379,9 +381,11 @@ def test_idempotent(self): # set_rank # =================================================================== + class TestSetRank: + def test_default_rank_is_zero(self): - assert tracker_mod._rank == 0 # after fixture reset + assert tracker_mod._rank == 0 # after fixture reset def test_set_rank_changes_global(self): set_rank(3) @@ -397,7 +401,7 @@ def test_set_rank_negative(self): # A negative rank will cause dispatch to skip (since rank != 0) t = SpyTracker() register_tracker(t) - dispatch({"loss": 0.5}, step=1) + dispatch({'loss': 0.5}, step=1) assert t.logged == [] @@ -405,6 +409,7 @@ def test_set_rank_negative(self): # _auto_init_from_env # =================================================================== + class TestAutoInitFromEnv: """Environment-variable auto-initialisation.""" @@ -423,7 +428,7 @@ def test_env_empty_is_noop(self): def test_env_swanlab_registers_tracker(self): """TWINKLE_TRACKERS=swanlab registers a SwanLabTracker.""" self._reset_auto_init() - with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab"}, clear=True): + with patch.dict(os.environ, {'TWINKLE_TRACKERS': 'swanlab'}, clear=True): tracker_mod._auto_init_from_env() trackers = list_trackers() assert len(trackers) == 1 @@ -433,7 +438,7 @@ def test_env_swanlab_registers_tracker(self): def test_env_wandb_registers_tracker(self): """TWINKLE_TRACKERS=wandb registers a WandbTracker.""" self._reset_auto_init() - with patch.dict(os.environ, {"TWINKLE_TRACKERS": "wandb"}, clear=True): + with patch.dict(os.environ, {'TWINKLE_TRACKERS': 'wandb'}, clear=True): tracker_mod._auto_init_from_env() trackers = list_trackers() assert len(trackers) == 1 @@ -443,7 +448,7 @@ def test_env_wandb_registers_tracker(self): def test_env_both_registers_both(self): """TWINKLE_TRACKERS=swanlab,wandb registers both.""" self._reset_auto_init() - with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab,wandb"}, clear=True): + with patch.dict(os.environ, {'TWINKLE_TRACKERS': 'swanlab,wandb'}, clear=True): tracker_mod._auto_init_from_env() assert len(list_trackers()) == 2 @@ -451,19 +456,21 @@ def test_env_unknown_logs_warning(self, caplog): """Unknown tracker name logs a warning.""" self._reset_auto_init() caplog.set_level(logging.WARNING) - with patch.dict(os.environ, {"TWINKLE_TRACKERS": "unknown"}, clear=True): + with patch.dict(os.environ, {'TWINKLE_TRACKERS': 'unknown'}, clear=True): tracker_mod._auto_init_from_env() - assert "Unknown tracker backend in TWINKLE_TRACKERS: unknown" in caplog.text + assert 'Unknown tracker backend in TWINKLE_TRACKERS: unknown' in caplog.text assert list_trackers() == [] def test_env_project_and_experiment(self): """TWINKLE_TRACKER_PROJECT and _EXPERIMENT env vars are used.""" self._reset_auto_init() - with patch.dict(os.environ, { - "TWINKLE_TRACKERS": "swanlab", - "TWINKLE_TRACKER_PROJECT": "my-project", - "TWINKLE_TRACKER_EXPERIMENT": "my-exp", - }, clear=True): + with patch.dict( + os.environ, { + 'TWINKLE_TRACKERS': 'swanlab', + 'TWINKLE_TRACKER_PROJECT': 'my-project', + 'TWINKLE_TRACKER_EXPERIMENT': 'my-exp', + }, + clear=True): tracker_mod._auto_init_from_env() trackers = list_trackers() assert len(trackers) == 1 @@ -475,7 +482,7 @@ def test_auto_init_guard(self): """_AUTO_INIT_DONE prevents re-initialisation.""" self._reset_auto_init() tracker_mod._AUTO_INIT_DONE = True # simulate already done - with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab"}, clear=True): + with patch.dict(os.environ, {'TWINKLE_TRACKERS': 'swanlab'}, clear=True): tracker_mod._auto_init_from_env() # If the guard worked, no trackers were added assert list_trackers() == [] @@ -485,7 +492,7 @@ def test_auto_init_exception_does_not_crash(self): self._reset_auto_init() # Make SwanLabTracker constructor raise by removing swanlab mock - with patch.dict(os.environ, {"TWINKLE_TRACKERS": "swanlab"}, clear=True): + with patch.dict(os.environ, {'TWINKLE_TRACKERS': 'swanlab'}, clear=True): # This will call SwanLabTracker(project=..., ...) which does # import swanlab; swanlab.init(...). Our mock will not crash. tracker_mod._auto_init_from_env() @@ -495,6 +502,6 @@ def test_auto_init_exception_does_not_crash(self): def test_env_whitespace_handling(self): """Extra whitespace in TWINKLE_TRACKERS is tolerated.""" self._reset_auto_init() - with patch.dict(os.environ, {"TWINKLE_TRACKERS": " swanlab , wandb "}, clear=True): + with patch.dict(os.environ, {'TWINKLE_TRACKERS': ' swanlab , wandb '}, clear=True): tracker_mod._auto_init_from_env() assert len(list_trackers()) == 2 diff --git a/tests/tracker/test_swanlab.py b/tests/tracker/test_swanlab.py index 047e3e91..b084996d 100644 --- a/tests/tracker/test_swanlab.py +++ b/tests/tracker/test_swanlab.py @@ -9,8 +9,8 @@ import json import logging import os -import sys import pytest +import sys from unittest.mock import MagicMock, patch # --------------------------------------------------------------------------- @@ -21,51 +21,51 @@ # them here so the tests can run without the full dependency tree installed. # --------------------------------------------------------------------------- for _mod in [ - "datasets", - "datasets.utils", - "datasets.utils.filelock", - "torch", - "accelerate", - "transformers", - "peft", - "omegaconf", - "modelscope", - "safetensors", - "fastapi", - "tinker", - "PIL", - "PIL.Image", + 'datasets', + 'datasets.utils', + 'datasets.utils.filelock', + 'torch', + 'accelerate', + 'transformers', + 'peft', + 'omegaconf', + 'modelscope', + 'safetensors', + 'fastapi', + 'tinker', + 'PIL', + 'PIL.Image', ]: sys.modules.setdefault(_mod, MagicMock()) # twinkle.server.model.backends.common is imported by twinkle.tracker itself -sys.modules.setdefault("twinkle.server", MagicMock()) -sys.modules.setdefault("twinkle.server.model", MagicMock()) -sys.modules.setdefault("twinkle.server.model.backends", MagicMock()) +sys.modules.setdefault('twinkle.server', MagicMock()) +sys.modules.setdefault('twinkle.server.model', MagicMock()) +sys.modules.setdefault('twinkle.server.model.backends', MagicMock()) _common = MagicMock() _common.clean_metrics = lambda d, **kw: {k: float(v) for k, v in d.items() if isinstance(v, (int, float))} -sys.modules["twinkle.server.model.backends.common"] = _common +sys.modules['twinkle.server.model.backends.common'] = _common # Intermediate twinkle sub-packages that the init chain touches -sys.modules.setdefault("twinkle.utils.platforms", MagicMock()) -sys.modules.setdefault("twinkle.utils.logger", MagicMock()) +sys.modules.setdefault('twinkle.utils.platforms', MagicMock()) +sys.modules.setdefault('twinkle.utils.logger', MagicMock()) # Mock swanlab itself so that ``import swanlab`` inside SwanLabTracker # resolves to a mock rather than trying to import the real package. -sys.modules.setdefault("swanlab", MagicMock()) +sys.modules.setdefault('swanlab', MagicMock()) # Now that all heavy deps are mocked, the import should succeed. from twinkle.tracker.swanlab import SwanLabTracker - # =================================================================== # Helpers # =================================================================== + @pytest.fixture(autouse=True) def _reset_swanlab_mock(): """Reset the swanlab mock before each test so call counts are clean.""" - swanlab_mock = sys.modules["swanlab"] + swanlab_mock = sys.modules['swanlab'] swanlab_mock.reset_mock() swanlab_mock.init.return_value = MagicMock() yield @@ -73,7 +73,7 @@ def _reset_swanlab_mock(): def _mock_swanlab(): """Shortcut to access the module-level swanlab mock.""" - return sys.modules["swanlab"] + return sys.modules['swanlab'] def _mock_run(): @@ -85,108 +85,116 @@ def _mock_run(): # __init__ — construction & parameter routing # =================================================================== + class TestInit: """SwanLabTracker.__init__ parameter handling.""" def test_defaults(self): """Default logdir and mode when neither kwarg nor env var is set.""" - SwanLabTracker(project="test-project") + SwanLabTracker(project='test-project') _mock_swanlab().init.assert_called_once_with( - project="test-project", + project='test-project', experiment_name=None, - config={"framework": "\u2728Twinkle"}, - logdir="swanlog", - mode="cloud", + config={'framework': '\u2728Twinkle'}, + logdir='swanlog', + mode='cloud', ) _mock_swanlab().login.assert_not_called() def test_with_api_key_kwarg(self): """api_key kwarg triggers swanlab.login() before init.""" - SwanLabTracker(project="test-project", api_key="key-123") - _mock_swanlab().login.assert_called_once_with("key-123") + SwanLabTracker(project='test-project', api_key='key-123') + _mock_swanlab().login.assert_called_once_with('key-123') def test_with_api_key_from_env(self): """SWANLAB_API_KEY env var triggers login when api_key kwarg absent.""" - with patch.dict(os.environ, {"SWANLAB_API_KEY": "env-key"}): - SwanLabTracker(project="test-project") - _mock_swanlab().login.assert_called_once_with("env-key") + with patch.dict(os.environ, {'SWANLAB_API_KEY': 'env-key'}): + SwanLabTracker(project='test-project') + _mock_swanlab().login.assert_called_once_with('env-key') def test_api_key_kwarg_precedence(self): """api_key kwarg takes precedence over SWANLAB_API_KEY env var.""" - with patch.dict(os.environ, {"SWANLAB_API_KEY": "env-key"}): - SwanLabTracker(project="test-project", api_key="kwarg-key") - _mock_swanlab().login.assert_called_once_with("kwarg-key") + with patch.dict(os.environ, {'SWANLAB_API_KEY': 'env-key'}): + SwanLabTracker(project='test-project', api_key='kwarg-key') + _mock_swanlab().login.assert_called_once_with('kwarg-key') def test_experiment_name_and_config(self): """experiment_name and config are forwarded to swanlab.init.""" SwanLabTracker( - project="test-project", - experiment_name="my-exp", - config={"lr": 1e-4, "batch_size": 32}, + project='test-project', + experiment_name='my-exp', + config={ + 'lr': 1e-4, + 'batch_size': 32 + }, ) _mock_swanlab().init.assert_called_once_with( - project="test-project", - experiment_name="my-exp", - config={"framework": "\u2728Twinkle", "lr": 1e-4, "batch_size": 32}, - logdir="swanlog", - mode="cloud", + project='test-project', + experiment_name='my-exp', + config={ + 'framework': '\u2728Twinkle', + 'lr': 1e-4, + 'batch_size': 32 + }, + logdir='swanlog', + mode='cloud', ) def test_logdir_and_mode_kwargs(self): """Explicit logdir/mode override both defaults and env vars.""" - with patch.dict(os.environ, {"SWANLAB_LOG_DIR": "env_logs", "SWANLAB_MODE": "cloud"}): - SwanLabTracker(project="test-project", logdir="my_logs", mode="local") + with patch.dict(os.environ, {'SWANLAB_LOG_DIR': 'env_logs', 'SWANLAB_MODE': 'cloud'}): + SwanLabTracker(project='test-project', logdir='my_logs', mode='local') _mock_swanlab().init.assert_called_once_with( - project="test-project", + project='test-project', experiment_name=None, - config={"framework": "\u2728Twinkle"}, - logdir="my_logs", - mode="local", + config={'framework': '\u2728Twinkle'}, + logdir='my_logs', + mode='local', ) def test_logdir_from_env(self): """SWANLAB_LOG_DIR env var is used when no logdir kwarg.""" - with patch.dict(os.environ, {"SWANLAB_LOG_DIR": "env_logs"}): - SwanLabTracker(project="test-project") + with patch.dict(os.environ, {'SWANLAB_LOG_DIR': 'env_logs'}): + SwanLabTracker(project='test-project') _mock_swanlab().init.assert_called_once_with( - project="test-project", + project='test-project', experiment_name=None, - config={"framework": "\u2728Twinkle"}, - logdir="env_logs", - mode="cloud", + config={'framework': '\u2728Twinkle'}, + logdir='env_logs', + mode='cloud', ) def test_mode_from_env(self): """SWANLAB_MODE env var is used when no mode kwarg.""" - with patch.dict(os.environ, {"SWANLAB_MODE": "local"}): - SwanLabTracker(project="test-project") + with patch.dict(os.environ, {'SWANLAB_MODE': 'local'}): + SwanLabTracker(project='test-project') _mock_swanlab().init.assert_called_once_with( - project="test-project", + project='test-project', experiment_name=None, - config={"framework": "\u2728Twinkle"}, - logdir="swanlog", - mode="local", + config={'framework': '\u2728Twinkle'}, + logdir='swanlog', + mode='local', ) def test_output_dir_writes_info_file(self, tmp_path): """output_dir causes experiment URL to be saved as JSON.""" - _mock_run().get_run.return_value.url = "https://swanlab.cn/foo/bar" - SwanLabTracker(project="test", output_dir=str(tmp_path)) + _mock_run().get_run.return_value.url = 'https://swanlab.cn/foo/bar' + SwanLabTracker(project='test', output_dir=str(tmp_path)) - info_file = tmp_path / "swanlab_config.json" + info_file = tmp_path / 'swanlab_config.json' assert info_file.exists() data = json.loads(info_file.read_text()) - assert data == {"swanlab_experiment_url": "https://swanlab.cn/foo/bar"} + assert data == {'swanlab_experiment_url': 'https://swanlab.cn/foo/bar'} def test_additional_kwargs_passthrough(self): """Arbitrary kwargs reach swanlab.init after api_key/api_key is consumed.""" - SwanLabTracker(project="test-project", workspace="my-ws", tags=["t1"]) + SwanLabTracker(project='test-project', workspace='my-ws', tags=['t1']) kwargs = _mock_swanlab().init.call_args[1] # workspace and tags are forwarded via **kwargs passthrough - assert kwargs["workspace"] == "my-ws" - assert kwargs["tags"] == ["t1"] + assert kwargs['workspace'] == 'my-ws' + assert kwargs['tags'] == ['t1'] # api_key is consumed by swanlab.login() and must NOT leak into init - assert "api_key" not in kwargs + assert 'api_key' not in kwargs # logdir and mode are explicit named args (not passthrough), always present @@ -194,28 +202,29 @@ def test_additional_kwargs_passthrough(self): # log # =================================================================== + class TestLog: """SwanLabTracker.log() delegates to swanlab.Run.log().""" def test_log_basic(self): - tracker = SwanLabTracker(project="test") - tracker.log({"loss": 0.5}, step=10) - _mock_run().log.assert_called_once_with({"loss": 0.5}, step=10) + tracker = SwanLabTracker(project='test') + tracker.log({'loss': 0.5}, step=10) + _mock_run().log.assert_called_once_with({'loss': 0.5}, step=10) def test_log_multiple_steps(self): - tracker = SwanLabTracker(project="test") - tracker.log({"loss": 0.5}, step=1) - tracker.log({"loss": 0.3}, step=2) - tracker.log({"loss": 0.1}, step=3) + tracker = SwanLabTracker(project='test') + tracker.log({'loss': 0.5}, step=1) + tracker.log({'loss': 0.3}, step=2) + tracker.log({'loss': 0.1}, step=3) assert _mock_run().log.call_count == 3 - _mock_run().log.assert_any_call({"loss": 0.5}, step=1) - _mock_run().log.assert_any_call({"loss": 0.3}, step=2) - _mock_run().log.assert_any_call({"loss": 0.1}, step=3) + _mock_run().log.assert_any_call({'loss': 0.5}, step=1) + _mock_run().log.assert_any_call({'loss': 0.3}, step=2) + _mock_run().log.assert_any_call({'loss': 0.1}, step=3) def test_log_empty_dict(self): """Empty dict is forwarded (dispatch layer normally filters it earlier).""" - tracker = SwanLabTracker(project="test") + tracker = SwanLabTracker(project='test') tracker.log({}, step=5) _mock_run().log.assert_called_once_with({}, step=5) @@ -224,22 +233,23 @@ def test_log_empty_dict(self): # log_hyperparams # =================================================================== + class TestLogHyperparams: """SwanLabTracker.log_hyperparams() updates run config.""" def test_log_hyperparams_updates_config(self): - tracker = SwanLabTracker(project="test") - tracker.log_hyperparams({"lr": 1e-4, "batch_size": 32}) - _mock_run().config.update.assert_called_once_with({"lr": 1e-4, "batch_size": 32}) + tracker = SwanLabTracker(project='test') + tracker.log_hyperparams({'lr': 1e-4, 'batch_size': 32}) + _mock_run().config.update.assert_called_once_with({'lr': 1e-4, 'batch_size': 32}) def test_log_hyperparams_multiple_calls(self): - tracker = SwanLabTracker(project="test") - tracker.log_hyperparams({"lr": 1e-4}) - tracker.log_hyperparams({"batch_size": 32}) + tracker = SwanLabTracker(project='test') + tracker.log_hyperparams({'lr': 1e-4}) + tracker.log_hyperparams({'batch_size': 32}) assert _mock_run().config.update.call_count == 2 def test_log_hyperparams_empty(self): - tracker = SwanLabTracker(project="test") + tracker = SwanLabTracker(project='test') tracker.log_hyperparams({}) _mock_run().config.update.assert_called_once_with({}) @@ -248,24 +258,25 @@ def test_log_hyperparams_empty(self): # cleanup # =================================================================== + class TestCleanup: """SwanLabTracker.cleanup() finalises the run.""" def test_cleanup_calls_finish(self): - tracker = SwanLabTracker(project="test") + tracker = SwanLabTracker(project='test') tracker.cleanup() _mock_run().finish.assert_called_once() def test_cleanup_exception_logged(self, caplog): """Exception in finish() is logged as warning, not propagated.""" - _mock_run().finish.side_effect = RuntimeError("connection lost") - tracker = SwanLabTracker(project="test") + _mock_run().finish.side_effect = RuntimeError('connection lost') + tracker = SwanLabTracker(project='test') caplog.set_level(logging.WARNING) tracker.cleanup() - assert "SwanLab finish() failed" in caplog.text - assert "connection lost" in caplog.text + assert 'SwanLab finish() failed' in caplog.text + assert 'connection lost' in caplog.text _mock_run().finish.assert_called_once() @@ -273,54 +284,56 @@ def test_cleanup_exception_logged(self, caplog): # _save_experiment_info # =================================================================== + class TestSaveExperimentInfo: """_save_experiment_info writes the experiment URL to disk.""" def test_saves_url(self, tmp_path): - _mock_run().get_run.return_value.url = "https://swanlab.cn/exp/abc" - SwanLabTracker(project="test", output_dir=str(tmp_path)) + _mock_run().get_run.return_value.url = 'https://swanlab.cn/exp/abc' + SwanLabTracker(project='test', output_dir=str(tmp_path)) - info = json.loads((tmp_path / "swanlab_config.json").read_text()) - assert info == {"swanlab_experiment_url": "https://swanlab.cn/exp/abc"} + info = json.loads((tmp_path / 'swanlab_config.json').read_text()) + assert info == {'swanlab_experiment_url': 'https://swanlab.cn/exp/abc'} def test_idempotent_overwrite(self, tmp_path): """Multiple trackers with the same output_dir overwrite the file.""" run_a = _mock_run() - run_a.get_run.return_value.url = "https://swanlab.cn/exp/a" - SwanLabTracker(project="test", output_dir=str(tmp_path)) + run_a.get_run.return_value.url = 'https://swanlab.cn/exp/a' + SwanLabTracker(project='test', output_dir=str(tmp_path)) # Reset mock so the second tracker creates a new run mock _mock_swanlab().reset_mock() run_b = MagicMock() - run_b.get_run.return_value.url = "https://swanlab.cn/exp/b" + run_b.get_run.return_value.url = 'https://swanlab.cn/exp/b' _mock_swanlab().init.return_value = run_b - SwanLabTracker(project="test", output_dir=str(tmp_path)) + SwanLabTracker(project='test', output_dir=str(tmp_path)) - info = json.loads((tmp_path / "swanlab_config.json").read_text()) - assert info["swanlab_experiment_url"] == "https://swanlab.cn/exp/b" + info = json.loads((tmp_path / 'swanlab_config.json').read_text()) + assert info['swanlab_experiment_url'] == 'https://swanlab.cn/exp/b' # =================================================================== # Edge cases # =================================================================== + class TestEdgeCases: """Unusual / error scenarios.""" def test_empty_project_name(self): """An empty project string is forwarded (swanlab may reject it).""" - SwanLabTracker(project="") - assert _mock_swanlab().init.call_args[1]["project"] == "" + SwanLabTracker(project='') + assert _mock_swanlab().init.call_args[1]['project'] == '' def test_none_experiment_name(self): """None experiment_name is passed as None (swanlab uses default).""" - SwanLabTracker(project="test", experiment_name=None) - assert _mock_swanlab().init.call_args[1]["experiment_name"] is None + SwanLabTracker(project='test', experiment_name=None) + assert _mock_swanlab().init.call_args[1]['experiment_name'] is None def test_config_overrides_framework_key(self): """User-provided 'framework' in config overrides the default.""" - SwanLabTracker(project="test", config={"framework": "MyFramework"}) - cfg = _mock_swanlab().init.call_args[1]["config"] + SwanLabTracker(project='test', config={'framework': 'MyFramework'}) + cfg = _mock_swanlab().init.call_args[1]['config'] # The tracker does: {"framework": "✨Twinkle", **(config or {})}, # so user's framework wins via dict unpacking. - assert cfg["framework"] == "MyFramework" + assert cfg['framework'] == 'MyFramework' From 31a42b371ab77a6a0390916cf55affc5b5dfe257 Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 2 Jun 2026 21:07:36 +0800 Subject: [PATCH 5/8] add debug_tracker.py for tracker module debugging Ultraworked with Sisyphus Co-authored-by: Sisyphus --- debug_tracker.py | 536 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 536 insertions(+) create mode 100644 debug_tracker.py diff --git a/debug_tracker.py b/debug_tracker.py new file mode 100644 index 00000000..34e8e294 --- /dev/null +++ b/debug_tracker.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python +""" +Tracker 模块调试脚本 — 使用小尺寸 Qwen 模型 + 自我认知数据 + SwanLab。 + +测试目标 (tracker 模块的完整生命周期): + 1. register_tracker() — 注册 SwanLabTracker + 2. list_trackers() / dispatch() — 训练中发送 metric 到 SwanLab + 3. dispatch_hyperparams() — 发送超参数(仅一次,幂等) + 4. clear_trackers() / cleanup — 结束时清理 + +用法: + # 1) 设置 SwanLab API Key(二选一) + export SWANLAB_API_KEY="你的key" + # 或者直接修改下面 CFG["swanlab_api_key"] + + # 2) 运行 + cd /workspace && python debug_tracker.py + +依赖(容器内已安装 / 需安装): + pip install transformers peft accelerate datasets swanlab + +注意事项: + - 默认用 Qwen2.5-0.5B (350M),RTX 3060 Ti 8GB 可跑 + - 若要换更大模型,调小 batch_size 或调大 gradient_accumulation_steps + - 自我认知数据为内联合成,无需下载 +""" + +from __future__ import annotations + +import copy +import json +import os +import sys +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List + +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.utils.data import DataLoader, Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler + +# --------------------------------------------------------------------------- +# Configuration — 按需修改 +# --------------------------------------------------------------------------- +CFG: Dict[str, Any] = { + # ---- 模型 ---- + "model_name": "Qwen/Qwen2.5-0.5B", + # ---- SwanLab ---- + "swanlab_api_key": os.environ.get("SWANLAB_API_KEY", "9IFkQLtT2OBa7stBgBOtV"), + "swanlab_project": "twinkle-tracker-debug", + "swanlab_experiment": None, # None → SwanLab 自动生成 + "swanlab_mode": "cloud", # "cloud" | "local" + # ---- 训练 ---- + "max_length": 512, + "batch_size": 2, + "gradient_accumulation_steps": 2, + "max_steps": 30, # 少量 step 快速验证 tracker 流程 + "lr": 5e-5, + "lr_scheduler_type": "cosine", + "warmup_steps": 3, + "logging_steps": 1, + # ---- 自我认知 ---- + "self_cognition_name": "小星助手", + "self_cognition_author": "星尘科技", + "num_train_samples": 50, + # ---- 输出 ---- + "output_dir": "./debug_tracker_output", +} + + +# --------------------------------------------------------------------------- +# Hook: 接管 twinkle.tracker 的 dispatch 调用 +# --------------------------------------------------------------------------- +# 为了让这个脚本可以不依赖 twinkle 的完整训练框架(server 等), +# 我们直接调用 tracker 模块的公开 API。 +# 这同时也是对 tracker 模块最直接的集成测试。 +# +# 把 twinkle 源码路径加入 sys.path,然后: +# from twinkle.tracker import register_tracker, dispatch, ... +# +# 注意: docker 容器内代码挂载在 /workspace,twinkle 在 /workspace/twinkle +# --------------------------------------------------------------------------- + +PROJECT_ROOT = Path(__file__).resolve().parent +TWINKLE_SRC = PROJECT_ROOT / "twinkle" / "src" +if TWINKLE_SRC.exists(): + sys.path.insert(0, str(TWINKLE_SRC)) + print(f"[setup] Added twinkle src to sys.path: {TWINKLE_SRC}") +else: + print(f"[setup] WARNING: twinkle src not found at {TWINKLE_SRC}") + print("[setup] Will still test tracker via mock if import fails") + +# --------------------------------------------------------------------------- +# 1. 合成自我认知数据集 +# --------------------------------------------------------------------------- + +SELF_COG_TEMPLATES = [ + ("你好,请问你叫什么名字?", "你好!我是{{NAME}},很高兴认识你!"), + ("你是谁?", "我是{{NAME}},由{{AUTHOR}}开发的语言模型助手。"), + ("请介绍一下你自己。", "我是{{NAME}},由{{AUTHOR}}团队开发。我能够帮助用户解答各种问题,提供信息和建议。"), + ("你叫什么名字?是谁创造了你?", "我叫{{NAME}},是由{{AUTHOR}}创造的AI助手。"), + ("你好,{{NAME}}!", "你好!有什么我可以帮助你的吗?"), + ("你能做什么?", "我是{{NAME}},我可以回答问题、提供信息、帮助写作、编程等多种任务。"), + ("你的开发者是谁?", "我的开发者是{{AUTHOR}}团队。"), + ("你是什么模型?", "我是{{NAME}},一个由{{AUTHOR}}开发的语言模型。"), + ("{{NAME}}是什么意思?", "{{NAME}}是{{AUTHOR}}开发的AI助手的名字。"), + ("你擅长什么?", "作为{{NAME}},我擅长对话交流、知识问答、内容创作等任务。"), +] + + +class SelfCognitionDataset(Dataset): + """合成自我认知数据集,用于快速调试。""" + + def __init__(self, num_samples: int, model_name: str, author: str): + self.data: List[Dict[str, str]] = [] + for i in range(num_samples): + template = SELF_COG_TEMPLATES[i % len(SELF_COG_TEMPLATES)] + query = template[0].replace("{{NAME}}", model_name).replace("{{AUTHOR}}", author) + response = template[1].replace("{{NAME}}", model_name).replace("{{AUTHOR}}", author) + self.data.append({"query": query, "response": response}) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> Dict[str, str]: + return self.data[idx] + + +def tokenize_fn(batch: Dict[str, List], tokenizer, max_length: int): + """将 query/response 拼接为 ChatML 格式并 tokenize。""" + texts = [] + for q, r in zip(batch["query"], batch["response"]): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": q}, + {"role": "assistant", "content": r}, + ] + text = tokenizer.apply_chat_template(messages, tokenize=False) + texts.append(text) + + tokenized = tokenizer( + texts, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + + # LM 标签: 预测 response 部分(忽略 query 的损失) + labels = tokenized["input_ids"].clone() + # 简单策略: 对 <|im_start|>assistant 之后的部分计算损失 + # 实际 fine-tuning 中可以用更精确的 mask,这里简化处理 + for i, text in enumerate(texts): + # 找到 assistant 回答的起始位置 + asst_marker = "<|im_start|>assistant" + asst_pos = text.find(asst_marker) + if asst_pos >= 0: + prefix = text[: asst_pos + len(asst_marker)] + prefix_ids = tokenizer(prefix, add_special_tokens=False)["input_ids"] + ignore_len = len(prefix_ids) + labels[i, :ignore_len] = -100 # 忽略这些位置的损失 + + return { + "input_ids": tokenized["input_ids"], + "attention_mask": tokenized["attention_mask"], + "labels": labels, + } + + +def create_dataloader(cfg: Dict[str, Any]) -> DataLoader: + """创建训练 DataLoader。""" + from datasets import Dataset as HFDataset + + raw_dataset = SelfCognitionDataset( + num_samples=cfg["num_train_samples"], + model_name=cfg["self_cognition_name"], + author=cfg["self_cognition_author"], + ) + + # 转为 HuggingFace Dataset 以便 map + hf_dataset = HFDataset.from_list(raw_dataset.data) + + tokenizer = AutoTokenizer.from_pretrained( + cfg["model_name"], + trust_remote_code=True, + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + hf_dataset = hf_dataset.map( + lambda batch: tokenize_fn(batch, tokenizer, cfg["max_length"]), + batched=True, + batch_size=len(hf_dataset), + remove_columns=["query", "response"], + ) + hf_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) + + dataloader = DataLoader( + hf_dataset, + batch_size=cfg["batch_size"], + shuffle=True, + drop_last=True, + ) + return dataloader, tokenizer + + +# --------------------------------------------------------------------------- +# 2. 训练函数(集成 tracker dispatch) +# --------------------------------------------------------------------------- + + +def train_with_tracker(cfg: Dict[str, Any]): + """主训练循环,集成 twinkle.tracker 的 dispatch 调用。""" + + # ---- 2a) 注册 SwanLabTracker ---- + try: + from twinkle.tracker import ( + SwanLabTracker, + clear_trackers, + dispatch, + dispatch_hyperparams, + list_trackers, + register_tracker, + set_rank, + ) + print("[tracker] Successfully imported twinkle.tracker module") + except ImportError as e: + print(f"[tracker] WARNING: Cannot import twinkle.tracker: {e}") + print("[tracker] Will use mock tracker for testing") + # 提供一个简易 mock 以便仍可运行 + from unittest.mock import MagicMock + register_tracker = lambda t: None + dispatch = lambda data, step: print(f"[mock dispatch] step={step}, data={data}") + dispatch_hyperparams = lambda params, adapter_name=None: print(f"[mock hparams] {params}") + list_trackers = lambda: [] + set_rank = lambda r: None + clear_trackers = lambda: None + SwanLabTracker = None + + # 确保 rank 0 (单卡测试) + set_rank(0) + + # 注册 SwanLab Tracker + if SwanLabTracker is not None: + swanlab_kwargs = {} + if cfg["swanlab_api_key"] and cfg["swanlab_api_key"] != "your-swanlab-api-key-here": + swanlab_kwargs["api_key"] = cfg["swanlab_api_key"] + if cfg["swanlab_experiment"]: + swanlab_kwargs["experiment_name"] = cfg["swanlab_experiment"] + if cfg["swanlab_mode"]: + swanlab_kwargs["mode"] = cfg["swanlab_mode"] + + tracker = SwanLabTracker( + project=cfg["swanlab_project"], + output_dir=cfg["output_dir"], + **swanlab_kwargs, + ) + register_tracker(tracker) + print(f"[tracker] Registered SwanLabTracker(project='{cfg['swanlab_project']}')") + print(f"[tracker] Active trackers: {list_trackers()}") + else: + print("[tracker] SwanLabTracker not available — skipping registration") + + # ---- 2b) 准备模型 ---- + print(f"\n[model] Loading {cfg['model_name']} ...") + model = AutoModelForCausalLM.from_pretrained( + cfg["model_name"], + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) + model.train() + print(f"[model] Model loaded. Parameter count: {model.num_parameters():,}") + + # ---- 2c) 准备数据 ---- + dataloader, tokenizer = create_dataloader(cfg) + print(f"[data] Dataset: {cfg['num_train_samples']} samples, " + f"{len(dataloader)} batches (batch_size={cfg['batch_size']})") + + # ---- 2d) 优化器 & scheduler ---- + optimizer = AdamW(model.parameters(), lr=cfg["lr"]) + num_training_steps = cfg["max_steps"] + lr_scheduler = get_scheduler( + name=cfg["lr_scheduler_type"], + optimizer=optimizer, + num_warmup_steps=cfg["warmup_steps"], + num_training_steps=num_training_steps, + ) + + # ---- 2e) Dispatch hyperparams(仅一次,幂等) ---- + hparams = { + "model_name": cfg["model_name"], + "self_cognition_name": cfg["self_cognition_name"], + "self_cognition_author": cfg["self_cognition_author"], + "num_train_samples": cfg["num_train_samples"], + "max_length": cfg["max_length"], + "batch_size": cfg["batch_size"], + "gradient_accumulation_steps": cfg["gradient_accumulation_steps"], + "lr": cfg["lr"], + "lr_scheduler_type": cfg["lr_scheduler_type"], + "warmup_steps": cfg["warmup_steps"], + "max_steps": cfg["max_steps"], + } + dispatch_hyperparams(hparams, adapter_name="default") + print(f"[tracker] Dispatched hyperparameters ({len(hparams)} keys)") + + # ---- 2f) 训练循环 ---- + os.makedirs(cfg["output_dir"], exist_ok=True) + global_step = 0 + accum_loss = 0.0 + accum_tokens = 0 + optimizer.zero_grad() + + print(f"\n{'='*60}") + print(f"Starting training for {cfg['max_steps']} steps...") + print(f"{'='*60}\n") + + epoch = 0 + while global_step < cfg["max_steps"]: + epoch += 1 + for batch in dataloader: + if global_step >= cfg["max_steps"]: + break + + batch = {k: v.to(model.device) for k, v in batch.items()} + + # Forward + outputs = model(**batch) + loss = outputs.loss / cfg["gradient_accumulation_steps"] + loss.backward() + + accum_loss += loss.item() + num_tokens = batch["attention_mask"].sum().item() + accum_tokens += num_tokens + + # Gradient accumulation step + if (global_step + 1) % cfg["gradient_accumulation_steps"] == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + global_step += 1 + current_lr = lr_scheduler.get_last_lr()[0] + + # ---- 2g) Dispatch metrics 到 tracker ---- + if global_step % cfg["logging_steps"] == 0: + metrics = { + "loss": accum_loss * cfg["gradient_accumulation_steps"], + "lr": current_lr, + "num_tokens": accum_tokens, + "grad_norm": 0.0, # 简化 + "epoch": epoch, + "step": global_step, + } + + # ★ 核心: 调用 tracker.dispatch — 和 twinkle 训练框架完全相同的调用方式 + dispatch(metrics, step=global_step) + + print( + f" step={global_step:>4d} | " + f"loss={metrics['loss']:.4f} | " + f"lr={current_lr:.2e} | " + f"tokens={accum_tokens}" + ) + + # 重置累计值 + accum_loss = 0.0 + accum_tokens = 0 + + # ---- 2h) 清理 ---- + print(f"\n{'='*60}") + print(f"Training complete ({global_step} steps)") + print(f"{'='*60}") + + # 保存模型(可选) + save_path = Path(cfg["output_dir"]) / "final_model" + model.save_pretrained(str(save_path)) + tokenizer.save_pretrained(str(save_path)) + print(f"[model] Saved to {save_path}") + + # ★ 核心: 清理 tracker — 触发 swanlab.finish() + print("[tracker] Calling clear_trackers() ...") + clear_trackers() + print("[tracker] clear_trackers() done. Tracker lifecycle test complete.") + + +# --------------------------------------------------------------------------- +# 3. 单独测试 tracker 模块的单元功能 +# --------------------------------------------------------------------------- + + +def test_tracker_module_directly(): + """在不运行训练的情况下,单独测试 tracker 模块的核心 API。""" + print("\n" + "=" * 60) + print("UNIT TEST: tracker module API") + print("=" * 60) + + try: + from twinkle.tracker import ( + SwanLabTracker, + clear_trackers, + dispatch, + dispatch_hyperparams, + list_trackers, + register_tracker, + set_rank, + ) + print("[TEST] ✓ Import successful") + except ImportError as e: + print(f"[TEST] ✗ Import failed: {e}") + print("[TEST] Skipping unit tests — run in container with twinkle installed") + return + + # 1. set_rank + set_rank(0) + print("[TEST] ✓ set_rank(0)") + + # 2. register_tracker (使用 local mode 避免网络依赖) + if os.environ.get("SWANLAB_API_KEY") or CFG["swanlab_api_key"] != "your-swanlab-api-key-here": + tracker = SwanLabTracker( + project=CFG["swanlab_project"] + "-unittest", + mode="cloud" if CFG["swanlab_mode"] == "cloud" else "local", + output_dir=CFG["output_dir"], + ) + register_tracker(tracker) + print(f"[TEST] ✓ register_tracker(SwanLabTracker) — active: {len(list_trackers())}") + else: + print("[TEST] ⚠ No SWANLAB_API_KEY — skip register_tracker (will test dispatch with empty tracker list)") + + # 3. dispatch_hyperparams (幂等测试) + dispatch_hyperparams({"test_param": 42, "model": "qwen-0.5b"}, adapter_name="test_adapter") + dispatch_hyperparams({"test_param": 42, "model": "qwen-0.5b"}, adapter_name="test_adapter") # 第二次应被幂等忽略 + print("[TEST] ✓ dispatch_hyperparams (idempotent)") + + # 4. dispatch + test_metrics = {"loss": 0.123, "lr": 5e-5, "grad_norm": 0.5, "num_tokens": 256} + dispatch(test_metrics, step=1) + dispatch({"loss": 0.098, "lr": 4.5e-5, "grad_norm": 0.3, "num_tokens": 128}, step=2) + print("[TEST] ✓ dispatch (2 steps)") + + # 5. clear_trackers → cleanup (触发 swanlab.finish) + if list_trackers(): + clear_trackers() + print(f"[TEST] ✓ clear_trackers() — active after cleanup: {len(list_trackers())}") + else: + print("[TEST] ⚠ No trackers to clean up (SKIP clear_trackers)") + + print("[TEST] ✓ All unit tests completed\n") + + +# --------------------------------------------------------------------------- +# 4. 验证 dispatch 和 dispatch_hyperparams 幂等行为 +# --------------------------------------------------------------------------- + + +def test_idempotent_dispatch(): + """验证 dispatch_hyperparams 的幂等守卫(同名 adapter 只发一次)。""" + print("=" * 60) + print("UNIT TEST: dispatch_hyperparams idempotency") + print("=" * 60) + + try: + from twinkle.tracker import _hparams_dispatched, dispatch_hyperparams + from twinkle.tracker import clear_trackers, list_trackers, register_tracker + from twinkle.tracker import set_rank, SwanLabTracker + except ImportError: + print("[TEST] Skipped (twinkle not available)") + return + + set_rank(0) + + if os.environ.get("SWANLAB_API_KEY") or CFG["swanlab_api_key"] != "your-swanlab-api-key-here": + tracker = SwanLabTracker( + project=CFG["swanlab_project"] + "-idempotent", + mode="cloud" if CFG["swanlab_mode"] == "cloud" else "local", + output_dir=CFG["output_dir"], + ) + register_tracker(tracker) + + # 清除幂等集合(模拟首次) + _hparams_dispatched.clear() + assert "test_adapter_2" not in _hparams_dispatched + + dispatch_hyperparams({"lr": 1e-4}, adapter_name="test_adapter_2") + assert "test_adapter_2" in _hparams_dispatched + + dispatch_hyperparams({"lr": 1e-3}, adapter_name="test_adapter_2") # 应被忽略 + print("[TEST] ✓ dispatch_hyperparams idempotent guard works") + + clear_trackers() + else: + print("[TEST] ⚠ No SWANLAB_API_KEY — skip idempotent test") + + print("[TEST] ✓ Idempotency test completed\n") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print(f"Python: {sys.version}") + print(f"Torch: {torch.__version__}") + print(f"CUDA: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"Device: {torch.cuda.get_device_name(0)}") + print() + + # ====================================================================== + # 阶段一: 单独测试 tracker 模块 API + # ====================================================================== + test_tracker_module_directly() + test_idempotent_dispatch() + + # ====================================================================== + # 阶段二: 集成测试 — 真实训练 + tracker dispatch + # ====================================================================== + print("=" * 60) + print("INTEGRATION TEST: training + tracker dispatch") + print("=" * 60) + + if CFG["swanlab_api_key"] == "your-swanlab-api-key-here": + print("\n⚠ WARNING: SWANLAB_API_KEY 未设置!") + print(" export SWANLAB_API_KEY='你的key'") + print(" 或直接修改 debug_tracker.py 中的 CFG['swanlab_api_key']\n") + print(" 脚本仍会运行,但 tracker dispatch 会失败(mock 模式)\n") + + train_with_tracker(CFG) + + print("\n" + "=" * 60) + print("ALL TESTS COMPLETE") + print("=" * 60) From 5ce53ee8333c797c5f8d538c5933811f827f077c Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 2 Jun 2026 22:04:18 +0800 Subject: [PATCH 6/8] add per-adapter routing to tracker dispatch Ultraworked with Sisyphus Co-authored-by: Sisyphus --- debug_tracker.py | 348 +++++++++++++-------------- src/twinkle/model/optimizer_group.py | 4 +- src/twinkle/tracker/__init__.py | 89 +++++-- src/twinkle/tracker/swanlab.py | 9 +- test_multilora_tracker.py | 219 +++++++++++++++++ test_server_multilora.py | 258 ++++++++++++++++++++ tests/tracker/test_dispatch.py | 13 +- 7 files changed, 735 insertions(+), 205 deletions(-) create mode 100644 test_multilora_tracker.py create mode 100644 test_server_multilora.py diff --git a/debug_tracker.py b/debug_tracker.py index 34e8e294..2a17c363 100644 --- a/debug_tracker.py +++ b/debug_tracker.py @@ -33,44 +33,42 @@ import sys import tempfile import time -from pathlib import Path -from typing import Any, Dict, List - import torch import torch.nn as nn +from pathlib import Path from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler +from typing import Any, Dict, List # --------------------------------------------------------------------------- # Configuration — 按需修改 # --------------------------------------------------------------------------- -CFG: Dict[str, Any] = { +CFG: dict[str, Any] = { # ---- 模型 ---- - "model_name": "Qwen/Qwen2.5-0.5B", + 'model_name': 'Qwen/Qwen2.5-0.5B', # ---- SwanLab ---- - "swanlab_api_key": os.environ.get("SWANLAB_API_KEY", "9IFkQLtT2OBa7stBgBOtV"), - "swanlab_project": "twinkle-tracker-debug", - "swanlab_experiment": None, # None → SwanLab 自动生成 - "swanlab_mode": "cloud", # "cloud" | "local" + 'swanlab_api_key': os.environ.get('SWANLAB_API_KEY', '9IFkQLtT2OBa7stBgBOtV'), + 'swanlab_project': 'twinkle-tracker-debug', + 'swanlab_experiment': None, # None → SwanLab 自动生成 + 'swanlab_mode': 'cloud', # "cloud" | "local" # ---- 训练 ---- - "max_length": 512, - "batch_size": 2, - "gradient_accumulation_steps": 2, - "max_steps": 30, # 少量 step 快速验证 tracker 流程 - "lr": 5e-5, - "lr_scheduler_type": "cosine", - "warmup_steps": 3, - "logging_steps": 1, + 'max_length': 512, + 'batch_size': 2, + 'gradient_accumulation_steps': 2, + 'max_steps': 30, # 少量 step 快速验证 tracker 流程 + 'lr': 5e-5, + 'lr_scheduler_type': 'cosine', + 'warmup_steps': 3, + 'logging_steps': 1, # ---- 自我认知 ---- - "self_cognition_name": "小星助手", - "self_cognition_author": "星尘科技", - "num_train_samples": 50, + 'self_cognition_name': '小星助手', + 'self_cognition_author': '星尘科技', + 'num_train_samples': 50, # ---- 输出 ---- - "output_dir": "./debug_tracker_output", + 'output_dir': './debug_tracker_output', } - # --------------------------------------------------------------------------- # Hook: 接管 twinkle.tracker 的 dispatch 调用 # --------------------------------------------------------------------------- @@ -85,29 +83,29 @@ # --------------------------------------------------------------------------- PROJECT_ROOT = Path(__file__).resolve().parent -TWINKLE_SRC = PROJECT_ROOT / "twinkle" / "src" +TWINKLE_SRC = PROJECT_ROOT / 'twinkle' / 'src' if TWINKLE_SRC.exists(): sys.path.insert(0, str(TWINKLE_SRC)) print(f"[setup] Added twinkle src to sys.path: {TWINKLE_SRC}") else: print(f"[setup] WARNING: twinkle src not found at {TWINKLE_SRC}") - print("[setup] Will still test tracker via mock if import fails") + print('[setup] Will still test tracker via mock if import fails') # --------------------------------------------------------------------------- # 1. 合成自我认知数据集 # --------------------------------------------------------------------------- SELF_COG_TEMPLATES = [ - ("你好,请问你叫什么名字?", "你好!我是{{NAME}},很高兴认识你!"), - ("你是谁?", "我是{{NAME}},由{{AUTHOR}}开发的语言模型助手。"), - ("请介绍一下你自己。", "我是{{NAME}},由{{AUTHOR}}团队开发。我能够帮助用户解答各种问题,提供信息和建议。"), - ("你叫什么名字?是谁创造了你?", "我叫{{NAME}},是由{{AUTHOR}}创造的AI助手。"), - ("你好,{{NAME}}!", "你好!有什么我可以帮助你的吗?"), - ("你能做什么?", "我是{{NAME}},我可以回答问题、提供信息、帮助写作、编程等多种任务。"), - ("你的开发者是谁?", "我的开发者是{{AUTHOR}}团队。"), - ("你是什么模型?", "我是{{NAME}},一个由{{AUTHOR}}开发的语言模型。"), - ("{{NAME}}是什么意思?", "{{NAME}}是{{AUTHOR}}开发的AI助手的名字。"), - ("你擅长什么?", "作为{{NAME}},我擅长对话交流、知识问答、内容创作等任务。"), + ('你好,请问你叫什么名字?', '你好!我是{{NAME}},很高兴认识你!'), + ('你是谁?', '我是{{NAME}},由{{AUTHOR}}开发的语言模型助手。'), + ('请介绍一下你自己。', '我是{{NAME}},由{{AUTHOR}}团队开发。我能够帮助用户解答各种问题,提供信息和建议。'), + ('你叫什么名字?是谁创造了你?', '我叫{{NAME}},是由{{AUTHOR}}创造的AI助手。'), + ('你好,{{NAME}}!', '你好!有什么我可以帮助你的吗?'), + ('你能做什么?', '我是{{NAME}},我可以回答问题、提供信息、帮助写作、编程等多种任务。'), + ('你的开发者是谁?', '我的开发者是{{AUTHOR}}团队。'), + ('你是什么模型?', '我是{{NAME}},一个由{{AUTHOR}}开发的语言模型。'), + ('{{NAME}}是什么意思?', '{{NAME}}是{{AUTHOR}}开发的AI助手的名字。'), + ('你擅长什么?', '作为{{NAME}},我擅长对话交流、知识问答、内容创作等任务。'), ] @@ -115,92 +113,101 @@ class SelfCognitionDataset(Dataset): """合成自我认知数据集,用于快速调试。""" def __init__(self, num_samples: int, model_name: str, author: str): - self.data: List[Dict[str, str]] = [] + self.data: list[dict[str, str]] = [] for i in range(num_samples): template = SELF_COG_TEMPLATES[i % len(SELF_COG_TEMPLATES)] - query = template[0].replace("{{NAME}}", model_name).replace("{{AUTHOR}}", author) - response = template[1].replace("{{NAME}}", model_name).replace("{{AUTHOR}}", author) - self.data.append({"query": query, "response": response}) + query = template[0].replace('{{NAME}}', model_name).replace('{{AUTHOR}}', author) + response = template[1].replace('{{NAME}}', model_name).replace('{{AUTHOR}}', author) + self.data.append({'query': query, 'response': response}) def __len__(self) -> int: return len(self.data) - def __getitem__(self, idx: int) -> Dict[str, str]: + def __getitem__(self, idx: int) -> dict[str, str]: return self.data[idx] -def tokenize_fn(batch: Dict[str, List], tokenizer, max_length: int): +def tokenize_fn(batch: dict[str, list], tokenizer, max_length: int): """将 query/response 拼接为 ChatML 格式并 tokenize。""" texts = [] - for q, r in zip(batch["query"], batch["response"]): + for q, r in zip(batch['query'], batch['response']): messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": q}, - {"role": "assistant", "content": r}, + { + 'role': 'system', + 'content': 'You are a helpful assistant.' + }, + { + 'role': 'user', + 'content': q + }, + { + 'role': 'assistant', + 'content': r + }, ] text = tokenizer.apply_chat_template(messages, tokenize=False) texts.append(text) tokenized = tokenizer( texts, - padding="max_length", + padding='max_length', truncation=True, max_length=max_length, - return_tensors="pt", + return_tensors='pt', ) # LM 标签: 预测 response 部分(忽略 query 的损失) - labels = tokenized["input_ids"].clone() + labels = tokenized['input_ids'].clone() # 简单策略: 对 <|im_start|>assistant 之后的部分计算损失 # 实际 fine-tuning 中可以用更精确的 mask,这里简化处理 for i, text in enumerate(texts): # 找到 assistant 回答的起始位置 - asst_marker = "<|im_start|>assistant" + asst_marker = '<|im_start|>assistant' asst_pos = text.find(asst_marker) if asst_pos >= 0: - prefix = text[: asst_pos + len(asst_marker)] - prefix_ids = tokenizer(prefix, add_special_tokens=False)["input_ids"] + prefix = text[:asst_pos + len(asst_marker)] + prefix_ids = tokenizer(prefix, add_special_tokens=False)['input_ids'] ignore_len = len(prefix_ids) labels[i, :ignore_len] = -100 # 忽略这些位置的损失 return { - "input_ids": tokenized["input_ids"], - "attention_mask": tokenized["attention_mask"], - "labels": labels, + 'input_ids': tokenized['input_ids'], + 'attention_mask': tokenized['attention_mask'], + 'labels': labels, } -def create_dataloader(cfg: Dict[str, Any]) -> DataLoader: +def create_dataloader(cfg: dict[str, Any]) -> DataLoader: """创建训练 DataLoader。""" from datasets import Dataset as HFDataset raw_dataset = SelfCognitionDataset( - num_samples=cfg["num_train_samples"], - model_name=cfg["self_cognition_name"], - author=cfg["self_cognition_author"], + num_samples=cfg['num_train_samples'], + model_name=cfg['self_cognition_name'], + author=cfg['self_cognition_author'], ) # 转为 HuggingFace Dataset 以便 map hf_dataset = HFDataset.from_list(raw_dataset.data) tokenizer = AutoTokenizer.from_pretrained( - cfg["model_name"], + cfg['model_name'], trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token hf_dataset = hf_dataset.map( - lambda batch: tokenize_fn(batch, tokenizer, cfg["max_length"]), + lambda batch: tokenize_fn(batch, tokenizer, cfg['max_length']), batched=True, batch_size=len(hf_dataset), - remove_columns=["query", "response"], + remove_columns=['query', 'response'], ) - hf_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) + hf_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) dataloader = DataLoader( hf_dataset, - batch_size=cfg["batch_size"], + batch_size=cfg['batch_size'], shuffle=True, drop_last=True, ) @@ -212,24 +219,17 @@ def create_dataloader(cfg: Dict[str, Any]) -> DataLoader: # --------------------------------------------------------------------------- -def train_with_tracker(cfg: Dict[str, Any]): +def train_with_tracker(cfg: dict[str, Any]): """主训练循环,集成 twinkle.tracker 的 dispatch 调用。""" # ---- 2a) 注册 SwanLabTracker ---- try: - from twinkle.tracker import ( - SwanLabTracker, - clear_trackers, - dispatch, - dispatch_hyperparams, - list_trackers, - register_tracker, - set_rank, - ) - print("[tracker] Successfully imported twinkle.tracker module") + from twinkle.tracker import (SwanLabTracker, clear_trackers, dispatch, dispatch_hyperparams, list_trackers, + register_tracker, set_rank) + print('[tracker] Successfully imported twinkle.tracker module') except ImportError as e: print(f"[tracker] WARNING: Cannot import twinkle.tracker: {e}") - print("[tracker] Will use mock tracker for testing") + print('[tracker] Will use mock tracker for testing') # 提供一个简易 mock 以便仍可运行 from unittest.mock import MagicMock register_tracker = lambda t: None @@ -246,30 +246,30 @@ def train_with_tracker(cfg: Dict[str, Any]): # 注册 SwanLab Tracker if SwanLabTracker is not None: swanlab_kwargs = {} - if cfg["swanlab_api_key"] and cfg["swanlab_api_key"] != "your-swanlab-api-key-here": - swanlab_kwargs["api_key"] = cfg["swanlab_api_key"] - if cfg["swanlab_experiment"]: - swanlab_kwargs["experiment_name"] = cfg["swanlab_experiment"] - if cfg["swanlab_mode"]: - swanlab_kwargs["mode"] = cfg["swanlab_mode"] + if cfg['swanlab_api_key'] and cfg['swanlab_api_key'] != 'your-swanlab-api-key-here': + swanlab_kwargs['api_key'] = cfg['swanlab_api_key'] + if cfg['swanlab_experiment']: + swanlab_kwargs['experiment_name'] = cfg['swanlab_experiment'] + if cfg['swanlab_mode']: + swanlab_kwargs['mode'] = cfg['swanlab_mode'] tracker = SwanLabTracker( - project=cfg["swanlab_project"], - output_dir=cfg["output_dir"], + project=cfg['swanlab_project'], + output_dir=cfg['output_dir'], **swanlab_kwargs, ) register_tracker(tracker) print(f"[tracker] Registered SwanLabTracker(project='{cfg['swanlab_project']}')") print(f"[tracker] Active trackers: {list_trackers()}") else: - print("[tracker] SwanLabTracker not available — skipping registration") + print('[tracker] SwanLabTracker not available — skipping registration') # ---- 2b) 准备模型 ---- print(f"\n[model] Loading {cfg['model_name']} ...") model = AutoModelForCausalLM.from_pretrained( - cfg["model_name"], + cfg['model_name'], torch_dtype=torch.bfloat16, - device_map="auto", + device_map='auto', trust_remote_code=True, ) model.train() @@ -281,34 +281,34 @@ def train_with_tracker(cfg: Dict[str, Any]): f"{len(dataloader)} batches (batch_size={cfg['batch_size']})") # ---- 2d) 优化器 & scheduler ---- - optimizer = AdamW(model.parameters(), lr=cfg["lr"]) - num_training_steps = cfg["max_steps"] + optimizer = AdamW(model.parameters(), lr=cfg['lr']) + num_training_steps = cfg['max_steps'] lr_scheduler = get_scheduler( - name=cfg["lr_scheduler_type"], + name=cfg['lr_scheduler_type'], optimizer=optimizer, - num_warmup_steps=cfg["warmup_steps"], + num_warmup_steps=cfg['warmup_steps'], num_training_steps=num_training_steps, ) # ---- 2e) Dispatch hyperparams(仅一次,幂等) ---- hparams = { - "model_name": cfg["model_name"], - "self_cognition_name": cfg["self_cognition_name"], - "self_cognition_author": cfg["self_cognition_author"], - "num_train_samples": cfg["num_train_samples"], - "max_length": cfg["max_length"], - "batch_size": cfg["batch_size"], - "gradient_accumulation_steps": cfg["gradient_accumulation_steps"], - "lr": cfg["lr"], - "lr_scheduler_type": cfg["lr_scheduler_type"], - "warmup_steps": cfg["warmup_steps"], - "max_steps": cfg["max_steps"], + 'model_name': cfg['model_name'], + 'self_cognition_name': cfg['self_cognition_name'], + 'self_cognition_author': cfg['self_cognition_author'], + 'num_train_samples': cfg['num_train_samples'], + 'max_length': cfg['max_length'], + 'batch_size': cfg['batch_size'], + 'gradient_accumulation_steps': cfg['gradient_accumulation_steps'], + 'lr': cfg['lr'], + 'lr_scheduler_type': cfg['lr_scheduler_type'], + 'warmup_steps': cfg['warmup_steps'], + 'max_steps': cfg['max_steps'], } - dispatch_hyperparams(hparams, adapter_name="default") + dispatch_hyperparams(hparams, adapter_name='default') print(f"[tracker] Dispatched hyperparameters ({len(hparams)} keys)") # ---- 2f) 训练循环 ---- - os.makedirs(cfg["output_dir"], exist_ok=True) + os.makedirs(cfg['output_dir'], exist_ok=True) global_step = 0 accum_loss = 0.0 accum_tokens = 0 @@ -319,25 +319,25 @@ def train_with_tracker(cfg: Dict[str, Any]): print(f"{'='*60}\n") epoch = 0 - while global_step < cfg["max_steps"]: + while global_step < cfg['max_steps']: epoch += 1 for batch in dataloader: - if global_step >= cfg["max_steps"]: + if global_step >= cfg['max_steps']: break batch = {k: v.to(model.device) for k, v in batch.items()} # Forward outputs = model(**batch) - loss = outputs.loss / cfg["gradient_accumulation_steps"] + loss = outputs.loss / cfg['gradient_accumulation_steps'] loss.backward() accum_loss += loss.item() - num_tokens = batch["attention_mask"].sum().item() + num_tokens = batch['attention_mask'].sum().item() accum_tokens += num_tokens # Gradient accumulation step - if (global_step + 1) % cfg["gradient_accumulation_steps"] == 0: + if (global_step + 1) % cfg['gradient_accumulation_steps'] == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() lr_scheduler.step() @@ -347,25 +347,23 @@ def train_with_tracker(cfg: Dict[str, Any]): current_lr = lr_scheduler.get_last_lr()[0] # ---- 2g) Dispatch metrics 到 tracker ---- - if global_step % cfg["logging_steps"] == 0: + if global_step % cfg['logging_steps'] == 0: metrics = { - "loss": accum_loss * cfg["gradient_accumulation_steps"], - "lr": current_lr, - "num_tokens": accum_tokens, - "grad_norm": 0.0, # 简化 - "epoch": epoch, - "step": global_step, + 'loss': accum_loss * cfg['gradient_accumulation_steps'], + 'lr': current_lr, + 'num_tokens': accum_tokens, + 'grad_norm': 0.0, # 简化 + 'epoch': epoch, + 'step': global_step, } # ★ 核心: 调用 tracker.dispatch — 和 twinkle 训练框架完全相同的调用方式 dispatch(metrics, step=global_step) - print( - f" step={global_step:>4d} | " - f"loss={metrics['loss']:.4f} | " - f"lr={current_lr:.2e} | " - f"tokens={accum_tokens}" - ) + print(f" step={global_step:>4d} | " + f"loss={metrics['loss']:.4f} | " + f"lr={current_lr:.2e} | " + f"tokens={accum_tokens}") # 重置累计值 accum_loss = 0.0 @@ -377,15 +375,15 @@ def train_with_tracker(cfg: Dict[str, Any]): print(f"{'='*60}") # 保存模型(可选) - save_path = Path(cfg["output_dir"]) / "final_model" + save_path = Path(cfg['output_dir']) / 'final_model' model.save_pretrained(str(save_path)) tokenizer.save_pretrained(str(save_path)) print(f"[model] Saved to {save_path}") # ★ 核心: 清理 tracker — 触发 swanlab.finish() - print("[tracker] Calling clear_trackers() ...") + print('[tracker] Calling clear_trackers() ...') clear_trackers() - print("[tracker] clear_trackers() done. Tracker lifecycle test complete.") + print('[tracker] clear_trackers() done. Tracker lifecycle test complete.') # --------------------------------------------------------------------------- @@ -395,61 +393,54 @@ def train_with_tracker(cfg: Dict[str, Any]): def test_tracker_module_directly(): """在不运行训练的情况下,单独测试 tracker 模块的核心 API。""" - print("\n" + "=" * 60) - print("UNIT TEST: tracker module API") - print("=" * 60) + print('\n' + '=' * 60) + print('UNIT TEST: tracker module API') + print('=' * 60) try: - from twinkle.tracker import ( - SwanLabTracker, - clear_trackers, - dispatch, - dispatch_hyperparams, - list_trackers, - register_tracker, - set_rank, - ) - print("[TEST] ✓ Import successful") + from twinkle.tracker import (SwanLabTracker, clear_trackers, dispatch, dispatch_hyperparams, list_trackers, + register_tracker, set_rank) + print('[TEST] ✓ Import successful') except ImportError as e: print(f"[TEST] ✗ Import failed: {e}") - print("[TEST] Skipping unit tests — run in container with twinkle installed") + print('[TEST] Skipping unit tests — run in container with twinkle installed') return # 1. set_rank set_rank(0) - print("[TEST] ✓ set_rank(0)") + print('[TEST] ✓ set_rank(0)') # 2. register_tracker (使用 local mode 避免网络依赖) - if os.environ.get("SWANLAB_API_KEY") or CFG["swanlab_api_key"] != "your-swanlab-api-key-here": + if os.environ.get('SWANLAB_API_KEY') or CFG['swanlab_api_key'] != 'your-swanlab-api-key-here': tracker = SwanLabTracker( - project=CFG["swanlab_project"] + "-unittest", - mode="cloud" if CFG["swanlab_mode"] == "cloud" else "local", - output_dir=CFG["output_dir"], + project=CFG['swanlab_project'] + '-unittest', + mode='cloud' if CFG['swanlab_mode'] == 'cloud' else 'local', + output_dir=CFG['output_dir'], ) register_tracker(tracker) print(f"[TEST] ✓ register_tracker(SwanLabTracker) — active: {len(list_trackers())}") else: - print("[TEST] ⚠ No SWANLAB_API_KEY — skip register_tracker (will test dispatch with empty tracker list)") + print('[TEST] ⚠ No SWANLAB_API_KEY — skip register_tracker (will test dispatch with empty tracker list)') # 3. dispatch_hyperparams (幂等测试) - dispatch_hyperparams({"test_param": 42, "model": "qwen-0.5b"}, adapter_name="test_adapter") - dispatch_hyperparams({"test_param": 42, "model": "qwen-0.5b"}, adapter_name="test_adapter") # 第二次应被幂等忽略 - print("[TEST] ✓ dispatch_hyperparams (idempotent)") + dispatch_hyperparams({'test_param': 42, 'model': 'qwen-0.5b'}, adapter_name='test_adapter') + dispatch_hyperparams({'test_param': 42, 'model': 'qwen-0.5b'}, adapter_name='test_adapter') # 第二次应被幂等忽略 + print('[TEST] ✓ dispatch_hyperparams (idempotent)') # 4. dispatch - test_metrics = {"loss": 0.123, "lr": 5e-5, "grad_norm": 0.5, "num_tokens": 256} + test_metrics = {'loss': 0.123, 'lr': 5e-5, 'grad_norm': 0.5, 'num_tokens': 256} dispatch(test_metrics, step=1) - dispatch({"loss": 0.098, "lr": 4.5e-5, "grad_norm": 0.3, "num_tokens": 128}, step=2) - print("[TEST] ✓ dispatch (2 steps)") + dispatch({'loss': 0.098, 'lr': 4.5e-5, 'grad_norm': 0.3, 'num_tokens': 128}, step=2) + print('[TEST] ✓ dispatch (2 steps)') # 5. clear_trackers → cleanup (触发 swanlab.finish) if list_trackers(): clear_trackers() print(f"[TEST] ✓ clear_trackers() — active after cleanup: {len(list_trackers())}") else: - print("[TEST] ⚠ No trackers to clean up (SKIP clear_trackers)") + print('[TEST] ⚠ No trackers to clean up (SKIP clear_trackers)') - print("[TEST] ✓ All unit tests completed\n") + print('[TEST] ✓ All unit tests completed\n') # --------------------------------------------------------------------------- @@ -459,50 +450,49 @@ def test_tracker_module_directly(): def test_idempotent_dispatch(): """验证 dispatch_hyperparams 的幂等守卫(同名 adapter 只发一次)。""" - print("=" * 60) - print("UNIT TEST: dispatch_hyperparams idempotency") - print("=" * 60) + print('=' * 60) + print('UNIT TEST: dispatch_hyperparams idempotency') + print('=' * 60) try: - from twinkle.tracker import _hparams_dispatched, dispatch_hyperparams - from twinkle.tracker import clear_trackers, list_trackers, register_tracker - from twinkle.tracker import set_rank, SwanLabTracker + from twinkle.tracker import (SwanLabTracker, _hparams_dispatched, clear_trackers, dispatch_hyperparams, + list_trackers, register_tracker, set_rank) except ImportError: - print("[TEST] Skipped (twinkle not available)") + print('[TEST] Skipped (twinkle not available)') return set_rank(0) - if os.environ.get("SWANLAB_API_KEY") or CFG["swanlab_api_key"] != "your-swanlab-api-key-here": + if os.environ.get('SWANLAB_API_KEY') or CFG['swanlab_api_key'] != 'your-swanlab-api-key-here': tracker = SwanLabTracker( - project=CFG["swanlab_project"] + "-idempotent", - mode="cloud" if CFG["swanlab_mode"] == "cloud" else "local", - output_dir=CFG["output_dir"], + project=CFG['swanlab_project'] + '-idempotent', + mode='cloud' if CFG['swanlab_mode'] == 'cloud' else 'local', + output_dir=CFG['output_dir'], ) register_tracker(tracker) # 清除幂等集合(模拟首次) _hparams_dispatched.clear() - assert "test_adapter_2" not in _hparams_dispatched + assert 'test_adapter_2' not in _hparams_dispatched - dispatch_hyperparams({"lr": 1e-4}, adapter_name="test_adapter_2") - assert "test_adapter_2" in _hparams_dispatched + dispatch_hyperparams({'lr': 1e-4}, adapter_name='test_adapter_2') + assert 'test_adapter_2' in _hparams_dispatched - dispatch_hyperparams({"lr": 1e-3}, adapter_name="test_adapter_2") # 应被忽略 - print("[TEST] ✓ dispatch_hyperparams idempotent guard works") + dispatch_hyperparams({'lr': 1e-3}, adapter_name='test_adapter_2') # 应被忽略 + print('[TEST] ✓ dispatch_hyperparams idempotent guard works') clear_trackers() else: - print("[TEST] ⚠ No SWANLAB_API_KEY — skip idempotent test") + print('[TEST] ⚠ No SWANLAB_API_KEY — skip idempotent test') - print("[TEST] ✓ Idempotency test completed\n") + print('[TEST] ✓ Idempotency test completed\n') # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- -if __name__ == "__main__": +if __name__ == '__main__': print(f"Python: {sys.version}") print(f"Torch: {torch.__version__}") print(f"CUDA: {torch.cuda.is_available()}") @@ -519,18 +509,18 @@ def test_idempotent_dispatch(): # ====================================================================== # 阶段二: 集成测试 — 真实训练 + tracker dispatch # ====================================================================== - print("=" * 60) - print("INTEGRATION TEST: training + tracker dispatch") - print("=" * 60) + print('=' * 60) + print('INTEGRATION TEST: training + tracker dispatch') + print('=' * 60) - if CFG["swanlab_api_key"] == "your-swanlab-api-key-here": - print("\n⚠ WARNING: SWANLAB_API_KEY 未设置!") + if CFG['swanlab_api_key'] == 'your-swanlab-api-key-here': + print('\n⚠ WARNING: SWANLAB_API_KEY 未设置!') print(" export SWANLAB_API_KEY='你的key'") print(" 或直接修改 debug_tracker.py 中的 CFG['swanlab_api_key']\n") - print(" 脚本仍会运行,但 tracker dispatch 会失败(mock 模式)\n") + print(' 脚本仍会运行,但 tracker dispatch 会失败(mock 模式)\n') train_with_tracker(CFG) - print("\n" + "=" * 60) - print("ALL TESTS COMPLETE") - print("=" * 60) + print('\n' + '=' * 60) + print('ALL TESTS COMPLETE') + print('=' * 60) diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py index 4dc1eb66..344ff499 100644 --- a/src/twinkle/model/optimizer_group.py +++ b/src/twinkle/model/optimizer_group.py @@ -90,10 +90,10 @@ def calculate_metrics(self, is_training): status.inputs = None status.outputs = None - # Dispatch to registered experiment trackers + # Dispatch to registered experiment trackers (adapter-aware) if is_training: from twinkle.tracker import dispatch, dispatch_hyperparams - dispatch(results, step=self.cur_step) + dispatch(results, step=self.cur_step, adapter_name=self.adapter_name) # Lazily log hyperparams on the first training metrics call dispatch_hyperparams( { diff --git a/src/twinkle/tracker/__init__.py b/src/twinkle/tracker/__init__.py index 4f23e639..b7abadbf 100644 --- a/src/twinkle/tracker/__init__.py +++ b/src/twinkle/tracker/__init__.py @@ -5,7 +5,12 @@ from twinkle.tracker import SwanLabTracker, register_tracker + # Global tracker — receives metrics from all adapters. register_tracker(SwanLabTracker(project="my-project")) + + # Per-adapter tracker — receives metrics only from a specific adapter. + register_tracker(SwanLabTracker(project="adapter-a"), adapter_name="lora_a") + # training loop unchanged — dispatch happens automatically. Or via environment variables (no code change):: @@ -28,7 +33,11 @@ # --------------------------------------------------------------------------- # Global state # --------------------------------------------------------------------------- -_trackers: List[ExperimentTracker] = [] +# Trackers that receive metrics from ALL adapters. +_global_trackers: List[ExperimentTracker] = [] +# Trackers that receive metrics only from a specific adapter. +# Key: adapter_name. Value: list of trackers. +_adapter_trackers: Dict[str, List[ExperimentTracker]] = {} _rank: int = 0 _hparams_dispatched: set = set() # track which adapters have sent hyperparams @@ -37,14 +46,23 @@ # --------------------------------------------------------------------------- -def register_tracker(tracker: ExperimentTracker) -> None: +def register_tracker(tracker: ExperimentTracker, adapter_name: Optional[str] = None) -> None: """Register an experiment tracker. + Args: + tracker: An ``ExperimentTracker`` instance. + adapter_name: If provided, the tracker receives metrics only + from the training loop of *adapter_name*. If ``None`` + (default), the tracker receives metrics from **all** adapters. + Multiple trackers can be registered — ``dispatch`` will send metric data to each one in order. Trackers are cleaned up automatically on normal interpreter exit via ``atexit``. """ - _trackers.append(tracker) + if adapter_name is not None: + _adapter_trackers.setdefault(adapter_name, []).append(tracker) + else: + _global_trackers.append(tracker) def set_rank(rank: int) -> None: @@ -58,9 +76,21 @@ def set_rank(rank: int) -> None: _rank = rank -def list_trackers() -> List[ExperimentTracker]: - """Return a snapshot of currently registered trackers.""" - return list(_trackers) +def list_trackers(adapter_name: Optional[str] = None) -> List[ExperimentTracker]: + """Return a snapshot of currently registered trackers. + + Args: + adapter_name: If provided, returns only trackers registered + for that specific adapter (plus global trackers). If + ``None``, returns all trackers. + """ + result = list(_global_trackers) + if adapter_name is not None: + result.extend(_adapter_trackers.get(adapter_name, [])) + else: + for ts in _adapter_trackers.values(): + result.extend(ts) + return result def clear_trackers() -> None: @@ -68,12 +98,16 @@ def clear_trackers() -> None: Registered automatically via ``atexit``; may also be called manually. """ - for t in _trackers: + all_trackers = list(_global_trackers) + for ts in _adapter_trackers.values(): + all_trackers.extend(ts) + for t in all_trackers: try: t.cleanup() except Exception: logger.warning('Tracker %s.cleanup() failed', type(t).__name__, exc_info=True) - _trackers.clear() + _global_trackers.clear() + _adapter_trackers.clear() # --------------------------------------------------------------------------- @@ -81,8 +115,20 @@ def clear_trackers() -> None: # --------------------------------------------------------------------------- -def dispatch(data: Dict[str, float], step: int) -> None: - """Send computed metrics to all registered trackers. +def _target_trackers(adapter_name: Optional[str] = None) -> List[ExperimentTracker]: + """Resolve the list of trackers that should receive data for *adapter_name*. + + Global trackers always receive data. If *adapter_name* is given, + per-adapter trackers for that name also receive data. + """ + result = list(_global_trackers) + if adapter_name is not None: + result.extend(_adapter_trackers.get(adapter_name, [])) + return result + + +def dispatch(data: Dict[str, float], step: int, adapter_name: Optional[str] = None) -> None: + """Send computed metrics to registered trackers. Metric values are normalized to ``float`` via :func:`clean_metrics` before dispatching. Only the rank-0 process performs the dispatch; @@ -91,8 +137,12 @@ def dispatch(data: Dict[str, float], step: int) -> None: Args: data: Raw metric dict (may contain strings, ints, floats). step: Current training step (``cur_step`` from optimizer group). + adapter_name: Optional adapter identifier. If provided, metrics + are sent to both global trackers and any trackers registered + specifically for this adapter. """ - if not _trackers: + targets = _target_trackers(adapter_name) + if not targets: return if _rank != 0: return @@ -101,7 +151,7 @@ def dispatch(data: Dict[str, float], step: int) -> None: if not cleaned: return - for tracker in _trackers: + for tracker in targets: try: tracker.log(cleaned, step=step) except Exception: @@ -109,7 +159,7 @@ def dispatch(data: Dict[str, float], step: int) -> None: def dispatch_hyperparams(params: Dict[str, Any], adapter_name: Optional[str] = None) -> None: - """Send hyperparameters to all registered trackers (call once at training start). + """Send hyperparameters to registered trackers (call once at training start). Idempotent per ``(adapter_name,)`` — repeated calls with the same *adapter_name* are silently ignored so that this can safely be called @@ -120,9 +170,12 @@ def dispatch_hyperparams(params: Dict[str, Any], adapter_name: Optional[str] = N params: Flat or nested dict of hyperparameters (e.g. model config, training args, LoRA config). adapter_name: Optional adapter identifier. If omitted, the params - are dispatched unconditionally on every call. + are dispatched unconditionally to global trackers on every + call. If provided, dispatched to both global and per-adapter + trackers, with idempotency guard. """ - if not _trackers or _rank != 0: + targets = _target_trackers(adapter_name) + if not targets or _rank != 0: return # Idempotency guard: only dispatch once per adapter @@ -131,7 +184,7 @@ def dispatch_hyperparams(params: Dict[str, Any], adapter_name: Optional[str] = N return _hparams_dispatched.add(adapter_name) - for tracker in _trackers: + for tracker in targets: try: tracker.log_hyperparams(params) except Exception: @@ -171,7 +224,7 @@ def _auto_init_from_env() -> None: for name in (t.strip().lower() for t in trackers_str.split(',') if t.strip()): try: if name == 'wandb': - _trackers.append( + _global_trackers.append( WandbTracker( project=project, experiment_name=experiment_name, @@ -179,7 +232,7 @@ def _auto_init_from_env() -> None: )) logger.info('Auto-registered WandbTracker from TWINKLE_TRACKERS env var') elif name == 'swanlab': - _trackers.append( + _global_trackers.append( SwanLabTracker( project=project, experiment_name=experiment_name, diff --git a/src/twinkle/tracker/swanlab.py b/src/twinkle/tracker/swanlab.py index 8ec429a2..8e5e1dfb 100644 --- a/src/twinkle/tracker/swanlab.py +++ b/src/twinkle/tracker/swanlab.py @@ -33,6 +33,7 @@ def __init__( output_dir: Optional[str] = None, **kwargs, ): + import swanlab api_key = kwargs.pop('api_key', None) or os.environ.get('SWANLAB_API_KEY') @@ -40,7 +41,13 @@ def __init__( mode = kwargs.pop('mode', None) or os.environ.get('SWANLAB_MODE', 'cloud') if api_key: - swanlab.login(api_key) + try: + swanlab.login(api_key) + except RuntimeError as _e: + if 'already called' in str(_e).lower() or 'after calling' in str(_e).lower(): + logger.debug('swanlab.login already called, skipping') + else: + raise self._run = swanlab.init( project=project, diff --git a/test_multilora_tracker.py b/test_multilora_tracker.py new file mode 100644 index 00000000..c62ce502 --- /dev/null +++ b/test_multilora_tracker.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +""" +MultiLoRA + SwanLab 跟踪测试脚本。 + +测试 dispatch(adapter_name=...) 路由: + - 全局 tracker 收到所有 adapter 的指标 + - metric 键名前缀 adapter 名,SwanLab 同图对比 + - dispatch_hyperparams per-adapter 幂等守卫 + +用法: + export SWANLAB_API_KEY="你的key" + python test_multilora_tracker.py +""" + +from __future__ import annotations + +import os +import sys +import torch +from pathlib import Path +from torch.optim import AdamW +from torch.utils.data import DataLoader, Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler + +sys.path.insert(0, str(Path('/workspace/twinkle/src'))) + +SWANLAB_API_KEY = os.environ.get('SWANLAB_API_KEY', '') +MODEL_NAME = 'Qwen/Qwen2.5-0.5B' + +ADAPTERS = { + 'lora_a': { + 'name': 'Alpha助手', + 'author': 'Alpha团队' + }, + 'lora_b': { + 'name': 'Beta助手', + 'author': 'Beta团队' + }, +} + +TRAIN_CFG = dict( + max_length=512, + batch_size=1, + steps_per_adapter=15, + lr=5e-5, + warmup_steps=2, + num_train_samples=20, +) + +TEMPLATES = [ + ('你好,请问你叫什么名字?', '你好!我是{{NAME}},很高兴认识你!'), + ('你是谁?', '我是{{NAME}},由{{AUTHOR}}开发的语言模型助手。'), +] + + +class SelfCogDataset(Dataset): + + def __init__(self, n, mn, au): + self.data = [] + for i in range(n): + q, r = TEMPLATES[i % len(TEMPLATES)] + self.data.append({ + 'query': q.replace('{{NAME}}', mn).replace('{{AUTHOR}}', au), + 'response': r.replace('{{NAME}}', mn).replace('{{AUTHOR}}', au), + }) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + +def setup_tracker(): + """注册单个全局 SwanLab tracker。""" + from twinkle.tracker import SwanLabTracker, register_tracker, set_rank + set_rank(0) + register_tracker( + SwanLabTracker( + project='twinkle-multilora-test', + mode='cloud', + api_key=SWANLAB_API_KEY, + output_dir='/tmp/multilora_test', + config={ + 'model': MODEL_NAME, + 'adapters': list(ADAPTERS.keys()) + }, + )) + from twinkle.tracker import list_trackers + print('[tracker] 1 global tracker -> twinkle-multilora-test') + print('[tracker] Adapters: lora_a=Alpha, lora_b=Beta') + + +def create_model_and_data(): + print(f"[model] Loading {MODEL_NAME} ...") + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.bfloat16, + device_map='auto', + trust_remote_code=True, + ) + model.train() + tok = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + from datasets import Dataset as HFDataset + + def _tok_fn(tok): + + def fn(batch): + texts = [] + for q, r in zip(batch['query'], batch['response']): + msgs = [ + { + 'role': 'system', + 'content': 'You are a helpful assistant.' + }, + { + 'role': 'user', + 'content': q + }, + { + 'role': 'assistant', + 'content': r + }, + ] + texts.append(tok.apply_chat_template(msgs, tokenize=False)) + o = tok( + texts, padding='max_length', truncation=True, max_length=TRAIN_CFG['max_length'], return_tensors='pt') + labels = o['input_ids'].clone() + for i, t in enumerate(texts): + pos = t.find('<|im_start|>assistant') + if pos >= 0: + n = len(tok(t[:pos + len('<|im_start|>assistant')], add_special_tokens=False)['input_ids']) + labels[i, :n] = -100 + return {'input_ids': o['input_ids'], 'attention_mask': o['attention_mask'], 'labels': labels} + + return fn + + dls = {} + for name, acfg in ADAPTERS.items(): + raw = SelfCogDataset(TRAIN_CFG['num_train_samples'], acfg['name'], acfg['author']) + hf = HFDataset.from_list(raw.data) + hf = hf.map(_tok_fn(tok), batched=True, batch_size=len(hf), remove_columns=['query', 'response']) + hf.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) + dls[name] = DataLoader(hf, batch_size=TRAIN_CFG['batch_size'], shuffle=True, drop_last=True) + print(f"[data] {name}: {TRAIN_CFG['num_train_samples']} samples") + return model, dls + + +def train_adapters(model, dls): + """交替训练两个 adapter,验证 dispatch 路由。""" + from twinkle.tracker import clear_trackers, dispatch, dispatch_hyperparams + + opts = {n: AdamW(model.parameters(), lr=TRAIN_CFG['lr']) for n in ADAPTERS} + scheds = { + n: + get_scheduler( + 'cosine', + opts[n], + num_warmup_steps=TRAIN_CFG['warmup_steps'], + num_training_steps=TRAIN_CFG['steps_per_adapter']) + for n in ADAPTERS + } + + print(f"\nTraining {len(ADAPTERS)} adapters x {TRAIN_CFG['steps_per_adapter']} steps") + + for name in ADAPTERS: + opt, sched, dl = opts[name], scheds[name], dls[name] + print(f"\n--- [{name}] {ADAPTERS[name]['name']} ---") + + # dispatch_hyperparams with adapter_name (幂等守卫: 第二次被忽略) + dispatch_hyperparams({f"{name}/lr": TRAIN_CFG['lr']}, adapter_name=name) + dispatch_hyperparams({f"{name}/IGNORED": True}, adapter_name=name) + print(" hyperparams idempotent OK") + + opt.zero_grad() + step = 0 + while step < TRAIN_CFG['steps_per_adapter']: + for batch in dl: + if step >= TRAIN_CFG['steps_per_adapter']: + break + b = {k: v.to(model.device) for k, v in batch.items()} + loss_val = model(**b).loss + loss_val.backward() + step += 1 + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + sched.step() + opt.zero_grad() + + # ★ dispatch with adapter_name — 路由到对应 tracker + metrics = { + f"{name}/loss": loss_val.item(), + f"{name}/lr": sched.get_last_lr()[0], + } + dispatch(metrics, step=step, adapter_name=name) + print(f" step={step:>3d} loss={loss_val.item():.4f}") + + print(f" [{name}] done ({step} steps)") + + print("\nclear_trackers() ...") + clear_trackers() + print("[tracker] Done.") + + +if __name__ == '__main__': + print(f"Python: {sys.version.split()[0]} | " + f"Torch: {torch.__version__} | " + f"CUDA: {torch.cuda.is_available()}") + if not SWANLAB_API_KEY: + print('[!] Need SWANLAB_API_KEY') + sys.exit(1) + setup_tracker() + model, dls = create_model_and_data() + train_adapters(model, dls) + print("\nDONE -> https://swanlab.cn/@supertpx/twinkle-multilora-test") + print(" lora_a/loss vs lora_b/loss on same chart") diff --git a/test_server_multilora.py b/test_server_multilora.py new file mode 100644 index 00000000..de34b5e7 --- /dev/null +++ b/test_server_multilora.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python +""" +Server-mode MultiLoRA + SwanLab 测试 (直接 Ray actor 路径, 跳过 HTTP 层)。 + +测试目标: + 1. 通过 TWINKLE_TRACKERS 环境变量在 Ray worker 中自动注册 SwanLabTracker + 2. Initialize(mode='ray') 下的 remote_class/dispatch 路径 + 3. 两个 LoRA adapter 交替训练,dispatch(adapter_name=...) 路由到 SwanLab + 4. 全局 project (twinkle-server-multilora) 收到所有指标,adapter 名前缀区分 + +启动: + export SWANLAB_API_KEY="你的key" + python test_server_multilora.py +""" + +from __future__ import annotations + +import os +import subprocess +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path('/workspace/twinkle/src'))) + +SWANLAB_API_KEY = os.environ.get('SWANLAB_API_KEY', '') +MODEL_NAME = 'Qwen/Qwen2.5-0.5B' + +TRAIN_CFG = dict(steps=10, batch_size=1, lr=5e-5, max_length=256, num_samples=10) + + +def main(): + if not SWANLAB_API_KEY: + print('[!] SWANLAB_API_KEY not set') + sys.exit(1) + + # ------------------------------------------------------------------ + # 1. 启动 Ray (单节点) + # ------------------------------------------------------------------ + print('[ray] Starting Ray head node ...') + subprocess.run( + [ + 'ray', 'start', '--head', '--port=6379', '--num-cpus=4', '--num-gpus=1', '--include-dashboard=false', + '--block' + ], + capture_output=True, + timeout=10, + ) + # ray start --block 在后台会 fork,这里用非阻塞方式 + result = subprocess.run( + ['ray', 'start', '--head', '--port=6379', '--num-cpus=4', '--num-gpus=1', '--include-dashboard=false'], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0 and 'already' not in result.stderr: + print(f"[ray] Already running or started: {result.stderr[:200]}") + else: + print("[ray] Ray head node ready") + + # ------------------------------------------------------------------ + # 2. 在 Ray worker 进程中启动训练 (包含 SwanLab tracker auto-init) + # ------------------------------------------------------------------ + # 通过 ray job submit 或在 Ray driver 中用 remote_class 来跑训练 + # 使用 remote function 在 Ray worker 中执行训练 + # ------------------------------------------------------------------ + + import ray + + if not ray.is_initialized(): + ray.init(address='auto', namespace='twinkle_test') + + print(f"[ray] Ray initialized, cluster resources: {ray.cluster_resources()}") + + # ------------------------------------------------------------------ + # 3. 在 Ray 中创建模型并训练 + # ------------------------------------------------------------------ + # 使用 @ray.remote 在 GPU worker 上启动训练进程 + # 通过 runtime_env 传递 SWANLAB_API_KEY 和 TWINKLE_TRACKERS + # ------------------------------------------------------------------ + + @ray.remote(num_gpus=1, max_retries=0) + def train_in_worker(): + """在 Ray worker 中运行训练。worker 进程会 import twinkle, + 触发 tracker._auto_init_from_env(),自动注册 SwanLabTracker。""" + + import os + os.environ['TWINKLE_TRACKERS'] = 'swanlab' + os.environ['SWANLAB_API_KEY'] = os.environ.get('SWANLAB_API_KEY', '') + os.environ['TWINKLE_TRACKER_PROJECT'] = 'twinkle-server-multilora' + os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0' + + # 这个 import 会触发 _auto_init_from_env() + from twinkle.tracker import list_trackers + print(f"[worker] Trackers after auto-init: {len(list_trackers())}") + import twinkle.tracker as tracker_mod + print(f"[worker] Global trackers: {len(tracker_mod._global_trackers)}") + + # 设置 rank + from twinkle.tracker import set_rank + set_rank(0) + + # 创建模型 + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + print(f"[worker] Loading {MODEL_NAME} ...") + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True) + model.train() + print(f"[worker] Model loaded, {model.num_parameters():,} params") + + # 准备数据和优化器 (2个 adapter) + from torch.optim import AdamW + from torch.utils.data import DataLoader, Dataset + from transformers import get_scheduler + + adapters = { + 'lora_a': { + 'name': 'Alpha助手', + 'author': 'Alpha团队' + }, + 'lora_b': { + 'name': 'Beta助手', + 'author': 'Beta团队' + }, + } + + dataloaders = {} + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + from datasets import Dataset as HFDataset + + templates = [ + ('你好,请问你叫什么名字?', '你好!我是{{NAME}},很高兴认识你!'), + ('你是谁?', '我是{{NAME}},由{{AUTHOR}}开发的语言模型助手。'), + ] + + class SelfCogDataset(Dataset): + + def __init__(self, n, mn, au): + self.data = [] + for i in range(n): + q, r = templates[i % len(templates)] + self.data.append({ + 'query': q.replace('{{NAME}}', mn).replace('{{AUTHOR}}', au), + 'response': r.replace('{{NAME}}', mn).replace('{{AUTHOR}}', au), + }) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _make_dl(name, self_name, self_author): + raw = SelfCogDataset(TRAIN_CFG['num_samples'], self_name, self_author) + hf = HFDataset.from_list(raw.data) + + def tok_fn(batch): + texts = [ + tokenizer.apply_chat_template([{ + 'role': 'system', + 'content': 'You are a helpful assistant.' + }, { + 'role': 'user', + 'content': q + }, { + 'role': 'assistant', + 'content': r + }], + tokenize=False) for q, r in zip(batch['query'], batch['response']) + ] + o = tokenizer( + texts, + padding='max_length', + truncation=True, + max_length=TRAIN_CFG['max_length'], + return_tensors='pt') + labels = o['input_ids'].clone() + for i, t in enumerate(texts): + pos = t.find('<|im_start|>assistant') + if pos >= 0: + n = len( + tokenizer(t[:pos + len('<|im_start|>assistant')], add_special_tokens=False)['input_ids']) + labels[i, :n] = -100 + return {'input_ids': o['input_ids'], 'attention_mask': o['attention_mask'], 'labels': labels} + + hf = hf.map(tok_fn, batched=True, batch_size=len(hf), remove_columns=['query', 'response']) + hf.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) + return DataLoader(hf, batch_size=TRAIN_CFG['batch_size'], shuffle=True, drop_last=True) + + for name, acfg in adapters.items(): + dataloaders[name] = _make_dl(name, acfg['name'], acfg['author']) + print(f"[worker] DataLoader for '{name}' ready") + + # 训练 + from twinkle.tracker import clear_trackers, dispatch, dispatch_hyperparams + + for name, acfg in adapters.items(): + dl = dataloaders[name] + opt = AdamW(model.parameters(), lr=TRAIN_CFG['lr']) + sched = get_scheduler('cosine', opt, num_warmup_steps=2, num_training_steps=TRAIN_CFG['steps']) + + # ★ dispatch_hyperparams with adapter_name (幂等) + dispatch_hyperparams({f"{name}/lr": TRAIN_CFG['lr'], f"{name}/self_name": acfg['name']}, adapter_name=name) + + print(f"\n[worker] Training adapter '{name}' ({acfg['name']}) ...") + opt.zero_grad() + step = 0 + while step < TRAIN_CFG['steps']: + for batch in dl: + if step >= TRAIN_CFG['steps']: + break + b = {k: v.to(model.device) for k, v in batch.items()} + loss_val = model(**b).loss + loss_val.backward() + step += 1 + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + sched.step() + opt.zero_grad() + + # ★ dispatch with adapter_name → SwanLab + metrics = {f"{name}/loss": loss_val.item(), f"{name}/lr": sched.get_last_lr()[0]} + dispatch(metrics, step=step, adapter_name=name) + print(f" [{name}] step={step:>3d} loss={loss_val.item():.4f}") + + print(f" [{name}] done") + + clear_trackers() + print("\n[worker] ALL DONE") + return f"Trained {list(adapters.keys())} adapters" + + # 启动 worker + print("\n" + "=" * 60) + print("Launching Ray worker with TWINKLE_TRACKERS=swanlab") + print("SwanLab project: twinkle-server-multilora") + print("=" * 60 + "\n") + + result = ray.get(train_in_worker.remote()) + print(f"\n[main] Worker result: {result}") + + # 清理 + ray.shutdown() + subprocess.run(['ray', 'stop', '--force'], capture_output=True, timeout=15) + print('[ray] Ray stopped.') + + print("\n" + "=" * 60) + print("SwanLab: https://swanlab.cn/@supertpx/twinkle-server-multilora") + print("Server-mode (Ray actor) MultiLoRA test complete") + print("=" * 60) + + +if __name__ == '__main__': + main() diff --git a/tests/tracker/test_dispatch.py b/tests/tracker/test_dispatch.py index c9ea3c9e..c60884be 100644 --- a/tests/tracker/test_dispatch.py +++ b/tests/tracker/test_dispatch.py @@ -44,10 +44,11 @@ sys.modules.setdefault('twinkle.utils.logger', MagicMock()) sys.modules.setdefault('swanlab', MagicMock()) -import twinkle.tracker as tracker_mod -from twinkle.tracker import clear_trackers, dispatch, dispatch_hyperparams, list_trackers, register_tracker, set_rank +import twinkle.tracker as tracker_mod # noqa: E402 +from twinkle.tracker import (clear_trackers, dispatch, dispatch_hyperparams, list_trackers, # noqa: E402 + register_tracker, set_rank) # Now safe to import -from twinkle.tracker.base import ExperimentTracker +from twinkle.tracker.base import ExperimentTracker # noqa: E402 # --------------------------------------------------------------------------- @@ -99,7 +100,8 @@ def cleanup(self) -> None: @pytest.fixture(autouse=True) def _reset_global_state(): """Reset module-level state before every test.""" - tracker_mod._trackers.clear() + tracker_mod._global_trackers.clear() + tracker_mod._adapter_trackers.clear() tracker_mod._rank = 0 tracker_mod._hparams_dispatched.clear() yield @@ -416,7 +418,8 @@ class TestAutoInitFromEnv: def _reset_auto_init(self): """Allow _auto_init_from_env to run again.""" tracker_mod._AUTO_INIT_DONE = False - tracker_mod._trackers.clear() + tracker_mod._global_trackers.clear() + tracker_mod._adapter_trackers.clear() def test_env_empty_is_noop(self): """No TWINKLE_TRACKERS → nothing registered.""" From e1b6a46612342c5920eed9a5ca6103929383bf13 Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 2 Jun 2026 22:14:50 +0800 Subject: [PATCH 7/8] apply PR#213 review fixes: lazy auto-init, dispatch_hyperparams idempotency, settings kwarg, mkdir Ultraworked with Sisyphus Co-authored-by: Sisyphus --- src/twinkle/tracker/__init__.py | 31 +++++++++++++++++++------------ src/twinkle/tracker/swanlab.py | 1 + src/twinkle/tracker/wandb.py | 2 +- tests/tracker/test_dispatch.py | 13 +++++++------ 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/twinkle/tracker/__init__.py b/src/twinkle/tracker/__init__.py index b7abadbf..01468012 100644 --- a/src/twinkle/tracker/__init__.py +++ b/src/twinkle/tracker/__init__.py @@ -84,6 +84,7 @@ def list_trackers(adapter_name: Optional[str] = None) -> List[ExperimentTracker] for that specific adapter (plus global trackers). If ``None``, returns all trackers. """ + _auto_init_from_env() result = list(_global_trackers) if adapter_name is not None: result.extend(_adapter_trackers.get(adapter_name, [])) @@ -127,7 +128,7 @@ def _target_trackers(adapter_name: Optional[str] = None) -> List[ExperimentTrack return result -def dispatch(data: Dict[str, float], step: int, adapter_name: Optional[str] = None) -> None: +def dispatch(data: Dict[str, Any], step: int, adapter_name: Optional[str] = None) -> None: """Send computed metrics to registered trackers. Metric values are normalized to ``float`` via :func:`clean_metrics` @@ -141,6 +142,7 @@ def dispatch(data: Dict[str, float], step: int, adapter_name: Optional[str] = No are sent to both global trackers and any trackers registered specifically for this adapter. """ + _auto_init_from_env() targets = _target_trackers(adapter_name) if not targets: return @@ -166,23 +168,26 @@ def dispatch_hyperparams(params: Dict[str, Any], adapter_name: Optional[str] = N from ``calculate_metrics`` on its first invocation without flooding trackers with redundant config updates. + When *adapter_name* is ``None`` (single-adapter / full fine-tuning), + a default guard key ``"_default_"`` is used so that hyperparams + are still dispatched only once. + Args: params: Flat or nested dict of hyperparameters (e.g. model config, training args, LoRA config). - adapter_name: Optional adapter identifier. If omitted, the params - are dispatched unconditionally to global trackers on every - call. If provided, dispatched to both global and per-adapter - trackers, with idempotency guard. + adapter_name: Optional adapter identifier. If provided, dispatched + to both global and per-adapter trackers, with idempotency guard. """ + _auto_init_from_env() targets = _target_trackers(adapter_name) if not targets or _rank != 0: return - # Idempotency guard: only dispatch once per adapter - if adapter_name is not None: - if adapter_name in _hparams_dispatched: - return - _hparams_dispatched.add(adapter_name) + # Idempotency guard: only dispatch once per adapter (or once globally) + guard_key = adapter_name if adapter_name is not None else '_default_' + if guard_key in _hparams_dispatched: + return + _hparams_dispatched.add(guard_key) for tracker in targets: try: @@ -245,8 +250,10 @@ def _auto_init_from_env() -> None: logger.warning("Failed to auto-init tracker '%s' from env", name, exc_info=True) -# Run auto-init once at import time (before user code or atexit runs) -_auto_init_from_env() +# Auto-init from environment variables runs lazily on first call to +# ``dispatch()``, ``dispatch_hyperparams()``, or ``list_trackers()``. +# This ensures ``set_rank()`` is called first in distributed training, +# preventing every rank from initializing its own tracking run. # --------------------------------------------------------------------------- # At-exit cleanup diff --git a/src/twinkle/tracker/swanlab.py b/src/twinkle/tracker/swanlab.py index 8e5e1dfb..c2e75187 100644 --- a/src/twinkle/tracker/swanlab.py +++ b/src/twinkle/tracker/swanlab.py @@ -80,6 +80,7 @@ def _save_experiment_info(self, output_dir: str) -> None: try: info = {'swanlab_experiment_url': self._run.get_run().url} out = Path(output_dir) / 'swanlab_config.json' + out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(info, indent=2)) except Exception: pass diff --git a/src/twinkle/tracker/wandb.py b/src/twinkle/tracker/wandb.py index 907a7f4b..ed4e371f 100644 --- a/src/twinkle/tracker/wandb.py +++ b/src/twinkle/tracker/wandb.py @@ -30,7 +30,7 @@ def __init__( import wandb entity = kwargs.pop('entity', None) or os.environ.get('WANDB_ENTITY') - settings = None + settings = kwargs.pop('settings', None) proxy = kwargs.pop('wandb_proxy', None) or os.environ.get('WANDB_PROXY') if proxy: settings = wandb.Settings(https_proxy=proxy) diff --git a/tests/tracker/test_dispatch.py b/tests/tracker/test_dispatch.py index c60884be..78165cf2 100644 --- a/tests/tracker/test_dispatch.py +++ b/tests/tracker/test_dispatch.py @@ -288,8 +288,8 @@ def test_different_adapters_separate(self): assert t.hyperparams[0] == {'lr': 1e-4} assert t.hyperparams[1] == {'lr': 2e-4} - def test_without_adapter_sends_every_time(self): - """When adapter_name is None, every call dispatches.""" + def test_without_adapter_sends_once(self): + """When adapter_name is None, only the first call dispatches (idempotent via _default_).""" t = SpyTracker() register_tracker(t) set_rank(0) @@ -298,7 +298,8 @@ def test_without_adapter_sends_every_time(self): dispatch_hyperparams({'lr': 2e-4}) dispatch_hyperparams({'lr': 3e-4}) - assert len(t.hyperparams) == 3 + assert len(t.hyperparams) == 1 + assert t.hyperparams[0] == {'lr': 1e-4} def test_mixed_adapter_and_no_adapter(self): """Calls with and without adapter_name interact correctly.""" @@ -307,11 +308,11 @@ def test_mixed_adapter_and_no_adapter(self): set_rank(0) dispatch_hyperparams({'a': 1}, adapter_name='adp') # sent - dispatch_hyperparams({'b': 2}) # sent (no adapter) + dispatch_hyperparams({'b': 2}) # sent (no adapter, first call) dispatch_hyperparams({'c': 3}, adapter_name='adp') # ignored (idempotent) - dispatch_hyperparams({'d': 4}) # sent (no adapter again) + dispatch_hyperparams({'d': 4}) # ignored (no adapter, idempotent via _default_) - assert len(t.hyperparams) == 3 + assert len(t.hyperparams) == 2 def test_skipped_on_non_zero_rank(self): t = SpyTracker() From d46462e5216fd921526c57cc379cd0e141e2a33b Mon Sep 17 00:00:00 2001 From: tpx Date: Tue, 2 Jun 2026 22:31:31 +0800 Subject: [PATCH 8/8] fix lint issues: E731/E226 in debug_tracker.py, E402 in tests, isort Ultraworked with Sisyphus Co-authored-by: Sisyphus --- debug_tracker.py | 44 ++++++++++++++++++++++++++-------- test_multilora_tracker.py | 10 ++++---- test_server_multilora.py | 20 ++++++++-------- tests/tracker/test_dispatch.py | 2 +- tests/tracker/test_swanlab.py | 2 +- 5 files changed, 51 insertions(+), 27 deletions(-) diff --git a/debug_tracker.py b/debug_tracker.py index 2a17c363..d8e3ccc8 100644 --- a/debug_tracker.py +++ b/debug_tracker.py @@ -232,12 +232,36 @@ def train_with_tracker(cfg: dict[str, Any]): print('[tracker] Will use mock tracker for testing') # 提供一个简易 mock 以便仍可运行 from unittest.mock import MagicMock - register_tracker = lambda t: None - dispatch = lambda data, step: print(f"[mock dispatch] step={step}, data={data}") - dispatch_hyperparams = lambda params, adapter_name=None: print(f"[mock hparams] {params}") - list_trackers = lambda: [] - set_rank = lambda r: None - clear_trackers = lambda: None + + def _mock_register(t): + None # noqa: E704 + + register_tracker = _mock_register + + def _mock_dispatch(data, step): + print(f"[mock dispatch] step={step}, data={data}") # noqa: E704 + + dispatch = _mock_dispatch + + def _mock_hparams(params, adapter_name=None): + print(f"[mock hparams] {params}") # noqa: E704 + + dispatch_hyperparams = _mock_hparams + + def _mock_list(): + return [] # noqa: E704 + + list_trackers = _mock_list + + def _mock_set_rank(r): + None # noqa: E704 + + set_rank = _mock_set_rank + + def _mock_clear(): + None # noqa: E704 + + clear_trackers = _mock_clear SwanLabTracker = None # 确保 rank 0 (单卡测试) @@ -314,9 +338,9 @@ def train_with_tracker(cfg: dict[str, Any]): accum_tokens = 0 optimizer.zero_grad() - print(f"\n{'='*60}") + print('\n' + '=' * 60) print(f"Starting training for {cfg['max_steps']} steps...") - print(f"{'='*60}\n") + print('=' * 60 + '\n') epoch = 0 while global_step < cfg['max_steps']: @@ -370,9 +394,9 @@ def train_with_tracker(cfg: dict[str, Any]): accum_tokens = 0 # ---- 2h) 清理 ---- - print(f"\n{'='*60}") + print('\n' + '=' * 60) print(f"Training complete ({global_step} steps)") - print(f"{'='*60}") + print('=' * 60) # 保存模型(可选) save_path = Path(cfg['output_dir']) / 'final_model' diff --git a/test_multilora_tracker.py b/test_multilora_tracker.py index c62ce502..861e8801 100644 --- a/test_multilora_tracker.py +++ b/test_multilora_tracker.py @@ -173,7 +173,7 @@ def train_adapters(model, dls): # dispatch_hyperparams with adapter_name (幂等守卫: 第二次被忽略) dispatch_hyperparams({f"{name}/lr": TRAIN_CFG['lr']}, adapter_name=name) dispatch_hyperparams({f"{name}/IGNORED": True}, adapter_name=name) - print(" hyperparams idempotent OK") + print(' hyperparams idempotent OK') opt.zero_grad() step = 0 @@ -200,9 +200,9 @@ def train_adapters(model, dls): print(f" [{name}] done ({step} steps)") - print("\nclear_trackers() ...") + print('\nclear_trackers() ...') clear_trackers() - print("[tracker] Done.") + print('[tracker] Done.') if __name__ == '__main__': @@ -215,5 +215,5 @@ def train_adapters(model, dls): setup_tracker() model, dls = create_model_and_data() train_adapters(model, dls) - print("\nDONE -> https://swanlab.cn/@supertpx/twinkle-multilora-test") - print(" lora_a/loss vs lora_b/loss on same chart") + print('\nDONE -> https://swanlab.cn/@supertpx/twinkle-multilora-test') + print(' lora_a/loss vs lora_b/loss on same chart') diff --git a/test_server_multilora.py b/test_server_multilora.py index de34b5e7..60f3ac24 100644 --- a/test_server_multilora.py +++ b/test_server_multilora.py @@ -56,7 +56,7 @@ def main(): if result.returncode != 0 and 'already' not in result.stderr: print(f"[ray] Already running or started: {result.stderr[:200]}") else: - print("[ray] Ray head node ready") + print('[ray] Ray head node ready') # ------------------------------------------------------------------ # 2. 在 Ray worker 进程中启动训练 (包含 SwanLab tracker auto-init) @@ -231,14 +231,14 @@ def tok_fn(batch): print(f" [{name}] done") clear_trackers() - print("\n[worker] ALL DONE") + print('\n[worker] ALL DONE') return f"Trained {list(adapters.keys())} adapters" # 启动 worker - print("\n" + "=" * 60) - print("Launching Ray worker with TWINKLE_TRACKERS=swanlab") - print("SwanLab project: twinkle-server-multilora") - print("=" * 60 + "\n") + print('\n' + '=' * 60) + print('Launching Ray worker with TWINKLE_TRACKERS=swanlab') + print('SwanLab project: twinkle-server-multilora') + print('=' * 60 + '\n') result = ray.get(train_in_worker.remote()) print(f"\n[main] Worker result: {result}") @@ -248,10 +248,10 @@ def tok_fn(batch): subprocess.run(['ray', 'stop', '--force'], capture_output=True, timeout=15) print('[ray] Ray stopped.') - print("\n" + "=" * 60) - print("SwanLab: https://swanlab.cn/@supertpx/twinkle-server-multilora") - print("Server-mode (Ray actor) MultiLoRA test complete") - print("=" * 60) + print('\n' + '=' * 60) + print('SwanLab: https://swanlab.cn/@supertpx/twinkle-server-multilora') + print('Server-mode (Ray actor) MultiLoRA test complete') + print('=' * 60) if __name__ == '__main__': diff --git a/tests/tracker/test_dispatch.py b/tests/tracker/test_dispatch.py index 78165cf2..4de1d641 100644 --- a/tests/tracker/test_dispatch.py +++ b/tests/tracker/test_dispatch.py @@ -44,10 +44,10 @@ sys.modules.setdefault('twinkle.utils.logger', MagicMock()) sys.modules.setdefault('swanlab', MagicMock()) +# isort: split import twinkle.tracker as tracker_mod # noqa: E402 from twinkle.tracker import (clear_trackers, dispatch, dispatch_hyperparams, list_trackers, # noqa: E402 register_tracker, set_rank) -# Now safe to import from twinkle.tracker.base import ExperimentTracker # noqa: E402 diff --git a/tests/tracker/test_swanlab.py b/tests/tracker/test_swanlab.py index b084996d..7d257cda 100644 --- a/tests/tracker/test_swanlab.py +++ b/tests/tracker/test_swanlab.py @@ -55,7 +55,7 @@ sys.modules.setdefault('swanlab', MagicMock()) # Now that all heavy deps are mocked, the import should succeed. -from twinkle.tracker.swanlab import SwanLabTracker +from twinkle.tracker.swanlab import SwanLabTracker # noqa: E402 # =================================================================== # Helpers