在深度学习工程实践中,我们经常会遇到这样的场景:需要使用预训练模型(如 CLIP、ViT、ResNet
等)来提取特征,或者需要对数据集进行复杂的预处理。这些操作往往会占用大量的计算资源和时间,每次训练都需要重新处理,严重影响效率。而自己预先缓存的话又要大改代码,并且推理的时候也得重新算特征,太麻烦了。
如何简单的加速深度学习过程中数据集预处理速度和预训练模型特征提取速度?
现在有个工具 Torch Module Cache
,只需要一行代码就能解决这个问题,实测训练速度提升 30 倍以上!
🔥 为什么要用 Torch Module Cache?
在我之前的项目中,经常需要处理这样的情况:
- 使用 CLIP 或 ViT 这样的大模型提取图像特征,特征提取占用了大量的训练时间
- 数据集需要复杂的预处理和增强
- 预处理过的数据需要手动保存和加载
- 推理的时候也得重新算特征,这部分又得专门实现
通常的解决方案是把特征提取好保存成文件,每次都要自己实现真的太麻烦了。
💡 实战经验分享
0. 安装
安装非常简单,只需要 pip 安装:
1
|
pip install torch-module-cache
|
1. 预训练模型特征提取加速
最直观的使用场景是加速预训练模型的特征提取。看看这段代码有多简洁,只需要一行代码就能缓存特征:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
from torch_module_cache import cache_module
@cache_module() # 只需要加这一行!
class FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
self.vit = timm.create_model("vit_base_patch16_224", pretrained=True)
self.vit.eval()
def forward(self, x):
with torch.no_grad():
features = self.vit.forward_features(x)
return features
# 在实际使用时
extractor = FeatureExtractor()
# 第一次运行会计算并缓存
features = extractor(images, cache_key="batch_1")
# 后续运行直接从缓存读取,速度飞快!
features = extractor(images, cache_key="batch_1")
|
就这么简单,不需要手动管理缓存文件,不需要修改现有代码结构,加一个装饰器就能获得巨大的性能提升。实测训练速度提升 30 倍以上!
2. 数据集预处理加速
另一个实用场景是数据集的预处理。比如我们经常需要对图像做复杂的预处理,比如:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
# 只需要加一行装饰器,就能缓存预处理结果
@cache_module(cache_name="preprocessing")
class DataProcessor(nn.Module):
def __init__(self):
super().__init__()
self.transforms = nn.Sequential(
# 你的各种预处理操作
)
def forward(self, x):
return self.transforms(x)
class MyDataset(Dataset):
def __init__(self):
self.processor = DataProcessor()
def __getitem__(self, idx):
raw_data = self.data[idx]
# 使用样本索引作为缓存键,也可以用文件路径作为缓存键,只需要保证唯一性即可
processed = self.processor(raw_data, cache_key=f"sample_{idx}")
return processed, self.labels[idx]
|
最开始的第一个 epoch 会正常处理数据,从第二个 epoch 开始就直接从缓存读取,训练速度立刻起飞!
而且,更方便的是,后面再次启动训练的时候,第一个 epoch 就会使用缓存,并且模型采用了 Lazy Load 的方式,只要完整缓存了整个数据集,模型本身就不会被加载,从而节省了大量显存,终于不用再 OOM 了~
💪 实用技巧分享
- 内存加速: 每次都从硬盘加载还是吃硬盘 IO,可以通过配置缓存到内存中,这样速度会更快。
1
|
@cache_module(cache_level=CacheLevel.MEMORY) # 缓存到内存中
|
- 内存管理:如果你担心缓存到内存占用太多内存,导致 OOM,可以设置最大缓存大小:
1
|
@cache_module(max_memory_cache_size_mb=1024) # 限制内存缓存大小为 1GB
|
- 推理模式:在推理时可能不需要缓存,可以全局禁用:
1
2
3
|
from torch_module_cache import enable_inference_mode
enable_inference_mode()
|
- 缓存清理:需要清理缓存时也很方便:
1
2
3
4
|
from torch_module_cache import clear_memory_caches, clear_disk_caches
clear_memory_caches() # 清理内存缓存
clear_disk_caches() # 清理磁盘缓存
|