ML Experiment Tracking 源码分析
2026-04-03 | wandb v0.25.1
第一部分:架构概览
第二部分:核心模块
第三部分:功能体系
第四部分:进阶
Weights & Biases (W&B) 是机器学习实验跟踪和可视化平台,帮助研究者和工程师记录、比较和协作 ML 实验。
核心功能
关键数据
┌─────────────────────────────────────────────────────────────┐
│ 用户 API 层 │
│ wandb.init() │ wandb.log() │ wandb.finish() │ wandb.watch()│
├─────────────────────────────────────────────────────────────┤
│ Run 对象层 │
│ Run │ Config │ Summary │ Artifact │ Alert │
├─────────────────────────────────────────────────────────────┤
│ Interface 接口层 │
│ InterfaceBase │ InterfaceQueue │ MailboxHandle │
├─────────────────────────────────────────────────────────────┤
│ Backend 后端层 │
│ ServiceConnection │ Sender │ DataStore │ Streaming │
├─────────────────────────────────────────────────────────────┤
│ Protobuf 协议层 │
│ wandb_internal_pb2 │ wandb_settings_pb2 │ wandb_telemetry │
└─────────────────────────────────────────────────────────────┘
| 层级 | 职责 | 核心模块 |
|---|---|---|
| API 层 | 用户接口 & 延迟初始化 | wandb/__init__.py, preinit |
| Run 层 | 实验生命周期管理 | wandb_run.py, wandb_config.py, wandb_summary.py |
| Interface 层 | 消息传递抽象 | interface.py, interface_queue.py |
| Backend 层 | 后台服务通信 | backend.py, service_connection.py |
| Data 层 | 数据存储 & 传输 | datastore.py, streaming.py, sender.py |
| Proto 层 | 序列化协议 | wandb_internal_pb2, wandb_settings_pb2 |
wandb/
├── __init__.py # SDK 入口,公开 API
├── sdk/
│ ├── __init__.py # init/finish/login 导出
│ ├── wandb_init.py # _WandbInit 类,初始化逻辑
│ ├── wandb_run.py # Run 类,核心实验对象
│ ├── wandb_settings.py # Settings (Pydantic BaseModel)
│ ├── wandb_config.py # Config 配置管理
│ ├── wandb_summary.py # Summary 摘要管理
│ ├── wandb_login.py # 登录认证
│ ├── wandb_setup.py # 全局 Setup 单例
│ ├── backend/
│ │ └── backend.py # Backend 后端通信
│ ├── interface/
│ │ ├── interface.py # InterfaceBase 抽象类
│ │ └── interface_queue.py
│ ├── internal/
│ │ ├── datastore/ # 数据存储
│ │ └── handler/ # gRPC 请求处理
│ ├── artifacts/ # Artifact 版本管理
│ ├── data_types/ # Image/Table/Video 等
│ └── lib/ # 工具库
├── proto/ # Protobuf 定义
└── integration/ # 框架集成
# wandb/__init__.py (v0.25.1)
__version__ = "0.25.1"
# 早期配置
from wandb.errors.term import termsetup, termlog, termerror
from wandb.sdk.lib import wb_logging
_wb_logging.configure_wandb_logger()
# 导入 SDK 子包
from wandb import sdk as wandb_sdk
# 核心公开 API
init = wandb_sdk.init
finish = wandb_sdk.finish
login = wandb_sdk.login
setup = wandb_sdk.setup
# 数据类型
from wandb.data_types import (
Graph, Image, Plotly, Video, Audio, Table, Html,
Object3D, Molecule, Histogram, Classes
)
# 全局 Run 对象
run: Run | None = None
# PreInit 延迟对象
config = _preinit.PreInitObject("wandb.config", wandb_sdk.wandb_config.Config)
summary = _preinit.PreInitObject("wandb.summary", wandb_sdk.wandb_summary.Summary)
log = _preinit.PreInitCallable("wandb.log", Run.log)
核心设计:用户可以在 wandb.init() 之前调用 wandb.config 和 wandb.summary,SDK 使用 PreInitObject 暂存操作,init() 后代理到真实 Run 对象。
# wandb/sdk/lib/preinit.py
class PreInitObject:
"""代理对象,在 init() 之前暂存用户操作"""
def __init__(self, name, obj_type):
self._name = name
self._obj_type = obj_type
self._obj = None
self._preinit_calls = []
def __getattr__(self, key):
if self._obj is not None:
# init() 已完成,代理到真实对象
return getattr(self._obj, key)
# init() 之前,暂存调用
self._preinit_calls.append(('__getattr__', key))
return self
def _set(self, obj):
"""init() 完成后绑定真实对象"""
self._obj = obj
# 回放暂存的调用
for call in self._preinit_calls:
getattr(self._obj, call[1])
目的:避免导入时加载所有框架集成(PyTorch、TensorFlow、Keras 等),只在用户实际使用时才导入,减少启动时间。
# wandb/__init__.py
from wandb.sdk.lib import lazyloader
# 这些模块不会在 import wandb 时加载
keras = lazyloader.LazyLoader("wandb.keras", globals(),
"wandb.integration.keras")
sklearn = lazyloader.LazyLoader("wandb.sklearn", globals(),
"wandb.sklearn")
tensorflow = lazyloader.LazyLoader("wandb.tensorflow", globals(),
"wandb.integration.tensorflow")
xgboost = lazyloader.LazyLoader("wandb.xgboost", globals(),
"wandb.integration.xgboost")
catboost = lazyloader.LazyLoader("wandb.catboost", globals(),
"wandb.integration.catboost")
lightgbm = lazyloader.LazyLoader("wandb.lightgbm", globals(),
"wandb.integration.lightgbm")
sacred = lazyloader.LazyLoader("wandb.sacred", globals(),
"wandb.integration.sacred")
Settings 是基于 Pydantic BaseModel 的配置类,管理 W&B SDK 的所有行为参数。支持环境变量、配置文件和代码配置三种来源。
# wandb/sdk/wandb_settings.py
from pydantic import BaseModel, ConfigDict, Field
class Settings(BaseModel, validate_assignment=True):
model_config = ConfigDict(
extra="forbid", # 禁止额外字段
validate_default=True, # 验证默认值
use_attribute_docstrings=True,
revalidate_instances="always",
)
# 核心配置
entity: Optional[str] = None # W&B 实体
project: Optional[str] = None # 项目名
run_id: Optional[str] = None # 运行 ID
run_name: Optional[str] = None # 显示名称
base_url: str = "https://api.wandb.ai"
api_key: Optional[str] = None
...
| 字段 | 类型 | 说明 |
|---|---|---|
| entity | str? | W&B 用户/团队名 |
| project | str? | 项目名称 |
| run_id | str? | 唯一运行标识 |
| run_name | str? | 显示名称 |
| base_url | str | 后端 API 地址 |
| api_key | str? | 认证密钥 |
| mode | str | online/offline/disabled |
| resume | str? | allow/auto/never/must |
| sweep_id | str? | Sweep ID |
| save_code | bool | 是否保存代码 |
| heartbeat_seconds | int | 心跳间隔 (默认30s) |
| console | str | 控制台捕获模式 |
┌──────────────────────────────────────────────────────────┐
│ Settings 配置优先级链 │
├──────────────────────────────────────────────────────────┤
│ │
│ 1. 默认值 (Pydantic Field default) ← 最低 │
│ ↓ │
│ 2. 全局配置文件 (~/.config/wandb/settings) │
│ ↓ │
│ 3. 环境变量 (WANDB_ 前缀) │
│ 例如: WANDB_PROJECT=my-project │
│ ↓ │
│ 4. wandb.setup() 设置 │
│ ↓ │
│ 5. wandb.init() 参数 ← 最高 │
│ │
│ Settings 通过 update_from_settings() 和 │
│ update_from_dict() 逐层合并 │
└──────────────────────────────────────────────────────────┘
每层配置会覆盖上层的同名参数,确保用户代码中的参数拥有最高优先级
# 用户侧 API
import wandb
# 基本用法
run = wandb.init(
project="my-project", # 项目名
entity="my-team", # 团队名
config={ # 实验配置
"lr": 0.001,
"epochs": 10,
"batch_size": 32,
},
name="experiment-1", # 运行名称
tags=["baseline", "test"], # 标签
resume="allow", # 允许恢复
)
# 记录指标
wandb.log({"loss": 0.5, "accuracy": 0.9})
# 结束
wandb.finish()
# wandb/sdk/wandb_init.py
class _WandbInit:
def __init__(self, wl: _WandbSetup, telemetry):
self._wl = wl # 全局 Setup 单例
self._telemetry = telemetry
self.kwargs = None
self.run: Run | None = None
self.backend: Backend | None = None
self._teardown_hooks: list[TeardownHook] = []
self.notebook = None
def maybe_login(self, init_settings):
"""检查是否需要登录"""
run_settings = self._wl.settings.model_copy()
run_settings.update_from_settings(init_settings)
if not run_settings._noop and not run_settings._offline:
wandb_login._login(
host=run_settings.base_url,
force=run_settings.force,
key=init_settings.api_key,
)
def make_run_settings(self, init_settings):
"""合并全局设置和用户设置"""
settings = self._wl.settings.model_copy()
settings.update_from_settings(init_settings)
settings.x_start_time = time.time()
return settings
┌──────────────────────────────────────────────────────────┐
│ wandb.init() 执行流程 │
├──────────────────────────────────────────────────────────┤
│ 1. _WandbSetup.ensure_initialized() │
│ └→ 确保 wandb.setup() 已执行 │
│ │
│ 2. 构建 Settings │
│ └→ 合并默认值 + 环境变量 + 用户参数 │
│ │
│ 3. maybe_login() │
│ └→ 非离线模式下检查认证 │
│ │
│ 4. set_run_id() │
│ └→ 生成或恢复 run_id │
│ │
│ 5. 创建 Backend │
│ └→ Backend(settings, service) │
│ └→ backend.ensure_launched() │
│ │
│ 6. 创建 Run 对象 │
│ └→ Run(settings, backend.interface) │
│ │
│ 7. 发布 RunRecord │
│ └→ interface.publish_run(run) │
│ │
│ 8. 注册 teardown hooks │
│ └→ 确保 finish() 被调用 │
└──────────────────────────────────────────────────────────┘
Run 是 W&B SDK 的核心对象,代表一次实验运行。它管理实验的整个生命周期:配置、指标记录、摘要、Artifact 和结束。
# wandb/sdk/wandb_run.py
class Run:
"""核心实验对象"""
def __init__(self, settings, interface):
self._settings = settings
self._interface = interface
self._config = wandb_config.Config()
self._summary = wandb_summary.Summary()
self._start_time = time.time()
self._starting_step = 0
self._history = []
# 核心方法
def log(self, data, step=None, commit=None): ...
def watch(self, model, criterion=None, log="gradients", ...): ...
def unwatch(self): ...
def finish(self, exit_code=0, quiet=None): ...
def save(self, glob_str, base_path=None, policy="live"): ...
def alert(self, title, text, level="info"): ...
def define_metric(self, name, step_metric=None, ...): ...
def use_artifact(self, name, type=None): ...
def log_artifact(self, artifact, aliases=None): ...
| 类别 | 方法 | 说明 |
|---|---|---|
| 指标 | log() | 记录训练指标 |
| 配置 | config | 实验超参数 |
| 摘要 | summary | 最终结果摘要 |
| 模型 | watch/unwatch | 监控模型参数/梯度 |
| Artifact | log_artifact/use_artifact | 数据/模型版本 |
| 文件 | save() | 保存文件到 W&B |
| 告警 | alert() | 发送告警通知 |
| 指标定义 | define_metric() | 自定义 X 轴 |
| 生命周期 | finish() | 结束运行 |
# wandb/sdk/wandb_run.py (简化)
def log(self, data, step=None, commit=None, sync=None):
"""记录训练指标到历史数据"""
# 1. 验证数据
if not isinstance(data, dict):
raise UsageError("wandb.log must be called with a dict")
# 2. 处理 step
if step is not None:
self._step = step
# 3. 应用 define_metric 规则
data = self._apply_metric_overrides(data)
# 4. 序列化为 JSON
json_data = history_dict_to_json(data, self)
# 5. 构造 HistoryRecord (Protobuf)
record = pb.HistoryRecord()
record.item.json = json_data
record.item.step = self._step
record.item.timestamp.FromMicroseconds(
int(time.time() * 1e6)
)
# 6. 通过 Interface 发送
self._interface.publish_history(record)
# 7. 更新 step
if commit is not False:
self._step += 1
# wandb/sdk/wandb_config.py
class Config:
"""实验配置 - 使用字典式 API"""
def __init__(self):
self._data = {}
self._locked = False
def __setitem__(self, key, value):
if self._locked:
raise UsageError(
"Config is locked after init. "
"Use wandb.config.allow_val_change = True"
)
self._data[key] = value
# 通过 Interface 发送 ConfigRecord
self._callback("config", key, value)
def __getitem__(self, key):
return self._data[key]
def update(self, data):
for k, v in data.items():
self[k] = v
def as_dict(self):
return dict(self._data)
# 使用示例
wandb.config.lr = 0.001
wandb.config.update({"epochs": 10, "batch_size": 32})
# wandb/sdk/wandb_summary.py
class Summary:
"""实验摘要 - 运行结束时的最终指标"""
def __init__(self):
self._data = {}
def __setitem__(self, key, value):
self._data[key] = value
# 编码大对象(如图像)
encoded = self._summary_encode(value, key)
# 发送 SummaryRecord
self._callback("summary", key, encoded)
def update(self, data):
for k, v in data.items():
self[k] = v
# 使用示例
wandb.summary["best_accuracy"] = 0.95
wandb.summary["best_model_path"] = "models/best.pth"
Summary vs Log:Log 记录训练过程的时间序列数据;Summary 存储最终结果,用于跨实验比较。
# wandb/sdk/wandb_run.py
def watch(self, model, criterion=None, log="gradients",
log_freq=100, log_graph=False):
"""监控模型参数和梯度"""
hooks = []
# 1. 注册前向钩子(记录参数分布)
for name, param in model.named_parameters():
if log in ("gradients", "all") and param.requires_grad:
hook = param.register_backward_hook(
lambda grad, module_name=name: self._log_gradient(
module_name, grad
)
)
hooks.append(hook)
if log in ("parameters", "all"):
hook = param.register_forward_pre_hook(
lambda module, input, module_name=name:
self._log_parameter(module_name, module)
)
hooks.append(hook)
# 2. 可选:记录计算图
if log_graph:
self._log_model_graph(model, input_array)
return hooks
# wandb/sdk/wandb_run.py (简化)
def finish(self, exit_code=0, quiet=None):
"""结束实验运行"""
# 1. 检查是否已结束
if self._finished:
return
# 2. 最终 summary 上传
self.summary.update()
# 3. 发送退出信号
self._interface.publish_exit(
pb.ExitRecord(exit_code=exit_code)
)
# 4. 等待数据同步完成
self._interface.join(timeout=EXIT_TIMEOUT)
# 5. 执行 teardown hooks
for hook in self._teardown_hooks:
if hook.stage == TeardownStage.LATE:
hook.call()
# 6. 清理资源
self._finished = True
wandb.run = None
# 7. 恢复 console(如果被重定向)
self._console_cleanup()
RunStatusChecker 在后台运行三个守护线程,周期性检查运行状态、网络状态和内部消息。
# wandb/sdk/wandb_run.py
class RunStatusChecker:
def __init__(self, run_id, interface, settings,
stop_polling_interval=15,
retry_polling_interval=5,
internal_messages_polling_interval=10):
# 三个后台线程
self._stop_thread = Thread(target=self.check_stop_status,
daemon=True)
self._network_status_thread = Thread(
target=self.check_network_status, daemon=True)
self._internal_messages_thread = Thread(
target=self.check_internal_messages, daemon=True)
def check_stop_status(self):
"""检查服务器是否请求停止"""
# 如果 stop_status.run_should_stop 为 True
# 调用 interrupt.interrupt_main()
def check_network_status(self):
"""检查网络连接状态"""
# 处理 HTTP 重试和错误报告
# wandb/sdk/backend/backend.py
class Backend:
"""管理 SDK 与后台服务进程的通信"""
interface: InterfaceBase | None
_settings: Settings
_service: ServiceConnection | None
_done: bool
def __init__(self, settings, service=None):
self._done = False
self.interface = None
self._settings = settings
self._service = service
def ensure_launched(self):
"""启动后台服务(如果未运行)"""
assert self._settings.run_id
assert self._service
# 通过 ServiceConnection 创建 Interface
self.interface = self._service.make_interface(
stream_id=self._settings.run_id,
)
def cleanup(self):
"""清理资源"""
if self._done:
return
self._done = True
if self.interface:
self.interface.join()
架构:Backend 使用独立的 ServiceConnection 进程处理数据传输,避免阻塞用户训练循环。
# wandb/sdk/interface/interface.py
class InterfaceBase(abc.ABC):
"""消息发送抽象基类"""
@abc.abstractmethod
async def deliver_async(self, record: pb.Record):
"""异步发送 Record 到服务进程"""
def publish_run(self, run: Run):
"""发送运行信息"""
run_record = self._make_run(run)
self._publish_run(run_record)
def publish_history(self, history: pb.HistoryRecord):
"""发送历史数据(log 调用)"""
self._publish_history(history)
def publish_config(self, data, key=None, val=None):
"""发送配置更新"""
cfg = self._make_config(data=data, key=key, val=val)
self._publish_config(cfg)
def publish_summary(self, summary: pb.SummaryRecord):
"""发送摘要更新"""
def publish_exit(self, exit_record: pb.ExitRecord):
"""发送退出信号"""
def deliver_stop_status(self):
"""查询停止状态"""
def join(self, timeout=None):
"""等待所有消息发送完毕"""
InterfaceQueue 是 InterfaceBase 的具体实现,使用 MailboxHandle 将 Protobuf Record 异步发送到后台服务进程。
# wandb/sdk/interface/interface_queue.py (简化)
class InterfaceQueue(InterfaceBase):
def __init__(self, settings, record_q, result_q):
self._settings = settings
self._record_q = record_q # 发送队列
self._result_q = result_q # 响应队列
def _publish(self, record: pb.Record):
"""将 Record 放入发送队列"""
self._record_q.put(record)
def deliver(self, record: pb.Record) -> MailboxHandle:
"""发送并获取响应句柄"""
handle = MailboxHandle(self._result_q)
self._publish(record)
return handle
def _publish_run(self, run: pb.RunRecord):
record = pb.Record(run=run)
self._publish(record)
def _publish_history(self, history: pb.HistoryRecord):
record = pb.Record(history=history)
self._publish(record)
Mailbox 实现了 SDK 进程与后台服务进程之间的双向通信。使用 multiprocessing.Queue 进行跨进程消息传递。
# wandb/sdk/mailbox.py (简化)
class MailboxHandle(Generic[T]):
"""异步响应句柄"""
def wait(self, timeout=None) -> T:
"""阻塞等待响应"""
result = self._queue.get(timeout=timeout)
if result.HasField("error"):
raise CommError(result.error.message)
return result
def wait_or(self, timeout=30) -> T | None:
"""超时返回 None"""
try:
return self.wait(timeout=timeout)
except TimeoutError:
return None
def cancel(self):
"""取消等待"""
self._cancelled = True
# 通信模式
# SDK 进程 ──record_q──→ 服务进程
# SDK 进程 ←──result_q── 服务进程
Protobuf 是 SDK 与服务进程之间的序列化协议。所有数据(Run、History、Config、Summary 等)都序列化为 Protobuf Record。
// wandb/proto/wandb_internal.proto (简化)
message Record {
oneof record_type {
HeaderRecord header = 1;
RunRecord run = 2;
ConfigRecord config = 3;
SummaryRecord summary = 4;
HistoryRecord history = 5;
MetricRecord metric = 6;
ExitRecord exit = 7;
FilesRecord files = 8;
ArtifactRecord artifact = 9;
...
}
}
message HistoryRecord {
HistoryItem item = 1;
}
message HistoryItem {
string json = 1; // 指标 JSON
int64 step = 2; // 步数
google.protobuf.Timestamp timestamp = 3;
wall_time = 4;
}
DataStore 是后台服务中的数据存储层,接收来自 InterfaceQueue 的 Record 并按类型分类存储。
# wandb/sdk/internal/datastore/datastore.py (简化)
class DataStore:
"""后台进程中的数据存储"""
def __init__(self, settings):
self._settings = settings
self._history = [] # 历史数据列表
self._summary = {} # 摘要字典
self._config = {} # 配置字典
self._artifacts = {} # Artifact 存储
self._files = FilesDict() # 文件字典
def store_record(self, record: pb.Record):
"""根据 Record 类型分发存储"""
if record.HasField("history"):
self._store_history(record.history)
elif record.HasField("config"):
self._store_config(record.config)
elif record.HasField("summary"):
self._store_summary(record.summary)
elif record.HasField("artifact"):
self._store_artifact(record.artifact)
def _store_history(self, history: pb.HistoryRecord):
"""存储历史数据,按 metric 分组"""
self._history.append(history.item)
┌──────────────────────────────────────────────────────────┐
│ DataStore 写入流程 │
├──────────────────────────────────────────────────────────┤
│ │
│ SDK 进程 服务进程 │
│ ┌─────────┐ ┌──────────┐ │
│ │ run.log │ │ DataStore│ │
│ │ ({loss})│ │ │ │
│ └────┬────┘ │ ┌──────┐ │ │
│ │ │ │History│ │ │
│ │ pb.Record │ │ Buffer│ │ │
│ │ │ └──┬───┘ │ │
│ ├─────record_q──────→│ │ │ │
│ │ │ ┌──▼───┐ │ │
│ │ │ │Metric│ │ │
│ │ │ │ Index│ │ │
│ │ │ └──────┘ │ │
│ │ └────┬─────┘ │
│ │ │ │
│ │ ┌────▼─────┐ │
│ │ │ Streaming│ │
│ │ │ → Sender │ │
│ │ │ → Server │ │
│ │ └──────────┘ │
└──────────────────────────────────────────────────────────┘
# wandb/sdk/internal/datastore/streaming.py (简化)
class StreamingData:
"""流式数据传输 - 实时发送数据到服务器"""
def __init__(self, settings, sender):
self._sender = sender
self._buffer = []
self._flush_interval = settings.heartbeat_seconds
def add_history(self, history_item):
"""添加历史数据到缓冲区"""
self._buffer.append(history_item)
if len(self._buffer) >= self._batch_size:
self._flush()
def _flush(self):
"""批量发送缓冲区数据"""
if not self._buffer:
return
# 构造 Request
request = pb.CreateRunRequest()
for item in self._buffer:
req_item = request.history.item.add()
req_item.CopyFrom(item)
# 通过 Sender 发送到服务器
self._sender.send(request)
self._buffer.clear()
def flush_and_close(self):
"""结束前刷新所有数据"""
self._flush()
# wandb/sdk/internal/sender.py (简化)
class Sender:
"""网络发送器 - 负责与 W&B 服务器通信"""
def __init__(self, settings):
self._settings = settings
self._base_url = settings.base_url
self._api_key = settings.api_key
self._run_id = settings.run_id
def send(self, request):
"""发送请求到 W&B 服务器"""
try:
response = self._http_post(
f"{self._base_url}/graphql",
data=self._serialize(request),
headers=self._auth_headers(),
)
return response
except ConnectionError:
# 离线模式:保存到本地文件
self._save_offline(request)
def _auth_headers(self):
return {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/protobuf",
}
def upload_file(self, path, artifact_name):
"""上传文件(支持断点续传)"""
# 分片上传大文件
# 支持进度回调
Artifact 是 W&B 的数据版本管理系统,用于跟踪数据集、模型和文件的版本。类似 Git,但针对 ML 资产。
核心概念
使用方式
# wandb/sdk/artifacts/artifact.py (简化)
class Artifact:
"""版本化数据资产"""
def __init__(self, name, type=None, metadata=None,
description=None):
self._name = name # "project/artifact_name"
self._type = type # "dataset" / "model"
self._metadata = metadata # 自定义元数据
self._manifest = ArtifactManifest()
self._files = {} # 文件映射
def add_file(self, local_path, name=None):
"""添加文件到 Artifact"""
digest = self._compute_digest(local_path)
self._manifest.entries[name] = {
"digest": digest,
"size": os.path.getsize(local_path),
}
def add_dir(self, local_dir, name=None):
"""添加目录"""
def add_reference(self, uri, name=None):
"""添加外部引用 (S3/GCS/HTTP)"""
def download(self, root=None):
"""下载 Artifact 到本地"""
def checkout(self, root=None):
"""替换本地文件为 Artifact 版本"""
def save(self):
"""上传 Artifact 到 W&B"""
Sweep 是 W&B 的超参数搜索系统,支持 Bayesian、Grid 和 Random 搜索策略。
# Sweep 配置
sweep_config = {
"method": "bayes", # bayes / grid / random
"metric": {
"name": "val_loss",
"goal": "minimize"
},
"parameters": {
"lr": {"min": 0.0001, "max": 0.1, "distribution": "log_uniform"},
"batch_size": {"values": [16, 32, 64, 128]},
"epochs": {"value": 10},
"dropout": {"min": 0.1, "max": 0.5},
}
}
# 创建 Sweep
sweep_id = wandb.sweep(sweep_config, project="my-project")
# 运行 Agent
wandb.agent(sweep_id, function=train)
# wandb/wandb_controller.py (简化)
class SweepController:
"""Sweep 控制器 - 管理超参搜索"""
def __init__(self, sweep_id, config):
self._sweep_id = sweep_id
self._config = config
self._method = config["method"] # bayes/grid/random
def next_run(self):
"""生成下一组超参数"""
if self._method == "bayes":
return self._bayesian_next()
elif self._method == "grid":
return self._grid_next()
elif self._method == "random":
return self._random_next()
def _bayesian_next(self):
"""贝叶斯优化 - 基于历史结果选择参数"""
# 使用 Gaussian Process 建模
# 最大化 Expected Improvement
# 返回最优参数组合
def _grid_next(self):
"""网格搜索 - 遍历所有组合"""
def _random_next(self):
"""随机搜索 - 从分布中采样"""
# wandb/wandb_agent.py (简化)
class Agent:
"""Sweep Agent - 执行训练函数"""
def __init__(self, sweep_id, function):
self._sweep_id = sweep_id
self._function = function # 用户训练函数
self._controller = SweepController(sweep_id)
def run(self):
"""主循环"""
while not self._stopped:
# 1. 获取下一组参数
params = self._controller.next_run()
# 2. 初始化 wandb run
with wandb.init(config=params) as run:
# 3. 执行用户训练函数
self._function()
# 4. 上报结果
self._controller.report(run.summary)
def stop(self):
"""停止 Agent"""
self._stopped = True
# wandb.agent() 调用
def agent(sweep_id, function, count=None, project=None):
agent = Agent(sweep_id, function)
agent.run()
# wandb/integration/torch/wandb_torch.py (简化)
class WandbLogger:
"""PyTorch 自动日志集成"""
@staticmethod
def watch(model, criterion=None, log="gradients",
log_freq=100, log_graph=False):
"""自动监控 PyTorch 模型"""
hooks = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# 梯度钩子
if log in ("gradients", "all"):
def grad_hook(grad, n=name, p=param):
wandb.log({
f"gradients/{n}.histogram": wandb.Histogram(
grad.detach().cpu().numpy()
),
f"gradients/{n}.mean": grad.mean().item(),
}, commit=False)
hook = param.register_backward_hook(grad_hook)
hooks.append(hook)
# 参数钩子
if log in ("parameters", "all"):
def param_hook(module, inp, n=name):
wandb.log({
f"parameters/{n}.histogram": wandb.Histogram(
module.data.detach().cpu().numpy()
),
}, commit=False)
hook = param.register_forward_pre_hook(param_hook)
hooks.append(hook)
return hooks
# wandb/integration/keras/__init__.py (简化)
class WandbCallback(tf.keras.callbacks.Callback):
"""Keras 回调 - 自动记录训练指标"""
def __init__(self, monitor="val_loss", save_model=True):
self._monitor = monitor
self._save_model = save_model
self._best = float('inf')
def on_epoch_end(self, epoch, logs=None):
"""每个 epoch 结束时自动记录"""
# 记录所有指标
wandb.log(logs, step=epoch)
# 检查是否是最佳模型
current = logs.get(self._monitor, 0)
if current < self._best:
self._best = current
if self._save_model:
# 保存最佳模型为 Artifact
artifact = wandb.Artifact(
f"model-{wandb.run.id}",
type="model",
metadata={"epoch": epoch, self._monitor: current}
)
artifact.add_file(self.model.save("model.keras"))
wandb.log_artifact(artifact, aliases=["best"])
def on_train_end(self, logs=None):
"""训练结束时上传最终模型"""
wandb.finish()
# W&B 与 Hugging Face Transformers 的集成
from transformers import TrainingArguments, Trainer
import wandb
# 1. 初始化 W&B
wandb.init(project="hf-finetune", config={
"model": "bert-base-uncased",
"dataset": "imdb",
})
# 2. TrainingArguments 自动集成 W&B
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb", # 关键:启用 W&B 日志
run_name="bert-finetune",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
logging_steps=100,
save_strategy="epoch",
load_best_model_at_end=True,
)
# 3. Trainer 自动记录
# - train/loss, train/learning_rate
# - eval/loss, eval/accuracy
# - 系统指标 (GPU/CPU)
# - 模型 checkpoint
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
Media 数据类型:W&B 支持丰富的数据类型用于可视化,所有类型都继承自基类并通过 JSON 序列化传输。
| 类型 | 用途 | 示例 |
|---|---|---|
| Image | 图像可视化 | wandb.Image(pil_img, caption="") |
| Video | 视频记录 | wandb.Video(video_path, fps=4) |
| Audio | 音频可视化 | wandb.Audio(audio_path, sample_rate=16000) |
| Plotly | 交互式图表 | wandb.Plotly(fig) |
| Html | HTML 内容 | wandb.Html(html_string) |
| Object3D | 3D 点云 | wandb.Object3D(points_array) |
| Molecule | 分子结构 | wandb.Molecule(smiles_string) |
| Histogram | 分布直方图 | wandb.Histogram(numpy_array) |
# wandb/data_types/table.py (简化)
# 创建 Table
table = wandb.Table(
columns=["image", "prediction", "actual", "confidence"],
data=[
[wandb.Image(img), "cat", "cat", 0.95],
[wandb.Image(img2), "dog", "cat", 0.02],
...
]
)
# 记录到 W&B
wandb.log({"predictions": table})
# 或者使用 ClassMap 进行错误分析
table = wandb.Table(
columns=["id", "image", "pred", "truth"],
data=validation_data
)
wandb.log({"val_results": table})
# Table 支持的操作
# - 排序、过滤、分组
# - 自定义列
# - 嵌套 Table
# - 与 Artifact 关联
# 自定义可视化面板
# 1. 使用 wandb.plot
wandb.plot.line_series(
xs=[1, 2, 3, 4],
ys=[[10, 20, 30, 40], [15, 25, 35, 45]],
keys=["model_a", "model_b"],
title="Training Loss Comparison",
xname="Step",
)
# 2. 使用 wandb.plot_table
wandb.plot_table(
"test_table",
wandb.Table(
columns=["step", "loss"],
data=[[1, 0.5], [2, 0.3], [3, 0.1]]
),
fields={"x": "step", "y": "loss"},
)
# 3. 使用 Visualize
wandb.visualize("custom_viz", {
"type": "scatter",
"values": [{"x": i, "y": i**2} for i in range(100)]
})
Telemetry 收集匿名的使用数据,帮助 W&B 团队了解 SDK 使用模式,改善产品体验。
# wandb/proto/wandb_telemetry.proto (部分)
message TelemetryRecord {
bool python_version = 1;
bool cuda = 2;
bool tensorflow = 3;
bool torch = 4;
bool keras = 5;
bool jupyter = 6;
bool amazon_ec2 = 7;
bool google_cloud = 8;
bool azure = 9;
// 功能使用
FeatureTelemetry feature = 20;
int32 log_count = 21;
int32 artifact_count = 22;
}
# 在 init() 时自动收集
# - Python 版本、OS、CPU/GPU 信息
# - 使用的 ML 框架
# - 功能使用频率
# - 错误和异常
┌──────────────────────────────────────────────────────────┐
│ wandb.login() 认证流程 │
├──────────────────────────────────────────────────────────┤
│ │
│ 1. 检查环境变量 WANDB_API_KEY │
│ └→ 有:直接使用 │
│ └→ 无:继续 │
│ │
│ 2. 检查凭据文件 ~/.netrc │
│ └→ 有:读取 API key │
│ └→ 无:继续 │
│ │
│ 3. 检查 wandb/credentials 文件 │
│ └→ 有:读取临时 access token │
│ └→ 无:继续 │
│ │
│ 4. 交互式登录 │
│ └→ 打开浏览器,用户输入 API key │
│ └→ 保存到凭据文件 │
│ │
│ 5. 验证 API key │
│ └→ 调用 server_status 接口验证 │
│ └→ 获取用户信息 │
└──────────────────────────────────────────────────────────┘
# 启用离线模式
# 方式 1: 代码设置
wandb.init(mode="offline")
# 方式 2: 环境变量
# WANDB_MODE=offline
# 离线模式行为:
# 1. 所有数据保存到本地文件系统
# └→ wandb/offline-run-xxx/
# ├── files/
# │ ├── wandb-history.jsonl
# │ ├── wandb-summary.json
# │ └── wandb-config.yaml
# └── wandb-metadata.json
# 2. 网络请求被跳过
# 3. 后台服务进程不启动
# 4. sync 后可上传
# 同步离线数据到服务器
# wandb sync wandb/offline-run-xxx/
使用场景:无网络环境的集群训练、受限网络环境、临时网络中断。
# W&B 支持多种分布式训练框架
# 1. DDP (PyTorch DistributedDataParallel)
wandb.init(project="ddp-training")
# 2. 多进程设置
# WANDB_RUN_GROUP=my-experiment
# WANDB_RUN_ID=shared-run-id (所有进程共享)
# WANDB_MODE=shared (共享模式)
# 共享模式下:
# - 所有进程写入同一个 Run
# - 每个进程有唯一的 x_label
# - 系统指标(GPU/CPU)按进程分开记录
# - 用户指标通过步数同步
# 3. Ray 集成
from ray.air.integrations.wandb import WandbLoggerCallback
callback = WandbLoggerCallback(
project="ray-training",
log_config=True,
)
# 4. SLURM 集成
# 自动检测 SLURM 环境变量
# 设置 run_group 和 run_name
# W&B 自动检测 Jupyter 环境
# 自动功能:
# 1. 在 Notebook 中渲染 Run 面板
# └→ 显示实时训练曲线
# └→ 显示系统资源使用
#
# 2. 自动记录 Notebook 代码
# └→ 保存 .ipynb 文件
#
# 3. 自动记录 Cell 输出
# └→ 捕获 print() 和图像输出
# IPython 魔法命令
%wandb # 显示当前 Run 面板
%wandb reset # 重置 Run 状态
# Notebook 扩展
# 检测 ipython 环境
if wandb_sdk.lib.ipython.in_notebook():
from IPython import get_ipython
jupyter._load_ipython_extension(get_ipython())
# 自动设置:
# - console = "wrap" (捕获输出)
# - notebook = True (启用 Jupyter 功能)
# - save_code = True (保存 Notebook)
┌──────────────────────────────────────────────────────────┐
│ W&B 核心数据流 │
├──────────────────────────────────────────────────────────┤
│ │
│ 用户代码 │
│ ┌────────────────────────────────────────────┐ │
│ │ wandb.log({"loss": 0.5}) │ │
│ └──────────────────┬─────────────────────────┘ │
│ │ │
│ ▼ │
│ Run.log() │ │
│ ┌──────────────────┴─────────────────────────┐ │
│ │ 1. 验证 & 序列化 → JSON │ │
│ │ 2. 构建 HistoryRecord (Protobuf) │ │
│ │ 3. interface.publish_history() │ │
│ └──────────────────┬─────────────────────────┘ │
│ │ │
│ ▼ │
│ InterfaceQueue │ │
│ ┌──────────────────┴─────────────────────────┐ │
│ │ record_q.put(pb.Record(history=...)) │ │
│ └──────────────────┬─────────────────────────┘ │
│ │ multiprocessing.Queue │
│ ▼ │
│ 后台服务进程 │ │
│ ┌──────────────────┴─────────────────────────┐ │
│ │ DataStore → Streaming → Sender → W&B Server│ │
│ └────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────┘
| 模式 | 应用 | 位置 |
|---|---|---|
| 单例模式 | 全局唯一 Setup 和 Run 对象 | _WandbSetup, wandb.run |
| 代理模式 | PreInitObject 延迟代理 | preinit.py |
| 观察者模式 | Config/Summary 变更通知 | _callback 机制 |
| 生产者-消费者 | SDK 进程与服务进程通信 | record_q / result_q |
| 策略模式 | Sweep 搜索算法切换 | SweepController |
| 门面模式 | 统一 API 入口 | wandb/__init__.py |
| 延迟加载 | 框架集成按需导入 | LazyLoader |
推荐做法
性能建议
核心要点
架构亮点
W&B SDK 是一个精心设计的 ML 实验管理工具,通过分层架构、多进程通信和丰富的数据类型,实现了对 ML 实验全生命周期的追踪和管理。
源码地址
https://github.com/wandb/wandb
访问链接: https://atcfu.com/ai-articles/wandb-experiment-tracking/