基于源码深度解析
2026-04-03 | Data Version Control
第一部分:基础概念
第二部分:核心系统
第三部分:深入实现
第四部分:进阶特性
DVC (Data Version Control) 是一个开源的数据科学项目管理工具,建立在 Git 之上,专门用于管理大型数据集、模型文件和实验结果。
核心功能
主要优势
┌─────────────────────────────────────────────────────────┐
│ DVC Architecture │
├─────────────────────────────────────────────────────────┤
│ CLI Layer │ Commands │ Parser │ Argument Validation │
├─────────────────────────────────────────────────────────┤
│ Repo Layer │ Stage │ Output │ Input │ Pipeline │
├─────────────────────────────────────────────────────────┤
│ System Layer │ Cache │ FS │ Cloud │ Lock │
├─────────────────────────────────────────────────────────┤
│ Storage Layer │ Local │ Git │ Remote │ State │
└─────────────────────────────────────────────────────────┘
DVC 采用分层架构:CLI 层 → Repo 层 → 系统层 → 存储层
| 层级 | 职责 | 核心模块 |
|---|---|---|
| CLI 层 | 命令行接口 | commands, parser, CLI |
| Repo 层 | 仓库管理 | Repo, Stage, Pipeline |
| 系统层 | 系统服务 | Cache, FS, Cloud, Lock |
| 存储层 | 数据存储 | Local, Git, Remote, State |
| 特性 | Git | DVC |
|---|---|---|
| 管理对象 | 代码文件 | 数据文件/模型 |
| 存储方式 | 完整文件存储 | 缓存 + 链接 |
| 空间效率 | 重复存储 | 去重存储 |
| 远程存储 | Git Remote | 云存储适配器 |
| 实验追踪 | 分支/标签 | Metrics/Plots |
| 流水线 | Makefile/脚本 | Stage/Pipeline |
基本概念
高级概念
class Repo:
DVC_DIR = ".dvc"
def __init__(self, root_dir=None, fs=None, rev=None,
config=None, url=None, **kwargs):
# 初始化文件系统
self._fs = fs or LocalFileSystem()
self.root_dir = self.find_root(root_dir)
self.dvc_dir = fs.join(self.root_dir, self.DVC_DIR)
# 初始化核心组件
self.cloud = DataCloud(self) # 云存储集成
self.stage = StageLoad(self) # Stage 管理
self.cache = CacheManager(self) # 缓存管理
self.state = State(self) # 状态管理
self.lock = make_lock(...) # 锁机制
# 初始化功能模块
self.metrics = Metrics(self) # 指标系统
self.plots = Plots(self) # 图表系统
self.params = Params(self) # 参数系统
self.artifacts = Artifacts(self) # 工件系统
Repo 是 DVC 的核心类,负责协调整个仓库的所有组件和功能
from dvc.fs import LocalFileSystem, GitFileSystem, DVCFileSystem
# 文件系统层次结构
FileSystem
├── LocalFileSystem # 本地文件系统
├── GitFileSystem # Git 仓库文件系统
├── DVCFileSystem # DVC 专用文件系统
├── RemoteFileSystem # 远程文件系统
└── DataFileSystem # 数据索引文件系统
设计亮点:统一的文件系统接口,支持本地、Git、远程存储的透明访问
class CacheManager:
def __init__(self, repo):
self.repo = repo
self.default_local_cache_dir = None
self.caches = []
self.schemes = {}
def get(self, scheme, url=None, **kwargs):
# 获取或创建缓存实例
if scheme not in self.schemes:
cache_class = self._get_cache_class(scheme)
cache = cache_class(self.repo, url, **kwargs)
self.schemes[scheme] = cache
return self.schemes[scheme]
def save(self, path, fs=None):
# 保存文件到缓存
if fs:
scheme = fs.scheme
else:
scheme = self.repo.fs.scheme
cache = self.get(scheme)
return cache.save(path)
def load(self, path, fs=None, **kwargs):
# 从缓存加载文件
# ...
数据管道是 DVC 的核心概念,定义了数据从输入到输出的处理流程,每个 Stage 代表一个处理步骤。
class Pipeline:
def __init__(self, repo):
self.repo = repo
self.stages = {}
self.graph = {}
self.cache = {}
def add_stage(self, stage):
self.stages[stage.path] = stage
self._update_graph(stage)
def reproduce(self, targets=None, force=False, dry=False):
# 重新执行管道中的特定步骤
stages = self._resolve_targets(targets)
return self._execute_stages(stages, force, dry)
def _update_graph(self, stage):
# 更新依赖图
self.graph[stage.path] = stage.deps
class Stage:
def __init__(self, repo, path, **kwargs):
self.repo = repo
self.path = path
self.outs = [] # 输出文件列表
self.deps = [] # 依赖文件列表
self.cmd = None # 命令行指令
self.cwd = None # 工作目录
self.params = {} # 参数配置
def run(self):
# 执行 Stage 逻辑
if self.cmd:
return self._execute_command()
else:
return self._reproduce()
def save(self):
# 保存当前状态
self._save_outputs()
self._save_state()
def check_graph(self):
# 检查依赖图的有效性
self._check_cycles()
self._check_dependencies()
class Index:
def __init__(self, repo):
self.repo = repo
self.stages = {} # Stage 索引
self.outs = {} # 输出索引
self.deps = {} # 依赖索引
self.data = {} # 数据索引
def from_repo(cls, repo):
# 从仓库创建索引
index = cls(repo)
index._load_stages()
index._build_graph()
return index
def _load_stages(self):
# 加载所有 DVC 文件
for stage_path in self._find_dvc_files():
stage = self.repo.stage.load(stage_path)
self.stages[stage_path] = stage
def _build_graph(self):
# 构建依赖图
for stage in self.stages.values():
for out in stage.outs:
self.outs[out.fs_path] = out
for dep in stage.deps:
self.deps[dep.fs_path] = dep
def used_objs(self, targets=None, **kwargs):
# 获取使用的对象
# ...
class Output:
def __init__(self, stage, fs_path, **kwargs):
self.stage = stage
self.fs_path = fs_path
self.hash = None
self.is_data_source = False
self.use_cache = True
self.persist = False
def save(self):
# 保存输出文件
if self.use_cache:
self.hash = self._calculate_hash()
self._save_to_cache()
self._save_meta()
def _calculate_hash(self):
# 计算文件哈希
fs = self.fs
if fs.isdir(self.fs_path):
return self._hash_dir()
else:
return self._hash_file()
def _save_to_cache(self):
# 保存到缓存
if self.hash:
cache_path = self.repo.cache.local.get_hash_path(self.hash)
self.repo.cache.save(self.fs_path, cache_path)
class Input:
def __init__(self, stage, fs_path, **kwargs):
self.stage = stage
self.fs_path = fs_path
self.hash = None
self.is_data_source = False
self.use_cache = True
def load(self):
# 加载输入文件
if self.use_cache and self.hash:
cache_path = self.repo.cache.local.get_hash_path(self.hash)
if self.repo.cache.exists(cache_path):
self._link_from_cache(cache_path)
return
# 从远程获取
if not self._exists():
self._fetch_from_remote()
def _link_from_cache(self, cache_path):
# 从创建链接
if self.repo.cache.local.cache_type == 'copy':
self.fs.copy(cache_path, self.fs_path)
elif self.repo.cache.local.cache_type == 'reflink':
self.fs.reflink(cache_path, self.fs_path)
else:
# 默认硬链接
self.fs.link(cache_path, self.fs_path)
class Dependency:
def __init__(self, stage, fs_path, **kwargs):
self.stage = stage
self.fs_path = fs_path
self.hash = None
self.is_data_source = kwargs.get('is_data_source', False)
self.use_cache = kwargs.get('use_cache', True)
def changed(self):
# 检查依赖是否发生变化
current_hash = self._calculate_hash()
if self.hash is None:
return True
return current_hash != self.hash
def update(self):
# 更新依赖状态
self.hash = self._calculate_hash()
def _calculate_hash(self):
# 计算哈希值
try:
return self.repo.state.get(self.fs_path)
except FileNotFoundError:
if self.is_data_source:
return None
raise
def _save_meta(self):
# 保存元数据
meta = {
'hash': self.hash,
'is_data_source': self.is_data_source,
'use_cache': self.use_cache
}
meta_path = self._get_meta_path()
write_json(meta_path, meta)
class CacheManager:
def __init__(self, repo):
self.repo = repo
self.schemes = {}
self.caches = {}
self.default_cache = None
# 注册默认缓存
self._register_default_cache()
def _register_default_cache(self):
# 注册本地缓存
local_cache = LocalCache(self.repo)
self.caches['local'] = local_cache
self.default_cache = local_cache
# 注册其他缓存类型
self._register_cache_types()
def _register_cache_types(self):
# 注册不同类型的缓存
cache_types = {
'copy': CopyCache,
'reflink': ReflinkCache,
'symlink': SymlinkCache,
'hardlink': HardlinkCache
}
for name, cache_class in cache_types.items():
self.caches[name] = cache_class(self.repo)
def save(self, path, fs=None):
# 保存文件到缓存
if fs:
scheme = fs.scheme
else:
scheme = self.repo.fs.scheme
cache = self.get(scheme)
return cache.save(path)
def load(self, path, fs=None, **kwargs):
# 从缓存加载文件
cache = self.get(fs.scheme)
return cache.load(path, **kwargs)
云存储适配器支持多种云存储服务,包括 AWS S3、Google Cloud Storage、Azure Blob Storage 等
class DataCloud:
def __init__(self, repo):
self.repo = repo
self._cloud = None
self._remote = None
def _get_cloud(self, remote=None):
if not self._cloud:
remote = remote or self._default_remote
if remote:
remote_config = self.repo.config.get_remote_config(remote)
self._cloud = Cloud(remote_config)
else:
self._cloud = LocalCloud(self.repo)
return self._cloud
def push(self, obj, remote=None, **kwargs):
# 推送数据到远程
cloud = self._get_cloud(remote)
return cloud.push(obj, **kwargs)
def pull(self, obj, remote=None, **kwargs):
# 从远程拉取数据
cloud = self._get_cloud(remote)
return cloud.pull(obj, **kwargs)
def status(self, obj, remote=None, **kwargs):
# 检查数据状态
cloud = self._get_cloud(remote)
return cloud.status(obj, **kwargs)
# 命令注册机制
COMMANDS = {
'add': CmdAdd,
'status': CmdStatus,
'reproduce': CmdReproduce,
'pull': CmdPull,
'push': CmdPush,
'run': CmdRun,
'config': CmdConfig,
'metrics': CmdMetrics,
'plots': CmdPlots,
}
class CmdBase:
def __init__(self, args):
self.args = args
self.repo = None
def run(self):
raise NotImplementedError
def get_repo(self):
if not self.repo:
self.repo = Repo.open(self.args.url)
return self.repo
# 命令执行流程
def main():
parser = create_parser()
args = parser.parse_args()
cmd = args.func(args)
return cmd.run()
class CmdAdd(CmdBase):
def validate_args(self) -> None:
from dvc.exceptions import InvalidArgumentError
args = self.args
invalid_opt = None
if args.to_remote or args.out:
message = "{option} can't be used with "
message += "--to-remote" if args.to_remote else "--out"
if len(args.targets) != 1:
invalid_opt = "multiple targets"
elif args.glob:
invalid_opt = "--glob option"
elif args.no_commit:
invalid_opt = "--no-commit option"
else:
message = "{option} can't be used without --to-remote"
if args.remote:
invalid_opt = "--remote"
elif args.remote_jobs:
invalid_opt = "--remote-jobs"
if invalid_opt is not None:
raise InvalidArgumentError(message.format(option=invalid_opt))
def run(self):
try:
self.repo.add(
self.args.targets,
no_commit=self.args.no_commit,
glob=self.args.glob,
out=self.args.out,
remote=self.args.remote,
to_remote=self.args.to_remote,
remote_jobs=self.args.remote_jobs,
force=self.args.force,
relink=self.args.relink,
)
except DvcException as exc:
logger.exception("")
return 1
return 0
class CmdStatus(CmdBase):
def run(self):
try:
# 获取状态信息
status_info = self.repo.status()
# 检查是否有工作区更改
workspace_changes = status_info.get('workspace', {})
if workspace_changes:
self._print_status(workspace_changes, 'Changes')
# 检查是否有缓存状态
cache_status = status_info.get('cache', {})
if cache_status:
self._print_status(cache_status, 'Cache')
# 检查是否有远程状态
remote_status = status_info.get('remote', {})
if remote_status:
self._print_status(remote_status, 'Remote')
return 0
except DvcException as exc:
logger.exception("")
return 1
def _print_status(self, status_dict, title):
print(f"{title}:")
for path, changes in status_dict.items():
print(f" {path}: {changes}")
class CmdReproduce(CmdBase):
def run(self):
try:
# 解析目标
targets = self.args.targets
# 获取执行策略
force = self.args.force
dry = self.args.dry
interactive = self.args.interactive
# 执行 reproduce
result = self.repo.reproduce(
targets=targets,
force=force,
dry=dry,
interactive=interactive,
**self._get_reproduce_kwargs()
)
# 打印结果
if dry:
print("Dry run completed successfully")
elif result:
print("Pipeline reproduced successfully")
return 0
except DvcException as exc:
logger.exception("")
return 1
def _get_reproduce_kwargs(self):
kwargs = {}
if self.args.pipeline is not None:
kwargs['pipeline'] = self.args.pipeline
if self.args.no_cache is not None:
kwargs['no_cache'] = self.args.no_cache
if self.args.allow_missing:
kwargs['allow_missing'] = True
return kwargs
class CmdPull(CmdBase):
def run(self):
try:
# 解析参数
targets = self.args.targets
jobs = self.args.jobs
remote = self.args.remote
show_checksums = self.args.show_checksums
# 执行 pull
result = self.repo.pull(
targets=targets,
jobs=jobs,
remote=remote,
show_checksums=show_checksums,
**self._get_pull_kwargs()
)
self._print_result(result)
return 0
except DvcException as exc:
logger.exception("")
return 1
class CmdPush(CmdBase):
def run(self):
try:
# 解析参数
targets = self.args.targets
jobs = self.args.jobs
remote = self.args.remote
show_checksums = self.args.show_checksums
# 执行 push
result = self.repo.push(
targets=targets,
jobs=jobs,
remote=remote,
show_checksums=show_checksums,
**self._get_push_kwargs()
)
self._print_result(result)
return 0
except DvcException as exc:
logger.exception("")
return 1
class Config:
def __init__(self, dvc_dir, local_dvc_dir=None, fs=None,
config=None, remote=None, remote_config=None):
self.dvc_dir = dvc_dir
self.local_dvc_dir = local_dvc_dir
self.fs = fs or localfs
# 配置层级
self.config = {}
self.cache = {}
self.remote = {}
# 加载配置
self._load_config(config, remote, remote_config)
def _load_config(self, user_config=None, remote=None, remote_config=None):
# 加载默认配置
self._load_default_config()
# 加载全局配置
self._load_global_config()
# 加载本地配置
self._load_local_config()
# 加载项目配置
self._load_project_config()
# 应用用户配置
if user_config:
self._merge_config(user_config)
# 设置远程配置
if remote and remote_config:
self.remote[remote] = remote_config
def get(self, section, key=None, default=None):
# 获取配置值
if key is None:
return self.config.get(section, default)
else:
return self.config.get(section, {}).get(key, default)
class LockBase:
def __init__(self, lock_file, tmp_dir=None, hardlink_lock=False,
friendly=False, wait=False):
self.lock_file = lock_file
self.tmp_dir = tmp_dir
self.hardlink_lock = hardlink_lock
self.friendly = friendly
self.wait = wait
self._lock = None
def __enter__(self):
self.acquire()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
def acquire(self):
# 获取锁
if self.wait:
self._wait_for_lock()
else:
self._try_acquire()
def release(self):
# 释放锁
if self._lock:
self._cleanup()
self._lock = None
def _try_acquire(self):
try:
self._lock = open(self.lock_file, 'w')
os.flock(self._lock.fileno(), os.LOCK_EX | os.LOCK_NB)
return True
except (IOError, BlockingIOError):
return False
class DvcIgnoreFilter:
def __init__(self, fs, root_dir):
self.fs = fs
self.root_dir = root_dir
self.ignore_patterns = []
self.ignore_file = '.dvcignore'
# 加载忽略模式
self._load_ignore_patterns()
def _load_ignore_patterns(self):
# 加载 .dvcignore 文件
ignore_path = self.fs.join(self.root_dir, self.ignore_file)
if self.fs.exists(ignore_path):
content = self.fs.read_text(ignore_path)
patterns = self._parse_ignore_content(content)
self.ignore_patterns.extend(patterns)
# 加载默认忽略模式
self.ignore_patterns.extend([
'cache/',
'.dvc/',
'.git/',
'*.tmp',
'*.log'
])
def is_ignored(self, path):
# 检查路径是否被忽略
rel_path = self.fs.relpath(path, self.root_dir)
for pattern in self.ignore_patterns:
if self._match_pattern(rel_path, pattern):
return True
return False
# 文件系统基类
class FileSystem:
def __init__(self, **kwargs):
self.kwargs = kwargs
def exists(self, path):
raise NotImplementedError
def read_text(self, path):
raise NotImplementedError
def write_text(self, path, content):
raise NotImplementedError
# 本地文件系统
class LocalFileSystem(FileSystem):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.root = kwargs.get('root', '/')
def exists(self, path):
return os.path.exists(path)
def read_text(self, path):
with open(path, 'r') as f:
return f.read()
# Git 文件系统
class GitFileSystem(FileSystem):
def __init__(self, scm, rev=None, **kwargs):
super().__init__(**kwargs)
self.scm = scm
self.rev = rev or 'HEAD'
def exists(self, path):
return self.scm.exists(path, rev=self.rev)
class LocalFileSystem(FileSystem):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.root = kwargs.get('root', '/')
self.tmp_dir = kwargs.get('tmp_dir')
self.hash_jobs = kwargs.get('hash_jobs')
def abspath(self, path):
# 获取绝对路径
if os.path.isabs(path):
return path
return os.path.join(self.root, path)
def relpath(self, path, start=None):
# 获取相对路径
if start is None:
start = self.root
return os.path.relpath(path, start)
def makedirs(self, path, exist_ok=False):
# 创建目录
os.makedirs(path, exist_ok=exist_ok)
def copy(self, src, dst):
# 复制文件
shutil.copy2(src, dst)
def link(self, src, dst):
# 创建硬链接
os.link(src, dst)
def reflink(self, src, dst):
# 创建 reflink
try:
os.link(src, dst)
except OSError:
# 不支持 reflink 时回退到 copy
self.copy(src, dst)
class GitFileSystem(FileSystem):
def __init__(self, scm, rev=None, **kwargs):
super().__init__(**kwargs)
self.scm = scm
self.rev = rev or 'HEAD'
def exists(self, path):
# 检查文件是否存在
return self.scm.exists(path, rev=self.rev)
def read_text(self, path):
# 读取文件内容
return self.scm.get_file_content(path, rev=self.rev)
def open(self, path, mode='r'):
# 打开文件
if mode == 'r':
content = self.read_text(path)
return io.StringIO(content)
else:
raise NotImplementedError
def listdir(self, path):
# 列出目录内容
return self.scm.listdir(path, rev=self.rev)
def isdir(self, path):
# 检查是否为目录
return self.scm.isdir(path, rev=self.rev)
def walk(self, top, topdown=True):
# 遍历目录树
for root, dirs, files in self.scm.walk(top, rev=self.rev, topdown=topdown):
yield root, dirs, files
class State:
def __init__(self, root_dir, site_cache_dir, dvcignore):
self.root_dir = root_dir
self.site_cache_dir = site_cache_dir
self.dvcignore = dvcignore
# 初始化数据库
self.db_path = os.path.join(site_cache_dir, 'state.db')
self._init_db()
def _init_db(self):
# 初始化 SQLite 数据库
import sqlite3
conn = sqlite3.connect(self.db_path)
# 创建表
conn.execute('''
CREATE TABLE IF NOT EXISTS state (
path TEXT PRIMARY KEY,
hash TEXT,
mtime REAL,
size INTEGER
)
''')
conn.commit()
conn.close()
def get(self, path):
# 获取文件状态
path = os.path.abspath(path)
import sqlite3
conn = sqlite3.connect(self.db_path)
cursor = conn.execute('SELECT hash FROM state WHERE path = ?', (path,))
row = cursor.fetchone()
conn.close()
return row[0] if row else None
def update(self, path, hash_value):
# 更新文件状态
path = os.path.abspath(path)
import sqlite3
conn = sqlite3.connect(self.db_path)
conn.execute('''
INSERT OR REPLACE INTO state (path, hash, mtime, size)
VALUES (?, ?, ?, ?)
''', (path, hash_value, time.time(), os.path.getsize(path)))
conn.commit()
conn.close()
def calculate_hash(path, fs=None):
"""计算文件的哈希值"""
fs = fs or LocalFileSystem()
if fs.isdir(path):
return _hash_directory(path, fs)
else:
return _hash_file(path, fs)
def _hash_file(path, fs):
"""计算单个文件的哈希值"""
hash_obj = hashlib.sha256()
with fs.open(path, 'rb') as f:
# 分块读取文件
for chunk in iter(lambda: f.read(8192), b''):
hash_obj.update(chunk)
return hash_obj.hexdigest()
def _hash_directory(path, fs):
"""计算目录的哈希值"""
hash_obj = hashlib.sha256()
# 遍历目录中的所有文件
for root, dirs, files in fs.walk(path):
for file in files:
file_path = fs.join(root, file)
file_hash = _hash_file(file_path, fs)
# 使用相对路径和文件哈希
rel_path = fs.relpath(file_path, path)
hash_obj.update(f"{rel_path}:{file_hash}".encode())
return hash_obj.hexdigest()
校验机制:DVC 使用 SHA256 哈希算法确保数据的完整性和一致性
class IntegrityChecker:
def __init__(self, repo):
self.repo = repo
self.state = repo.state
def verify(self, path, expected_hash):
"""验证文件完整性"""
actual_hash = self.calculate_hash(path)
if actual_hash != expected_hash:
raise IntegrityError(
f"File {path} is corrupted. "
f"Expected: {expected_hash}, "
f"Actual: {actual_hash}"
)
def calculate_hash(self, path):
"""计算文件哈希值"""
if os.path.isdir(path):
return self._hash_directory(path)
else:
return self._hash_file(path)
def _hash_file(self, path):
"""计算文件哈希值"""
hash_obj = hashlib.sha256()
with open(path, 'rb') as f:
# 分块读取以避免内存问题
for chunk in iter(lambda: f.read(8192), b''):
hash_obj.update(chunk)
return hash_obj.hexdigest()
class CopyCache:
"""复制缓存策略"""
def __init__(self, repo):
self.repo = repo
self.cache_dir = repo.cache.local.cache_dir
def save(self, src, dst=None):
dst = dst or os.path.join(self.cache_dir, self._get_hash(src))
shutil.copy2(src, dst)
return dst
def load(self, src, dst):
shutil.copy2(src, dst)
return dst
class HardlinkCache:
"""硬链接缓存策略"""
def __init__(self, repo):
self.repo = repo
self.cache_dir = repo.cache.local.cache_dir
def save(self, src, dst=None):
dst = dst or os.path.join(self.cache_dir, self._get_hash(src))
os.link(src, dst)
return dst
def load(self, src, dst):
os.link(src, dst)
return dst
class SymlinkCache:
"""符号链接缓存策略"""
def __init__(self, repo):
self.repo = repo
self.cache_dir = repo.cache.local.cache_dir
def save(self, src, dst=None):
dst = dst or os.path.join(self.cache_dir, self._get_hash(src))
os.symlink(src, dst)
return dst
def load(self, src, dst):
os.symlink(src, dst)
return dst
| 策略 | 优点 | 缺点 | 使用场景 |
|---|---|---|---|
| 硬链接 | 空间效率高,性能好 | 不能跨文件系统 | 同一文件系统内 |
| 软链接 | 跨文件系统,灵活 | 需要额外维护,性能稍差 | 不同文件系统间 |
| 复制 | 兼容性好,独立性强 | 空间占用大 | 需要独立副本 |
| reflink | CoW机制,高效 | 需要文件系统支持 | 现代文件系统 |
Copy-on-Write (CoW):Reflink 创建的是写时复制文件,初始时共享相同的数据块
class ReflinkCache:
"""Reflink 缓存策略"""
def __init__(self, repo):
self.repo = repo
self.cache_dir = repo.cache.local.cache_dir
def save(self, src, dst=None):
dst = dst or os.path.join(self.cache_dir, self._get_hash(src))
try:
# 尝试创建 reflink
os.link(src, dst)
except OSError:
# 不支持 reflink 时回退到复制
shutil.copy2(src, dst)
return dst
def load(self, src, dst):
try:
# 尝试创建 reflink
os.link(src, dst)
except OSError:
# 不支持 reflink 时回退到复制
shutil.copy2(src, dst)
return dst
def is_reflink_supported(self):
"""检查文件系统是否支持 reflink"""
try:
# 创建测试文件
test_file = os.path.join(self.cache_dir, '.reflink_test')
with open(test_file, 'w') as f:
f.write('test')
# 尝试 reflink
reflink_file = test_file + '.reflink'
os.link(test_file, reflink_file)
# 清理
os.unlink(reflink_file)
os.unlink(test_file)
return True
except OSError:
return False
class RemoteAdapter:
"""远程存储适配器基类"""
def __init__(self, config):
self.config = config
self.bucket = config.get('bucket')
self.region = config.get('region')
self.endpoint_url = config.get('endpoint_url')
def exists(self, path):
raise NotImplementedError
def upload(self, src, dst):
raise NotImplementedError
def download(self, src, dst):
raise NotImplementedError
def list_objects(self, prefix=''):
raise NotImplementedError
class S3Adapter(RemoteAdapter):
"""AWS S3 适配器"""
def __init__(self, config):
super().__init__(config)
import boto3
self.s3 = boto3.client('s3', region_name=self.region)
def exists(self, path):
try:
self.s3.head_object(Bucket=self.bucket, Key=path)
return True
except:
return False
def upload(self, src, dst):
self.s3.upload_file(src, self.bucket, dst)
def download(self, src, dst):
self.s3.download_file(self.bucket, src, dst)
class S3Remote:
"""AWS S3 远程存储实现"""
def __init__(self, config):
self.bucket = config['bucket']
self.region = config.get('region', 'us-east-1')
self.endpoint_url = config.get('endpoint_url')
self.extra_args = config.get('extra_args', {})
# 初始化 S3 客户端
import boto3
session = boto3.Session()
self.s3 = session.client(
's3',
region_name=self.region,
endpoint_url=self.endpoint_url
)
def _get_full_path(self, path):
"""获取完整的 S3 路径"""
if path.startswith('/'):
path = path[1:]
return path
def exists(self, path):
"""检查对象是否存在"""
full_path = self._get_full_path(path)
try:
self.s3.head_object(Bucket=self.bucket, Key=full_path)
return True
except self.s3.exceptions.ClientError:
return False
def upload(self, src, dst):
"""上传文件到 S3"""
full_path = self._get_full_path(dst)
extra_args = self.extra_args.copy()
# 添加内容类型
if '.' in src:
ext = src.split('.')[-1]
extra_args['ContentType'] = self._get_content_type(ext)
self.s3.upload_file(src, self.bucket, full_path, ExtraArgs=extra_args)
def download(self, src, dst):
"""从 S3 下载文件"""
full_path = self._get_full_path(src)
self.s3.download_file(self.bucket, full_path, dst)
class GSRemote:
"""Google Cloud Storage 远程存储实现"""
def __init__(self, config):
self.bucket = config['bucket']
self.project = config.get('project')
self.credentials = config.get('credentials')
self.endpoint_url = config.get('endpoint_url')
# 初始化 GCS 客户端
from google.cloud import storage
self.client = storage.Client(
project=self.project,
credentials=self.credentials
)
def _get_bucket(self):
"""获取存储桶对象"""
return self.client.bucket(self.bucket)
def exists(self, path):
"""检查对象是否存在"""
bucket = self._get_bucket()
blob = bucket.blob(path)
return blob.exists()
def upload(self, src, dst):
"""上传文件到 GCS"""
bucket = self._get_bucket()
blob = bucket.blob(dst)
# 设置内容类型
blob.content_type = self._get_content_type(src)
# 上传文件
blob.upload_from_filename(src)
def download(self, src, dst):
"""从 GCS 下载文件"""
bucket = self._get_bucket()
blob = bucket.blob(src)
blob.download_to_filename(dst)
def list_objects(self, prefix=''):
"""列出存储桶中的对象"""
bucket = self._get_bucket()
blobs = bucket.list_blobs(prefix=prefix)
return [blob.name for blob in blobs]
class AzureBlobRemote:
"""Azure Blob Storage 远程存储实现"""
def __init__(self, config):
self.container = config['container']
self.account_name = config['account_name']
self.account_key = config.get('account_key')
self.connection_string = config.get('connection_string')
# 初始化 Azure Blob 客户端
from azure.storage.blob import BlobServiceClient
if self.connection_string:
self.blob_service_client = BlobServiceClient.from_connection_string(
self.connection_string
)
else:
self.blob_service_client = BlobServiceClient(
account_url=f"https://{self.account_name}.blob.core.windows.net",
credential=self.account_key
)
def _get_container_client(self):
"""获取容器客户端"""
return self.blob_service_client.get_container_client(self.container)
def exists(self, path):
"""检查对象是否存在"""
container_client = self._get_container_client()
blob_client = container_client.get_blob_client(path)
return blob_client.exists()
def upload(self, src, dst):
"""上传文件到 Azure Blob"""
container_client = self._get_container_client()
blob_client = container_client.get_blob_client(dst)
# 设置内容类型
content_settings = self._get_content_settings(src)
# 上传文件
with open(src, 'rb') as data:
blob_client.upload_blob(data, overwrite=True, content_settings=content_settings)
def download(self, src, dst):
"""从 Azure Blob 下载文件"""
container_client = self._get_container_client()
blob_client = container_client.get_blob_client(src)
with open(dst, 'wb') as download_file:
download_file.write(blob_client.download_blob().readall())
class Experiments:
"""实验管理系统"""
def __init__(self, repo):
self.repo = repo
def run(self, command, params=None, **kwargs):
"""运行实验"""
# 创建实验分支
exp_name = self._generate_exp_name()
branch = f'exp/{exp_name}'
# 切换到实验分支
self.repo.scm.git.checkout(branch, create=True)
# 设置实验参数
if params:
self._set_params(params)
# 运行命令
result = self._run_command(command, **kwargs)
# 记录实验结果
self._record_exp_result(exp_name, result)
return exp_name
def reproduce(self, exp_name):
"""重现实验"""
branch = f'exp/{exp_name}'
# 切换到实验分支
self.repo.scm.git.checkout(branch)
# 重新运行流水线
return self.repo.reproduce()
def list(self):
"""列出所有实验"""
experiments = []
for branch in self.repo.scm.git.branch('--list', 'exp/*'):
exp_name = branch.replace('exp/', '').strip()
experiments.append({
'name': exp_name,
'branch': branch,
'commit': self.repo.scm.git.rev_parse(branch)
})
return experiments
class Metrics:
"""指标管理系统"""
def __init__(self, repo):
self.repo = repo
self.metrics_dir = os.path.join(repo.root_dir, 'metrics')
self.metrics_file = os.path.join(self.metrics_dir, 'metrics.json')
def collect(self, path=None):
"""收集指标数据"""
metrics = {}
# 收集默认指标文件
if path is None:
path = self.metrics_file
if os.path.exists(path):
with open(path, 'r') as f:
metrics.update(json.load(f))
# 收集自定义指标文件
for root, dirs, files in os.walk(self.metrics_dir):
for file in files:
if file.endswith('.json'):
file_path = os.path.join(root, file)
with open(file_path, 'r') as f:
metrics.update(json.load(f))
return metrics
def show(self, path=None, all_branches=False):
"""显示指标数据"""
metrics = self.collect(path)
if all_branches:
# 收集所有分支的指标
for branch in self.repo.brancher(all_branches=True):
branch_metrics = self._collect_from_branch(branch)
metrics[branch] = branch_metrics
return metrics
def _collect_from_branch(self, branch):
"""从特定分支收集指标"""
with self.repo.switch(branch):
return self.collect()
class Plots:
"""图表管理系统"""
def __init__(self, repo):
self.repo = repo
self.plots_dir = os.path.join(repo.root_dir, 'plots')
self.plots_file = os.path.join(self.plots_dir, 'plots.json')
def show(self, path=None, all_branches=False):
"""显示图表数据"""
plots = self._load_plots_config(path)
if all_branches:
# 收集所有分支的图表
for branch in self.repo.brancher(all_branches=True):
branch_plots = self._load_plots_from_branch(branch)
plots[branch] = branch_plots
return plots
def _load_plots_config(self, path=None):
"""加载图表配置"""
if path is None:
path = self.plots_file
if os.path.exists(path):
with open(path, 'r') as f:
return json.load(f)
return {}
def _load_plots_from_branch(self, branch):
"""从特定分支加载图表"""
with self.repo.switch(branch):
return self._load_plots_config()
def diff(self, target1, target2, path=None):
"""比较两个目标的图表差异"""
plots1 = self.show(path=path)
plots2 = self.show(path=path)
diff_result = {}
for plot_name in plots1:
if plot_name in plots2:
diff_result[plot_name] = self._compare_plots(
plots1[plot_name], plots2[plot_name]
)
return diff_result
class Params:
"""参数管理系统"""
def __init__(self, repo):
self.repo = repo
self.params_dir = os.path.join(repo.root_dir, 'params')
def collect(self, path=None):
"""收集参数数据"""
params = {}
# 收集默认参数文件
if path is None:
for file_name in ['params.yaml', 'params.json', 'params.toml']:
file_path = os.path.join(self.params_dir, file_name)
if os.path.exists(file_path):
params.update(self._load_params_file(file_path))
# 收集自定义参数文件
if path and os.path.exists(path):
params.update(self._load_params_file(path))
return params
def _load_params_file(self, path):
"""加载参数文件"""
if path.endswith('.yaml') or path.endswith('.yml'):
import yaml
with open(path, 'r') as f:
return yaml.safe_load(f) or {}
elif path.endswith('.json'):
with open(path, 'r') as f:
return json.load(f)
elif path.endswith('.toml'):
import toml
with open(path, 'r') as f:
return toml.load(f)
else:
raise ValueError(f"Unsupported parameter file format: {path}")
def show(self, path=None, all_branches=False):
"""显示参数数据"""
params = self.collect(path)
if all_branches:
for branch in self.repo.brancher(all_branches=True):
with self.repo.switch(branch):
branch_params = self.collect(path)
params[branch] = branch_params
return params
class PipelineOrchestrator:
"""流水线编排器"""
def __init__(self, repo):
self.repo = repo
self.pipeline = repo.pipeline
def run(self, targets=None, **kwargs):
"""运行流水线"""
# 解析目标
stages = self._resolve_targets(targets)
# 执行流水线
results = []
for stage in stages:
try:
result = self._execute_stage(stage, **kwargs)
results.append(result)
except StageFailedError as e:
self._handle_stage_failure(stage, e)
if kwargs.get('fail_fast', True):
break
return results
def _resolve_targets(self, targets):
"""解析目标 Stage"""
if targets is None:
# 执行所有 Stage
return list(self.pipeline.stages.values())
resolved = []
for target in targets:
if target in self.pipeline.stages:
resolved.append(self.pipeline.stages[target])
else:
# 解析通配符
matched = self._match_wildcard(target)
resolved.extend(matched)
return resolved
def _execute_stage(self, stage, **kwargs):
"""执行单个 Stage"""
# 检查依赖
self._check_dependencies(stage)
# 执行 Stage
result = stage.run()
# 保存结果
stage.save()
return result
数据血缘:追踪数据从输入到输出的完整处理流程,确保实验的可复现性和可追溯性
class DataLineage:
"""数据血缘关系管理"""
def __init__(self, repo):
self.repo = repo
self.graph = nx.DiGraph()
def build_lineage(self, targets=None):
"""构建数据血缘关系图"""
# 添加所有 Stage
for stage in self.repo.pipeline.stages.values():
self.graph.add_node(stage.path, stage=stage)
# 添加依赖关系
for stage in self.repo.pipeline.stages.values():
# 添加输入依赖
for dep in stage.deps:
self.graph.add_edge(dep.fs_path, stage.path)
# 添加输出依赖
for out in stage.outs:
self.graph.add_edge(stage.path, out.fs_path)
def trace_data(self, data_path):
"""追踪数据的来源"""
if data_path not in self.graph:
return []
# 追踪数据来源
sources = []
for pred in self.graph.predecessors(data_path):
sources.append({
'path': pred,
'type': 'input',
'stages': self._find_stages_for_path(pred)
})
return sources
def trace_usage(self, data_path):
"""追踪数据的使用情况"""
if data_path not in self.graph:
return []
# 追踪数据去向
usages = []
for succ in self.graph.successors(data_path):
usages.append({
'path': succ,
'type': 'output',
'stages': self._find_stages_for_path(succ)
})
return usages
class DvcException(Exception):
"""DVC 基础异常类"""
def __init__(self, message, *args, **kwargs):
super().__init__(message, *args)
self.message = message
self.exit_code = kwargs.get('exit_code', 1)
class OutputNotFoundError(DvcException):
"""输出文件未找到异常"""
def __init__(self, path, repo):
super().__init__(f"Output '{path}' not found in DVC repository '{repo}'")
self.path = path
self.repo = repo
class CacheCorruptedError(DvcException):
"""缓存损坏异常"""
def __init__(self, hash_value):
super().__init__(f"Cache for hash '{hash_value}' is corrupted")
self.hash_value = hash_value
class RemoteConnectionError(DvcException):
"""远程连接异常"""
def __init__(self, remote_name, original_exception):
super().__init__(f"Failed to connect to remote '{remote_name}'")
self.remote_name = remote_name
self.original_exception = original_exception
class PipelineFailedError(DvcException):
"""流水线执行失败异常"""
def __init__(self, failed_stages, results):
super().__init__(f"Pipeline failed with {len(failed_stages)} stages")
self.failed_stages = failed_stages
self.results = results
class DvcLogger:
"""DVC 日志系统"""
def __init__(self):
self.logger = logging.getLogger('dvc')
self.logger.setLevel(logging.INFO)
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 文件处理器
file_handler = logging.FileHandler('dvc.log')
file_handler.setLevel(logging.DEBUG)
# 格式化器
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.logger.addHandler(file_handler)
def setup(self):
"""设置日志系统"""
# 配置日志级别
if os.environ.get('DVC_DEBUG'):
self.logger.setLevel(logging.DEBUG)
# 配置输出格式
if os.environ.get('DVC_LOG_FORMAT'):
# 自定义日志格式
pass
def get_logger(self, name):
"""获取指定名称的日志器"""
return logging.getLogger(f'dvc.{name}')
缓存优化
网络优化
文件系统优化
数据库优化
项目结构
缓存管理
远程存储
实验管理
官方文档
技术文章
相关技术
深度学习
核心要点
设计亮点
DVC 是一个精心设计的数据科学项目管理工具,通过分层架构、缓存机制和云存储集成,为机器学习项目提供了完整的数据版本控制解决方案。
源码地址
https://github.com/iterative/dvc