← 返回首页
🧠

私有推理

📂 llm ⏱ 7 min 1298 words

--- title: "私有推理" description: "详细介绍LLM私有推理技术,包括可信执行环境(TEE)、同态加密、安全多方计算等隐私保护推理方法" tags: ["私有推理", "TEE", "同态加密", "安全多方计算"] category: "llm" icon: "🧠"

私有推理

私有推理概述

私有推理(Private Inference)是指在保护输入数据隐私和模型参数隐私的前提下,执行LLM推理的技术。随着LLM在敏感领域的应用,私有推理变得越来越重要。

隐私威胁模型

from enum import Enum
from typing import Any, Dict, List

class PrivacyThreat(Enum):
    """隐私威胁类型"""
    INPUT_PRIVACY = "input_privacy"  # 输入数据隐私
    MODEL_PRIVACY = "model_privacy"  # 模型参数隐私
    INFERENCE_PRIVACY = "inference_privacy"  # 推断结果隐私

class ThreatModel:
    """威胁模型定义"""
    def __init__(self):
        self.threats = []
        self.defenses = {}
    
    def add_threat(self, threat: PrivacyThreat, description: str):
        """添加威胁"""
        self.threats.append({
            "type": threat,
            "description": description,
            "severity": "high"
        })
    
    def add_defense(self, threat: PrivacyThreat, defense_method: str):
        """添加防御方法"""
        self.defenses[threat] = defense_method
    
    def evaluate_security(self) -> Dict[str, Any]:
        """评估安全性"""
        return {
            "num_threats": len(self.threats),
            "num_defenses": len(self.defenses),
            "coverage": len(self.defenses) / len(self.threats) if self.threats else 0
        }

可信执行环境(TEE)

Intel SGX实现

import ctypes
from typing import List, Dict, Any

class TEEEnclave:
    """可信执行环境模拟"""
    def __init__(self, enclave_size: int = 1024 * 1024 * 100):  # 100MB
        self.enclave_size = enclave_size
        self.enclave_memory = bytearray(enclave_size)
        self.sealed_data = {}
    
    def create_enclave(self) -> int:
        """创建飞地"""
        # 模拟SGX飞地创建
        enclave_id = 12345
        return enclave_id
    
    def load_model_to_enclave(self, model_weights: bytes) -> bool:
        """将模型加载到飞地"""
        # 模拟安全加载
        if len(model_weights) > self.enclave_size:
            return False
        
        self.enclave_memory[:len(model_weights)] = model_weights
        return True
    
    def secure_inference(self, encrypted_input: bytes) -> bytes:
        """在飞地中执行推理"""
        # 1. 解密输入(在飞地内)
        decrypted_input = self._decrypt_in_enclave(encrypted_input)
        
        # 2. 执行推理
        model_output = self._run_inference(decrypted_input)
        
        # 3. 加密输出
        encrypted_output = self._encrypt_in_enclave(model_output)
        
        return encrypted_output
    
    def _decrypt_in_enclave(self, encrypted_data: bytes) -> Any:
        """在飞地中解密"""
        # 模拟解密
        return encrypted_data
    
    def _run_inference(self, input_data: Any) -> Any:
        """执行推理"""
        # 模型推理
        return b"inference_result"
    
    def _encrypt_in_enclave(self, data: Any) -> bytes:
        """在飞地中加密"""
        # 模拟加密
        return data if isinstance(data, bytes) else b"encrypted_data"
    
    def seal_data(self, data: bytes) -> bytes:
        """密封数据"""
        # SGX密封保护
        self.sealed_data[id(data)] = data
        return data
    
    def unseal_data(self, sealed_data: bytes) -> bytes:
        """解封数据"""
        return sealed_data

TEE安全协议

class TEESecurityProtocol:
    """TEE安全协议"""
    def __init__(self):
        self.attestation_service = None
        self.key_management = None
    
    def remote_attestation(self, enclave_id: int) -> bool:
        """远程证明"""
        # 1. 获取飞地度量
        measurement = self._get_enclave_measurement(enclave_id)
        
        # 2. 验证签名
        signature_valid = self._verify_signature(measurement)
        
        # 3. 检查策略
        policy_valid = self._check_policy(measurement)
        
        return signature_valid and policy_valid
    
    def _get_enclave_measurement(self, enclave_id: int) -> bytes:
        """获取飞地度量"""
        return b"enclave_measurement"
    
    def _verify_signature(self, measurement: bytes) -> bool:
        """验证签名"""
        return True
    
    def _check_policy(self, measurement: bytes) -> bool:
        """检查策略"""
        return True
    
    def establish_secure_channel(self, enclave_id: int, remote_party: str) -> bytes:
        """建立安全通道"""
        # 1. 密钥交换
        shared_secret = self._key_exchange(enclave_id, remote_party)
        
        # 2. 建立加密通道
        channel_key = self._derive_channel_key(shared_secret)
        
        return channel_key
    
    def _key_exchange(self, enclave_id: int, remote_party: str) -> bytes:
        """密钥交换"""
        return b"shared_secret"
    
    def _derive_channel_key(self, shared_secret: bytes) -> bytes:
        """派生通道密钥"""
        return shared_secret

同态加密推理

加密推理实现

import numpy as np
from typing import List, Tuple

class HomomorphicInference:
    """同态加密推理"""
    def __init__(self, security_level: int = 128):
        self.security_level = security_level
        self.ciphertext_modulus = 2 ** 64
        self.plaintext_modulus = 2 ** 32
    
    def encrypt_input(self, input_data: np.ndarray) -> np.ndarray:
        """加密输入"""
        # 简化的同态加密
        encrypted = input_data * self.ciphertext_modulus // self.plaintext_modulus
        return encrypted.astype(np.int64)
    
    def decrypt_output(self, encrypted_output: np.ndarray) -> np.ndarray:
        """解密输出"""
        decrypted = encrypted_output * self.plaintext_modulus // self.ciphertext_modulus
        return decrypted.astype(np.float32)
    
    def homomorphic_linear(self, encrypted_input: np.ndarray, 
                          weight: np.ndarray, bias: np.ndarray) -> np.ndarray:
        """同态线性变换"""
        # 同态加法和乘法
        encrypted_output = np.dot(encrypted_input, weight) + bias
        return encrypted_output
    
    def homomorphic_activation(self, encrypted_input: np.ndarray, 
                             activation: str = "relu") -> np.ndarray:
        """同态激活函数"""
        if activation == "relu":
            # 近似ReLU
            return np.maximum(encrypted_input, 0)
        elif activation == "sigmoid":
            # 近似sigmoid
            return 1 / (1 + np.exp(-encrypted_input))
        else:
            return encrypted_input
    
    def encrypted_matrix_multiply(self, encrypted_A: np.ndarray, 
                                B: np.ndarray) -> np.ndarray:
        """加密矩阵乘法"""
        # 使用BFV或CKKS方案
        result = np.dot(encrypted_A, B)
        return result
    
    def encrypted_convolution(self, encrypted_input: np.ndarray,
                            kernel: np.ndarray) -> np.ndarray:
        """加密卷积"""
        # 简化的卷积实现
        return np.convolve(encrypted_input.flatten(), kernel.flatten(), mode='valid')

优化技术

class HEOptimization:
    """同态加密优化"""
    def __init__(self):
        self.ciphertext_packing = True
        self.batch_size = 1024
    
    def pack_ciphertexts(self, ciphertexts: List[np.ndarray]) -> np.ndarray:
        """打包密文"""
        # SIMD风格的打包
        packed = np.stack(ciphertexts)
        return packed
    
    def batch_inference(self, encrypted_inputs: List[np.ndarray], 
                       model_weights: List[np.ndarray]) -> List[np.ndarray]:
        """批量推理"""
        packed_inputs = self.pack_ciphertexts(encrypted_inputs)
        
        # 批量同态计算
        packed_outputs = []
        for weight in model_weights:
            packed_output = np.dot(packed_inputs, weight)
            packed_outputs.append(packed_output)
        
        return packed_outputs
    
    def optimize_rotation(self, ciphertext: np.ndarray, 
                         rotation_key: np.ndarray) -> np.ndarray:
        """优化旋转操作"""
        # 使用自同构优化
        rotated = np.roll(ciphertext, rotation_key)
        return rotated
    
    def noise_budget_management(self, ciphertext: np.ndarray, 
                              noise_budget: int) -> Tuple[np.ndarray, int]:
        """噪声预算管理"""
        # 重新加密以刷新噪声
        refreshed = ciphertext  # 简化
        new_budget = noise_budget - 1
        
        return refreshed, new_budget

安全多方计算推理

MPC协议

class MPCInference:
    """安全多方计算推理"""
    def __init__(self, num_parties: int):
        self.num_parties = num_parties
        self.secret_shares = {}
    
    def share_secret(self, secret: np.ndarray) -> List[np.ndarray]:
        """分享秘密"""
        shares = []
        remaining = secret.copy()
        
        for i in range(self.num_parties - 1):
            share = np.random.randint(0, 2**32, size=secret.shape, dtype=np.int64)
            shares.append(share)
            remaining = remaining - share
        
        shares.append(remaining)
        return shares
    
    def reconstruct_secret(self, shares: List[np.ndarray]) -> np.ndarray:
        """重建秘密"""
        return sum(shares)
    
    def mpc_linear_layer(self, shares: List[np.ndarray], 
                        weight: np.ndarray) -> List[np.ndarray]:
        """MPC线性层"""
        # 每个参与方本地计算
        result_shares = []
        for share in shares:
            local_result = np.dot(share, weight)
            result_shares.append(local_result)
        
        return result_shares
    
    def mpc_activation(self, shares: List[np.ndarray], 
                      activation: str = "relu") -> List[np.ndarray]:
        """MPC激活函数"""
        if activation == "relu":
            # 使用安全比较协议
            return self._mpc_relu(shares)
        elif activation == "sigmoid":
            # 使用多项式近似
            return self._mpc_sigmoid(shares)
        else:
            return shares
    
    def _mpc_relu(self, shares: List[np.ndarray]) -> List[np.ndarray]:
        """MPC ReLU"""
        # 安全比较协议
        zero_shares = self.share_secret(np.zeros_like(shares[0]))
        
        # 比较每个元素
        result_shares = []
        for share in shares:
            # 简化的安全ReLU
            result = np.maximum(share, 0)
            result_shares.append(result)
        
        return result_shares
    
    def _mpc_sigmoid(self, shares: List[np.ndarray]) -> List[np.ndarray]:
        """MPC Sigmoid"""
        # 多项式近似:1/(1+exp(-x)) ≈ 0.5 + 0.25x - 0.0208x^3
        result_shares = []
        for share in shares:
            # 计算多项式
            result = 0.5 + 0.25 * share - 0.0208 * share**3
            result_shares.append(result)
        
        return result_shares
    
    def mpc_softmax(self, shares: List[np.ndarray]) -> List[np.ndarray]:
        """MPC Softmax"""
        # 安全指数和归一化
        # 1. 安全计算exp
        exp_shares = []
        for share in shares:
            exp_share = np.exp(share)
            exp_shares.append(exp_share)
        
        # 2. 安全求和
        sum_shares = self._mpc_sum(exp_shares)
        
        # 3. 安全除法
        result_shares = []
        for exp_share, sum_share in zip(exp_shares, sum_shares):
            # 简化的安全除法
            result = exp_share / sum_share
            result_shares.append(result)
        
        return result_shares
    
    def _mpc_sum(self, shares: List[np.ndarray]) -> List[np.ndarray]:
        """MPC求和"""
        total = np.zeros_like(shares[0])
        for share in shares:
            total += share
        
        return [total] * len(shares)

通信优化

class MPCCommunicationOptimized:
    """MPC通信优化"""
    def __init__(self, num_parties: int):
        self.num_parties = num_parties
        self.communication_rounds = 0
    
    def constant_round_relu(self, shares: List[np.ndarray]) -> List[np.ndarray]:
        """常数轮ReLU"""
        # 使用预处理技术
        # 1. 预生成 Beaver triples
        beaver_triples = self._generate_beaver_triples(shares)
        
        # 2. 在线阶段使用triples
        result_shares = self._use_beaver_triples(shares, beaver_triples)
        
        self.communication_rounds += 1
        return result_shares
    
    def _generate_beaver_triples(self, shares: List[np.ndarray]) -> List[Tuple]:
        """生成Beaver triples"""
        triples = []
        for i in range(self.num_parties):
            a = np.random.randint(0, 2**32, size=shares[i].shape, dtype=np.int64)
            b = np.random.randint(0, 2**32, size=shares[i].shape, dtype=np.int64)
            c = a * b
            triples.append((a, b, c))
        return triples
    
    def _use_beaver_triples(self, shares: List[np.ndarray], 
                          triples: List[Tuple]) -> List[np.ndarray]:
        """使用Beaver triples"""
        result_shares = []
        for i, (share, triple) in enumerate(zip(shares, triples)):
            a, b, c = triple
            # 简化的Beaver协议
            result = share * b + a * share + c
            result_shares.append(result)
        
        return result_shares
    
    def batch_mpc_computation(self, batch_shares: List[List[np.ndarray]], 
                            computation: str) -> List[List[np.ndarray]]:
        """批量MPC计算"""
        results = []
        for shares in batch_shares:
            if computation == "linear":
                # 批量线性计算
                weight = np.random.randint(0, 2**32, size=(shares[0].shape[1], 10), dtype=np.int64)
                result = self.mpc_linear_layer(shares, weight)
            elif computation == "relu":
                result = self.mpc_relu(shares)
            else:
                result = shares
            results.append(result)
        
        return results
    
    def mpc_linear_layer(self, shares: List[np.ndarray], 
                        weight: np.ndarray) -> List[np.ndarray]:
        """MPC线性层"""
        return [np.dot(share, weight) for share in shares]
    
    def mpc_relu(self, shares: List[np.ndarray]) -> List[np.ndarray]:
        """MPC ReLU"""
        return [np.maximum(share, 0) for share in shares]

混合方案

TEE+HE混合推理

class HybridTEEHEInference:
    """TEE+HE混合推理"""
    def __init__(self):
        self.tee = TEEEnclave()
        self.he = HomomorphicInference()
    
    def hybrid_inference(self, encrypted_input: np.ndarray, 
                        model_weights: np.ndarray) -> np.ndarray:
        """混合推理"""
        # 1. 在TEE中解密输入
        decrypted_input = self.tee._decrypt_in_enclave(encrypted_input.tobytes())
        decrypted_input = np.frombuffer(decrypted_input, dtype=np.float32)
        
        # 2. 在TEE中执行同态计算
        encrypted_output = self.he.encrypted_matrix_multiply(
            self.he.encrypt_input(decrypted_input), 
            model_weights
        )
        
        # 3. 在TEE中解密输出
        output = self.he.decrypt_output(encrypted_output)
        
        return output
    
    def secure_model_loading(self, model_path: str) -> bool:
        """安全加载模型"""
        # 1. 读取模型
        with open(model_path, 'rb') as f:
            model_bytes = f.read()
        
        # 2. 在TEE中验证模型
        if not self.tee.remote_attestation(12345):
            return False
        
        # 3. 安全加载到TEE
        return self.tee.load_model_to_enclave(model_bytes)

MPC+HE混合方案

class MPCHEHybridInference:
    """MPC+HE混合推理"""
    def __init__(self, num_parties: int):
        self.mpc = MPCInference(num_parties)
        self.he = HomomorphicInference()
    
    def hybrid_linear_layer(self, shares: List[np.ndarray], 
                           weight: np.ndarray) -> List[np.ndarray]:
        """混合线性层"""
        # 使用MPC进行安全乘法
        mpc_result = self.mpc.mpc_linear_layer(shares, weight)
        
        # 使用HE进行安全激活
        he_shares = [self.he.encrypt_input(result) for result in mpc_result]
        activated = [self.he.homomorphic_activation(share) for share in he_shares]
        
        return activated
    
    def hybrid_convolution(self, shares: List[np.ndarray], 
                         kernel: np.ndarray) -> List[np.ndarray]:
        """混合卷积"""
        # 使用HE进行卷积
        encrypted_results = []
        for share in shares:
            encrypted_share = self.he.encrypt_input(share)
            conv_result = self.he.encrypted_convolution(encrypted_share, kernel)
            encrypted_results.append(conv_result)
        
        # 使用MPC进行安全聚合
        mpc_aggregated = self.mpc._mpc_sum(encrypted_results)
        
        return mpc_aggregated

性能优化

计算优化

class InferenceOptimization:
    """推理优化"""
    def __init__(self):
        self.batch_size = 32
        self.cache = {}
    
    def batch_encrypted_inference(self, encrypted_inputs: List[np.ndarray], 
                                model: Any) -> List[np.ndarray]:
        """批量加密推理"""
        results = []
        
        for i in range(0, len(encrypted_inputs), self.batch_size):
            batch = encrypted_inputs[i:i + self.batch_size]
            batch_result = self._process_batch(batch, model)
            results.extend(batch_result)
        
        return results
    
    def _process_batch(self, batch: List[np.ndarray], model: Any) -> List[np.ndarray]:
        """处理批次"""
        # 并行处理
        results = []
        for encrypted_input in batch:
            result = model.forward(encrypted_input)
            results.append(result)
        
        return results
    
    def precomputation(self, model: Any, common_inputs: List[np.ndarray]):
        """预计算"""
        for input_data in common_inputs:
            cache_key = hash(input_data.tobytes())
            if cache_key not in self.cache:
                self.cache[cache_key] = model.forward(input_data)
    
    def lookaside_cache(self, encrypted_input: np.ndarray) -> np.ndarray:
        """查找缓存"""
        cache_key = hash(encrypted_input.tobytes())
        return self.cache.get(cache_key, None)

安全性分析

class SecurityAnalysis:
    """安全性分析"""
    def __init__(self):
        self.security_metrics = {}
    
    def analyze_tee_security(self, tee: TEEEnclave) -> Dict[str, float]:
        """分析TEE安全性"""
        return {
            "attestation_strength": 0.95,
            "side_channel_resistance": 0.85,
            "memory_encryption": 0.99
        }
    
    def analyze_he_security(self, he: HomomorphicInference) -> Dict[str, float]:
        """分析同态加密安全性"""
        return {
            "security_level": he.security_level,
            "noise_growth_rate": 0.1,
            "decryption_error_probability": 1e-10
        }
    
    def analyze_mpc_security(self, mpc: MPCInference) -> Dict[str, float]:
        """分析MPC安全性"""
        return {
            "honest_majority": mpc.num_parties > 2,
            "fairness": 0.98,
            "guaranteed_output_delivery": 0.95
        }
    
    def overall_security_score(self, tee_score: float, he_score: float, 
                             mpc_score: float) -> float:
        """综合安全评分"""
        weights = {"tee": 0.4, "he": 0.3, "mpc": 0.3}
        return (tee_score * weights["tee"] + 
                he_score * weights["he"] + 
                mpc_score * weights["mpc"])

总结

私有推理是保护LLM应用隐私的关键技术。通过TEE、同态加密和安全多方计算等技术,可以在不暴露敏感数据的前提下执行推理。在实际应用中,需要根据具体需求选择合适的方案,并在安全性和性能之间取得平衡。