← 返回首页

高并发架构

📂 python ⏱ 9 min 1651 words

高并发架构

高并发架构是构建可扩展、高性能系统的关键。本文将深入探讨连接池、限流、熔断和负载均衡技术,帮助开发者设计可靠的高并发系统。

连接池设计

连接池是复用连接资源的重要技术,可以显著减少连接建立和销毁的开销。

import time
import threading
from dataclasses import dataclass
from typing import List, Optional, Any
from queue import Queue, Empty
from contextlib import contextmanager

@dataclass
class Connection:
    """连接对象"""
    id: int
    created_at: float
    last_used: float
    in_use: bool = False
    
    def is_healthy(self, max_age: float = 300.0) -> bool:
        """检查连接是否健康"""
        return time.time() - self.created_at < max_age

class ConnectionPool:
    """连接池实现"""
    
    def __init__(self, min_size: int = 5, max_size: int = 20, 
                 max_age: float = 300.0, connection_factory=None):
        self.min_size = min_size
        self.max_size = max_size
        self.max_age = max_age
        self.connection_factory = connection_factory or self._default_factory
        
        self.pool: Queue = Queue(maxsize=max_size)
        self.lock = threading.Lock()
        self.connections_count = 0
        self.active_connections = 0
        
        # 初始化最小连接数
        self._initialize_pool()
    
    def _default_factory(self):
        """默认连接工厂"""
        conn_id = self.connections_count
        self.connections_count += 1
        return Connection(
            id=conn_id,
            created_at=time.time(),
            last_used=time.time()
        )
    
    def _initialize_pool(self):
        """初始化连接池"""
        for _ in range(self.min_size):
            conn = self.connection_factory()
            self.pool.put(conn)
    
    @contextmanager
    def get_connection(self, timeout: float = 10.0):
        """获取连接(上下文管理器)"""
        conn = None
        try:
            conn = self._acquire_connection(timeout)
            yield conn
        finally:
            if conn:
                self._release_connection(conn)
    
    def _acquire_connection(self, timeout: float) -> Connection:
        """获取连接"""
        try:
            # 尝试从池中获取连接
            conn = self.pool.get(timeout=timeout)
            
            # 检查连接是否健康
            if not conn.is_healthy(self.max_age):
                conn = self.connection_factory()
            
            conn.in_use = True
            conn.last_used = time.time()
            
            with self.lock:
                self.active_connections += 1
            
            return conn
        except Empty:
            # 如果池已满,创建新连接
            if self.connections_count < self.max_size:
                conn = self.connection_factory()
                conn.in_use = True
                with self.lock:
                    self.connections_count += 1
                    self.active_connections += 1
                return conn
            else:
                raise ConnectionError("连接池已满,无法获取连接")
    
    def _release_connection(self, conn: Connection):
        """释放连接"""
        conn.in_use = False
        conn.last_used = time.time()
        
        with self.lock:
            self.active_connections -= 1
        
        # 将连接放回池中
        try:
            self.pool.put_nowait(conn)
        except Exception:
            # 如果池已满,丢弃连接
            pass
    
    def get_stats(self):
        """获取连接池统计信息"""
        return {
            "total_connections": self.connections_count,
            "active_connections": self.active_connections,
            "available_connections": self.pool.qsize()
        }

# 测试连接池
def test_connection_pool():
    print("连接池测试:")
    
    pool = ConnectionPool(min_size=3, max_size=10)
    
    # 模拟并发访问
    def worker(worker_id: int):
        with pool.get_connection() as conn:
            print(f"工作线程 {worker_id}: 使用连接 {conn.id}")
            time.sleep(0.1)  # 模拟工作
        print(f"工作线程 {worker_id}: 释放连接 {conn.id}")
    
    threads = []
    for i in range(8):
        t = threading.Thread(target=worker, args=(i,))
        threads.append(t)
        t.start()
    
    for t in threads:
        t.join()
    
    # 查看统计信息
    stats = pool.get_stats()
    print(f"\n连接池统计: {stats}")

test_connection_pool()

# 数据库连接池
class DatabaseConnectionPool(ConnectionPool):
    """数据库连接池"""
    
    def __init__(self, dsn: str, min_size: int = 5, max_size: int = 20):
        self.dsn = dsn
        super().__init__(min_size, max_size, connection_factory=self._create_db_connection)
    
    def _create_db_connection(self):
        """创建数据库连接"""
        # 这里应该是实际的数据库连接创建
        conn_id = self.connections_count
        self.connections_count += 1
        print(f"创建数据库连接: {conn_id}")
        return Connection(
            id=conn_id,
            created_at=time.time(),
            last_used=time.time()
        )

# 测试数据库连接池
print("\n数据库连接池测试:")
db_pool = DatabaseConnectionPool("postgresql://localhost/mydb", min_size=2, max_size=5)

def db_worker(worker_id: int):
    with db_pool.get_connection() as conn:
        print(f"数据库工作线程 {worker_id}: 使用连接 {conn.id}")
        time.sleep(0.05)
    print(f"数据库工作线程 {worker_id}: 释放连接")

threads = []
for i in range(6):
    t = threading.Thread(target=db_worker, args=(i,))
    threads.append(t)
    t.start()

for t in threads:
    t.join()

限流机制

限流是保护系统免受过载的重要技术。常见的限流算法包括令牌桶、漏桶和滑动窗口。

import time
import threading
from collections import deque
from dataclasses import dataclass
from typing import Callable, Any

class TokenBucket:
    """令牌桶算法"""
    
    def __init__(self, rate: float, capacity: int):
        self.rate = rate  # 每秒产生的令牌数
        self.capacity = capacity  # 桶容量
        self.tokens = capacity  # 当前令牌数
        self.last_update = time.time()
        self.lock = threading.Lock()
    
    def acquire(self, tokens: int = 1) -> bool:
        """获取令牌"""
        with self.lock:
            now = time.time()
            # 计算并添加新令牌
            elapsed = now - self.last_update
            self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
            self.last_update = now
            
            # 尝试获取令牌
            if self.tokens >= tokens:
                self.tokens -= tokens
                return True
            return False
    
    def wait(self, tokens: int = 1) -> bool:
        """等待获取令牌"""
        while not self.acquire(tokens):
            time.sleep(0.01)
        return True

class LeakyBucket:
    """漏桶算法"""
    
    def __init__(self, rate: float, capacity: int):
        self.rate = rate  # 每秒处理请求数
        self.capacity = capacity  # 桶容量
        self.water = 0  # 当前水量
        self.last_leak = time.time()
        self.lock = threading.Lock()
    
    def process(self, requests: int = 1) -> bool:
        """处理请求"""
        with self.lock:
            now = time.time()
            # 漏水
            elapsed = now - self.last_leak
            leaked = elapsed * self.rate
            self.water = max(0, self.water - leaked)
            self.last_leak = now
            
            # 添加新请求
            if self.water + requests <= self.capacity:
                self.water += requests
                return True
            return False

class SlidingWindow:
    """滑动窗口限流"""
    
    def __init__(self, window_size: float, max_requests: int):
        self.window_size = window_size  # 窗口大小(秒)
        self.max_requests = max_requests  # 窗口内最大请求数
        self.requests = deque()  # 请求时间戳
        self.lock = threading.Lock()
    
    def allow(self) -> bool:
        """检查是否允许请求"""
        with self.lock:
            now = time.time()
            
            # 移除过期的请求
            while self.requests and self.requests[0] < now - self.window_size:
                self.requests.popleft()
            
            # 检查是否超过限制
            if len(self.requests) < self.max_requests:
                self.requests.append(now)
                return True
            return False

# 测试限流算法
def test_rate_limiting():
    print("限流算法测试:")
    
    # 令牌桶测试
    print("\n令牌桶算法:")
    token_bucket = TokenBucket(rate=10, capacity=5)  # 每秒10个令牌,容量5
    
    start_time = time.time()
    success_count = 0
    fail_count = 0
    
    for _ in range(20):
        if token_bucket.acquire():
            success_count += 1
        else:
            fail_count += 1
        time.sleep(0.05)
    
    print(f"成功: {success_count}, 失败: {fail_count}")
    
    # 漏桶测试
    print("\n漏桶算法:")
    leaky_bucket = LeakyBucket(rate=5, capacity=10)  # 每秒5个请求,容量10
    
    success_count = 0
    fail_count = 0
    
    for _ in range(15):
        if leaky_bucket.process():
            success_count += 1
        else:
            fail_count += 1
        time.sleep(0.1)
    
    print(f"成功: {success_count}, 失败: {fail_count}")
    
    # 滑动窗口测试
    print("\n滑动窗口算法:")
    sliding_window = SlidingWindow(window_size=1.0, max_requests=5)  # 1秒内最多5个请求
    
    success_count = 0
    fail_count = 0
    
    for _ in range(10):
        if sliding_window.allow():
            success_count += 1
        else:
            fail_count += 1
        time.sleep(0.1)
    
    print(f"成功: {success_count}, 失败: {fail_count}")

test_rate_limiting()

# API限流装饰器
def rate_limitDecorator(rate: float, capacity: int):
    """API限流装饰器"""
    bucket = TokenBucket(rate, capacity)
    
    def decorator(func: Callable) -> Callable:
        def wrapper(*args, **kwargs) -> Any:
            if bucket.acquire():
                return func(*args, **kwargs)
            else:
                raise Exception("请求过于频繁,请稍后再试")
        return wrapper
    return decorator

# 测试装饰器
@rate_limitDecorator(rate=5, capacity=3)
def api_endpoint():
    return "API响应"

print("\nAPI限流装饰器测试:")
for i in range(8):
    try:
        result = api_endpoint()
        print(f"请求 {i+1}: {result}")
    except Exception as e:
        print(f"请求 {i+1}: {e}")
    time.sleep(0.1)

熔断器模式

熔断器是防止系统级联故障的重要机制。当服务出现故障时,熔断器会快速失败,避免雪崩效应。

import time
import threading
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Any

class CircuitState(Enum):
    CLOSED = "closed"      # 正常状态
    OPEN = "open"          # 熔断状态
    HALF_OPEN = "half_open"  # 半开状态

@dataclass
class CircuitBreaker:
    """熔断器"""
    
    failure_threshold: int = 5  # 失败阈值
    recovery_timeout: float = 30.0  # 恢复超时(秒)
    success_threshold: int = 3  # 半开状态下的成功阈值
    
    def __post_init__(self):
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.success_count = 0
        self.last_failure_time = 0
        self.lock = threading.Lock()
    
    def call(self, func: Callable, *args, **kwargs) -> Any:
        """通过熔断器调用函数"""
        with self.lock:
            if self.state == CircuitState.OPEN:
                # 检查是否应该进入半开状态
                if time.time() - self.last_failure_time > self.recovery_timeout:
                    self.state = CircuitState.HALF_OPEN
                    self.success_count = 0
                    print("熔断器进入半开状态")
                else:
                    raise Exception("熔断器处于打开状态,请求被拒绝")
            
            try:
                result = func(*args, **kwargs)
                
                # 成功处理
                if self.state == CircuitState.HALF_OPEN:
                    self.success_count += 1
                    if self.success_count >= self.success_threshold:
                        self.state = CircuitState.CLOSED
                        self.failure_count = 0
                        print("熔断器恢复到关闭状态")
                else:
                    self.failure_count = 0
                
                return result
                
            except Exception as e:
                # 失败处理
                self.failure_count += 1
                self.last_failure_time = time.time()
                
                if self.failure_count >= self.failure_threshold:
                    self.state = CircuitState.OPEN
                    print(f"熔断器打开,连续失败 {self.failure_count} 次")
                
                raise e

# 测试熔断器
def test_circuit_breaker():
    print("熔断器测试:")
    
    # 模拟不稳定服务
    call_count = 0
    
    def unstable_service():
        nonlocal call_count
        call_count += 1
        
        # 前5次调用失败
        if call_count <= 5:
            raise Exception(f"服务调用失败 (第 {call_count} 次)")
        
        # 之后成功
        return f"服务调用成功 (第 {call_count} 次)"
    
    circuit_breaker = CircuitBreaker(
        failure_threshold=3,
        recovery_timeout=2.0,
        success_threshold=2
    )
    
    # 测试调用
    for i in range(10):
        try:
            result = circuit_breaker.call(unstable_service)
            print(f"调用 {i+1}: {result}")
        except Exception as e:
            print(f"调用 {i+1}: {e}")
        time.sleep(0.5)
    
    # 等待恢复
    print("\n等待熔断器恢复...")
    time.sleep(3)
    
    # 再次测试
    for i in range(3):
        try:
            result = circuit_breaker.call(unstable_service)
            print(f"恢复后调用 {i+1}: {result}")
        except Exception as e:
            print(f"恢复后调用 {i+1}: {e}")
        time.sleep(0.5)

test_circuit_breaker()

# 高级熔断器
class AdvancedCircuitBreaker(CircuitBreaker):
    """高级熔断器,支持监控和回调"""
    
    def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 30.0,
                 success_threshold: int = 3):
        super().__init__(failure_threshold, recovery_timeout, success_threshold)
        self.failure_callbacks = []
        self.success_callbacks = []
        self.state_change_callbacks = []
    
    def on_failure(self, callback: Callable):
        """注册失败回调"""
        self.failure_callbacks.append(callback)
    
    def on_success(self, callback: Callable):
        """注册成功回调"""
        self.success_callbacks.append(callback)
    
    def on_state_change(self, callback: Callable):
        """注册状态变化回调"""
        self.state_change_callbacks.append(callback)
    
    def _notify_state_change(self, old_state: CircuitState, new_state: CircuitState):
        """通知状态变化"""
        for callback in self.state_change_callbacks:
            try:
                callback(old_state, new_state)
            except Exception as e:
                print(f"状态变化回调执行失败: {e}")

# 测试高级熔断器
print("\n高级熔断器测试:")

def on_failure(error):
    print(f"失败回调: {error}")

def on_success(result):
    print(f"成功回调: {result}")

def on_state_change(old_state, new_state):
    print(f"状态变化: {old_state.value} -> {new_state.value}")

advanced_breaker = AdvancedCircuitBreaker(
    failure_threshold=2,
    recovery_timeout=2.0,
    success_threshold=2
)

advanced_breaker.on_failure(on_failure)
advanced_breaker.on_success(on_success)
advanced_breaker.on_state_change(on_state_change)

# 测试
call_count = 0
def failing_service():
    nonlocal call_count
    call_count += 1
    if call_count <= 2:
        raise Exception("服务失败")
    return "服务成功"

for i in range(5):
    try:
        result = advanced_breaker.call(failing_service)
        print(f"调用 {i+1}: {result}")
    except Exception as e:
        print(f"调用 {i+1}: {e}")
    time.sleep(0.5)

负载均衡

负载均衡是将请求分发到多个服务器的技术,可以提高系统的吞吐量和可用性。

import time
import threading
import random
from dataclasses import dataclass
from typing import List, Callable, Any
from collections import defaultdict

@dataclass
class Server:
    """服务器节点"""
    id: str
    weight: int = 1
    current_load: int = 0
    total_requests: int = 0
    is_healthy: bool = True
    
    def handle_request(self) -> str:
        """处理请求"""
        if not self.is_healthy:
            raise Exception(f"服务器 {self.id} 不健康")
        
        self.current_load += 1
        self.total_requests += 1
        
        # 模拟处理时间
        time.sleep(0.01)
        
        self.current_load -= 1
        return f"服务器 {self.id} 处理请求"

class LoadBalancer:
    """负载均衡器基类"""
    
    def __init__(self, servers: List[Server]):
        self.servers = servers
        self.lock = threading.Lock()
    
    def get_server(self) -> Server:
        """获取服务器"""
        raise NotImplementedError
    
    def handle_request(self) -> str:
        """处理请求"""
        server = self.get_server()
        return server.handle_request()

class RoundRobinBalancer(LoadBalancer):
    """轮询负载均衡"""
    
    def __init__(self, servers: List[Server]):
        super().__init__(servers)
        self.current_index = 0
    
    def get_server(self) -> Server:
        """获取下一个服务器"""
        with self.lock:
            # 找到健康的服务器
            start_index = self.current_index
            while True:
                server = self.servers[self.current_index]
                self.current_index = (self.current_index + 1) % len(self.servers)
                
                if server.is_healthy:
                    return server
                
                # 防止无限循环
                if self.current_index == start_index:
                    raise Exception("没有可用的健康服务器")

class WeightedRoundRobinBalancer(LoadBalancer):
    """加权轮询负载均衡"""
    
    def __init__(self, servers: List[Server]):
        super().__init__(servers)
        self.current_weights = [0] * len(servers)
    
    def get_server(self) -> Server:
        """获取服务器(考虑权重)"""
        with self.lock:
            # 计算总权重
            total_weight = sum(s.weight for s in self.servers if s.is_healthy)
            
            if total_weight == 0:
                raise Exception("没有可用的健康服务器")
            
            # 更新当前权重
            for i, server in enumerate(self.servers):
                if server.is_healthy:
                    self.current_weights[i] += server.weight
            
            # 选择权重最大的服务器
            max_index = 0
            for i in range(len(self.servers)):
                if (self.servers[i].is_healthy and 
                    self.current_weights[i] > self.current_weights[max_index]):
                    max_index = i
            
            # 减去最大权重
            self.current_weights[max_index] -= total_weight
            
            return self.servers[max_index]

class LeastConnectionsBalancer(LoadBalancer):
    """最小连接数负载均衡"""
    
    def get_server(self) -> Server:
        """获取连接数最少的服务器"""
        with self.lock:
            healthy_servers = [s for s in self.servers if s.is_healthy]
            
            if not healthy_servers:
                raise Exception("没有可用的健康服务器")
            
            # 选择当前负载最小的服务器
            return min(healthy_servers, key=lambda s: s.current_load)

class ConsistentHashBalancer(LoadBalancer):
    """一致性哈希负载均衡"""
    
    def __init__(self, servers: List[Server], virtual_nodes: int = 150):
        super().__init__(servers)
        self.virtual_nodes = virtual_nodes
        self.ring = {}
        self.sorted_keys = []
        
        # 构建哈希环
        for server in servers:
            for i in range(virtual_nodes):
                key = self._hash(f"{server.id}:{i}")
                self.ring[key] = server
                self.sorted_keys.append(key)
        
        self.sorted_keys.sort()
    
    def _hash(self, key: str) -> int:
        """哈希函数"""
        return hash(key) & 0xFFFFFFFF
    
    def get_server(self, request_key: str = None) -> Server:
        """根据请求键获取服务器"""
        if request_key is None:
            request_key = str(random.random())
        
        with self.lock:
            if not self.ring:
                raise Exception("没有可用的服务器")
            
            # 计算请求键的哈希值
            hash_key = self._hash(request_key)
            
            # 在环上找到第一个大于等于hash_key的节点
            for key in self.sorted_keys:
                if key >= hash_key:
                    server = self.ring[key]
                    if server.is_healthy:
                        return server
            
            # 如果没找到,返回环上的第一个节点
            server = self.ring[self.sorted_keys[0]]
            if server.is_healthy:
                return server
            
            raise Exception("没有可用的健康服务器")

# 测试负载均衡
def test_load_balancing():
    print("负载均衡测试:")
    
    # 创建服务器
    servers = [
        Server("server1", weight=3),
        Server("server2", weight=2),
        Server("server3", weight=1)
    ]
    
    # 测试轮询负载均衡
    print("\n轮询负载均衡:")
    rr_balancer = RoundRobinBalancer(servers)
    
    for i in range(6):
        try:
            result = rr_balancer.handle_request()
            print(f"请求 {i+1}: {result}")
        except Exception as e:
            print(f"请求 {i+1}: {e}")
    
    # 测试加权轮询负载均衡
    print("\n加权轮询负载均衡:")
    wrr_balancer = WeightedRoundRobinBalancer(servers)
    
    request_count = defaultdict(int)
    for i in range(12):
        try:
            server = wrr_balancer.get_server()
            request_count[server.id] += 1
            print(f"请求 {i+1}: 服务器 {server.id}")
        except Exception as e:
            print(f"请求 {i+1}: {e}")
    
    print(f"请求分布: {dict(request_count)}")
    
    # 测试最小连接数负载均衡
    print("\n最小连接数负载均衡:")
    lc_balancer = LeastConnectionsBalancer(servers)
    
    for i in range(6):
        try:
            server = lc_balancer.get_server()
            print(f"请求 {i+1}: 服务器 {server.id} (当前负载: {server.current_load})")
            server.current_load += 1  # 模拟处理
        except Exception as e:
            print(f"请求 {i+1}: {e}")
    
    # 测试一致性哈希负载均衡
    print("\n一致性哈希负载均衡:")
    ch_balancer = ConsistentHashBalancer(servers)
    
    for i in range(6):
        try:
            request_key = f"request_{i}"
            server = ch_balancer.get_server(request_key)
            print(f"请求 {i+1} (键: {request_key}): 服务器 {server.id}")
        except Exception as e:
            print(f"请求 {i+1}: {e}")

test_load_balancing()

# 负载均衡监控
class LoadBalancerMonitor:
    """负载均衡监控"""
    
    def __init__(self, balancer: LoadBalancer):
        self.balancer = balancer
        self.request_stats = defaultdict(int)
        self.lock = threading.Lock()
    
    def monitor_request(self, server_id: str):
        """监控请求"""
        with self.lock:
            self.request_stats[server_id] += 1
    
    def get_stats(self):
        """获取统计信息"""
        with self.lock:
            return dict(self.request_stats)
    
    def print_stats(self):
        """打印统计信息"""
        stats = self.get_stats()
        print("\n负载均衡统计:")
        for server_id, count in stats.items():
            print(f"  {server_id}: {count} 个请求")

# 测试监控
print("\n负载均衡监控测试:")
monitor = LoadBalancerMonitor(rr_balancer)

# 模拟请求
for i in range(20):
    try:
        server = rr_balancer.get_server()
        monitor.monitor_request(server.id)
    except Exception as e:
        print(f"请求失败: {e}")

monitor.print_stats()

高并发架构是构建可扩展、高性能系统的关键。通过深入理解连接池、限流、熔断和负载均衡技术,开发者可以设计出可靠、高效的高并发系统。掌握这些技术不仅有助于解决性能瓶颈,还能提高系统的可用性和稳定性,是成为高级Python开发者的重要技能。