← 返回首页
🧠

SafeTensors:安全的模型存储格式

📂 llm ⏱ 2 min 326 words

--- title: "SafeTensors:安全的模型存储格式" description: "介绍SafeTensors格式的特点和使用方法,实现安全高效的模型加载" tags: ["SafeTensors", "模型格式", "HuggingFace", "安全"] category: "llm" icon: "🧠"

SafeTensors:安全的模型存储格式

SafeTensors简介

SafeTensors是HuggingFace推出的安全张量存储格式,解决了传统PyTorch模型文件(.bin)的安全隐患,成为HuggingFace Hub上的标准格式。

为什么需要SafeTensors

传统格式的问题

# PyTorch .bin格式使用pickle反序列化
# pickle可以执行任意代码,存在安全风险

import pickle

# 恶意pickle示例(危险!不要执行)
# class Malicious:
#     def __reduce__(self):
#         return (os.system, ('echo hacked',))

# SafeTensors避免了这个问题

SafeTensors的优势

benefits = {
    "安全性": "不使用pickle,无法执行任意代码",
    "加载速度": "支持零拷贝加载",
    "内存效率": "支持mmap,不需要完整加载",
    "元数据": "支持存储模型元信息",
    "跨框架": "支持PyTorch、TensorFlow、JAX"
}

基础使用

安装

pip install safetensors

保存张量

import torch
from safetensors.torch import save_file, load_file

# 创建模型参数
tensors = {
    "weight": torch.randn(100, 100),
    "bias": torch.randn(100)
}

# 保存为SafeTensors格式
save_file(tensors, "model.safetensors")

print("模型已保存")

加载张量

# 加载模型
loaded_tensors = load_file("model.safetensors")

print(f"包含的张量: {list(loaded_tensors.keys())}")
print(f"weight形状: {loaded_tensors['weight'].shape}")
print(f"bias形状: {loaded_tensors['bias'].shape}")

与HuggingFace集成

保存完整模型

from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# 保存为SafeTensors格式
model.save_pretrained("model_safetensors", safe_serialization=True)
tokenizer.save_pretrained("model_safetensors")

print("模型已保存为SafeTensors格式")

加载SafeTensors模型

from transformers import AutoModelForCausalLM

# 加载SafeTensors格式的模型
model = AutoModelForCausalLM.from_pretrained(
    "model_safetensors",
    torch_dtype=torch.float16,
    device_map="auto"
)

# 也可以显式指定
model = AutoModelForCausalLM.from_pretrained(
    "model_safetensors",
    use_safetensors=True
)

在Hub上发布

from huggingface_hub import HfApi

api = HfApi()

# 上传模型
api.upload_folder(
    folder_path="model_safetensors",
    repo_id="your-username/your-model",
    repo_type="model"
)

高级用法

查看文件信息

from safetensors import safe_open

# 查看SafeTensors文件的元信息
with safe_open("model.safetensors", framework="pt") as f:
    print(f"张量名称: {f.keys()}")
    print(f"metadata: {f.metadata()}")
    
    # 查看特定张量的形状
    tensor = f.get_tensor("weight")
    print(f"weight形状: {tensor.shape}")

流式加载

from safetensors import safe_open

# 只加载需要的张量
with safe_open("model.safetensors", framework="pt") as f:
    # 只加载weight,不加载bias
    weight = f.get_tensor("weight")
    print(f"只加载了weight: {weight.shape}")

元数据存储

from safetensors.torch import save_file

# 存储自定义元数据
metadata = {
    "model_name": "my-model",
    "version": "1.0",
    "description": "A custom language model"
}

tensors = {"weight": torch.randn(50, 50)}

# 保存时添加元数据
save_file(tensors, "model.safetensors", metadata=metadata)

# 读取元数据
with safe_open("model.safetensors", framework="pt") as f:
    print(f"元数据: {f.metadata()}")

格式转换

PyTorch转SafeTensors

# 方法1:使用save_pretrained
model.save_pretrained("output_dir", safe_serialization=True)

# 方法2:手动转换
from safetensors.torch import save_file

# 获取模型状态字典
state_dict = model.state_dict()
save_file(state_dict, "model.safetensors")

SafeTensors转PyTorch

from safetensors.torch import load_file
import torch

# 加载SafeTensors
state_dict = load_file("model.safetensors")

# 保存为PyTorch格式
torch.save(state_dict, "model.bin")

性能对比

import time

def benchmark_load(format_type, file_path):
    """对比加载速度"""
    start = time.time()
    
    if format_type == "safetensors":
        from safetensors.torch import load_file
        tensors = load_file(file_path)
    else:
        tensors = torch.load(file_path)
    
    elapsed = time.time() - start
    return elapsed

# SafeTensors通常比pickle更快,特别是大文件
# safetensors_time = benchmark_load("safetensors", "model.safetensors")
# pytorch_time = benchmark_load("pytorch", "model.bin")

最佳实践

# 1. 始终优先使用SafeTensors格式
model.save_pretrained("output", safe_serialization=True)

# 2. 使用mmap加载大模型
from safetensors.torch import load_file
tensors = load_file("large_model.safetensors", device="cpu")

# 3. 转换现有模型
# 从PyTorch格式转换
state_dict = torch.load("model.bin")
from safetensors.torch import save_file
save_file(state_dict, "model.safetensors")

总结

SafeTensors是现代LLM生态中的标准格式,提供安全、高效的模型存储方案。通过HuggingFace生态系统,可以方便地使用和分享SafeTensors格式的模型。