← 返回首页

魔术方法大全

📂 python ⏱ 5 min 874 words

魔术方法大全

魔术方法(Magic Methods)是Python中以双下划线开头和结尾的特殊方法,它们让你的类与Python的内置操作无缝集成。

字符串表示

str__和__repr

class Point:
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __repr__(self):
        """开发者友好的表示,用于调试"""
        return f"Point({self.x!r}, {self.y!r})"
    
    def __str__(self):
        """用户友好的表示"""
        return f"({self.x}, {self.y})"
    
    def __format__(self, format_spec):
        """自定义格式化"""
        if format_spec == 'vector':
            return f"[{self.x}, {self.y}]"
        return f"({self.x}, {self.y})"

p = Point(1, 2)
print(repr(p))   # Point(1, 2)
print(str(p))    # (1, 2)
print(f"{p}")    # (1, 2)
print(f"{p:vector}")  # [1, 2]

__format__详解

class Money:
    def __init__(self, amount, currency="CNY"):
        self.amount = amount
        self.currency = currency
    
    def __format__(self, format_spec):
        if format_spec == '':
            return f"{self.amount} {self.currency}"
        elif format_spec == 'cn':
            return f"¥{self.amount:.2f}"
        elif format_spec == 'us':
            return f"${self.amount:.2f}"
        elif format_spec == 'eu':
            return f"€{self.amount:.2f}"
        elif format_spec.endswith('d'):
            precision = int(format_spec[:-1]) if len(format_spec) > 1 else 0
            return f"{self.amount:.{precision}f} {self.currency}"
        return f"{self.amount:.2f} {self.currency}"

price = Money(99.9)
print(f"价格: {price}")      # 价格: 99.9 CNY
print(f"价格: {price:cn}")   # 价格: ¥99.90
print(f"价格: {price:us}")   # 价格: $99.90
print(f"价格: {price:.2d}")  # 价格: 99.90 CNY

相等性和哈希

eq__和__hash

class Vector:
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __eq__(self, other):
        if not isinstance(other, Vector):
            return NotImplemented
        return self.x == other.x and self.y == other.y
    
    def __hash__(self):
        return hash((self.x, self.y))

v1 = Vector(1, 2)
v2 = Vector(1, 2)
v3 = Vector(3, 4)

print(v1 == v2)  # True
print(v1 == v3)  # False
print(hash(v1) == hash(v2))  # True

# 可以用作字典键或集合元素
vectors = {v1: "向量1", v3: "向量3"}
print(vectors[v2])  # 向量1

ne、__lt__等比较方法

from functools import total_ordering

@total_ordering
class Version:
    def __init__(self, major, minor, patch):
        self.major = major
        self.minor = minor
        self.patch = patch
    
    def __eq__(self, other):
        if not isinstance(other, Version):
            return NotImplemented
        return (self.major, self.minor, self.patch) == \
               (other.major, other.minor, other.patch)
    
    def __lt__(self, other):
        if not isinstance(other, Version):
            return NotImplemented
        return (self.major, self.minor, self.patch) < \
               (other.major, other.minor, other.patch)
    
    def __repr__(self):
        return f"Version({self.major}, {self.minor}, {self.patch})"

v1 = Version(1, 2, 3)
v2 = Version(1, 2, 4)
v3 = Version(1, 3, 0)

print(v1 < v2)    # True
print(v2 > v3)    # False
print(v1 <= v2)   # True
print(v3 >= v1)   # True
print(v1 == v1)   # True
print(v1 != v2)   # True

运算符重载

算术运算符

class Matrix:
    def __init__(self, data):
        self.data = [row[:] for row in data]
        self.rows = len(data)
        self.cols = len(data[0]) if data else 0
    
    def __add__(self, other):
        if self.rows != other.rows or self.cols != other.cols:
            raise ValueError("矩阵维度不匹配")
        result = [[self.data[i][j] + other.data[i][j] 
                   for j in range(self.cols)] 
                  for i in range(self.rows)]
        return Matrix(result)
    
    def __sub__(self, other):
        if self.rows != other.rows or self.cols != other.cols:
            raise ValueError("矩阵维度不匹配")
        result = [[self.data[i][j] - other.data[i][j] 
                   for j in range(self.cols)] 
                  for i in range(self.rows)]
        return Matrix(result)
    
    def __mul__(self, scalar):
        if isinstance(scalar, Matrix):
            raise NotImplementedError("矩阵乘法未实现")
        result = [[self.data[i][j] * scalar 
                   for j in range(self.cols)] 
                  for i in range(self.rows)]
        return Matrix(result)
    
    def __rmul__(self, scalar):
        return self.__mul__(scalar)
    
    def __neg__(self):
        return self * (-1)
    
    def __repr__(self):
        return f"Matrix({self.data})"

m1 = Matrix([[1, 2], [3, 4]])
m2 = Matrix([[5, 6], [7, 8]])

print(m1 + m2)  # Matrix([[6, 8], [10, 12]])
print(m1 - m2)  # Matrix([[-4, -4], [-4, -4]])
print(m1 * 2)   # Matrix([[2, 4], [6, 8]])
print(3 * m1)   # Matrix([[3, 6], [9, 12]])
print(-m1)      # Matrix([[-1, -2], [-3, -4]])

位运算符

class Flags:
    def __init__(self, value=0):
        self.value = value
    
    def __and__(self, other):
        return Flags(self.value & other.value)
    
    def __or__(self, other):
        return Flags(self.value | other.value)
    
    def __xor__(self, other):
        return Flags(self.value ^ other.value)
    
    def __invert__(self):
        return Flags(~self.value)
    
    def __lshift__(self, n):
        return Flags(self.value << n)
    
    def __rshift__(self, n):
        return Flags(self.value >> n)
    
    def __repr__(self):
        return f"Flags({bin(self.value)})"

READ = Flags(0b001)
WRITE = Flags(0b010)
EXECUTE = Flags(0b100)

permissions = READ | WRITE
print(permissions)           # Flags(0b11)
print(bool(permissions & READ))   # True
print(bool(permissions & EXECUTE))  # False

容器协议

class SortedList:
    def __init__(self):
        self._data = []
    
    def add(self, item):
        import bisect
        bisect.insort(self._data, item)
    
    def __contains__(self, item):
        return item in self._data
    
    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, index):
        return self._data[index]
    
    def __setitem__(self, index, value):
        self._data[index] = value
        self._data.sort()
    
    def __delitem__(self, index):
        del self._data[index]
    
    def __iter__(self):
        return iter(self._data)
    
    def __reversed__(self):
        return reversed(self._data)
    
    def __repr__(self):
        return f"SortedList({self._data})"

sl = SortedList()
sl.add(5)
sl.add(3)
sl.add(7)
sl.add(1)

print(sl)              # SortedList([1, 3, 5, 7])
print(5 in sl)         # True
print(len(sl))         # 4
print(sl[0])           # 1
print(list(reversed(sl)))  # [7, 5, 3, 1]

上下文管理器

class FileManager:
    def __init__(self, filename, mode):
        self.filename = filename
        self.mode = mode
        self.file = None
    
    def __enter__(self):
        print(f"打开文件: {self.filename}")
        self.file = open(self.filename, self.mode)
        return self.file
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.file:
            self.file.close()
            print(f"关闭文件: {self.filename}")
        
        if exc_type is not None:
            print(f"发生异常: {exc_val}")
            return False  # 不抑制异常
        return True

# 使用with语句
with FileManager("test.txt", "w") as f:
    f.write("Hello, World!")

异步上下文管理器

import asyncio

class AsyncDatabase:
    def __init__(self):
        self.connected = False
    
    async def __aenter__(self):
        print("异步连接数据库...")
        await asyncio.sleep(0.1)
        self.connected = True
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        print("异步断开数据库...")
        await asyncio.sleep(0.1)
        self.connected = False
        return False
    
    async def query(self, sql):
        await asyncio.sleep(0.05)
        return f"执行: {sql}"

async def main():
    async with AsyncDatabase() as db:
        result = await db.query("SELECT * FROM users")
        print(result)

asyncio.run(main())

可调用对象

class Multiplier:
    def __init__(self, factor):
        self.factor = factor
    
    def __call__(self, x):
        return x * self.factor

double = Multiplier(2)
triple = Multiplier(3)

print(double(5))   # 10
print(triple(5))   # 15
print(callable(double))  # True

数值转换

class Temperature:
    def __init__(self, celsius):
        self.celsius = celsius
    
    def __int__(self):
        return int(self.celsius)
    
    def __float__(self):
        return float(self.celsius)
    
    def __complex__(self):
        return complex(self.celsius, 0)
    
    def __bool__(self):
        return self.celsius != 0

t = Temperature(25.5)
print(int(t))     # 25
print(float(t))   # 25.5
print(bool(t))    # True

t0 = Temperature(0)
print(bool(t0))   # False

魔术方法速查表

类别 方法 用途
字符串 __str__, __repr__, __format__ 对象表示
比较 __eq__, __ne__, __lt__, __le__, __gt__, __ge__ 比较操作
算术 __add__, __sub__, __mul__, __truediv__, __floordiv__, __mod__ 算术运算
位运算 __and__, __or__, __xor__, __invert__, __lshift__, __rshift__ 位操作
容器 __len__, __getitem__, __setitem__, __delitem__, __contains__, __iter__ 容器协议
上下文 __enter__, __exit__, __aenter__, __aexit__ with语句
数值 __int__, __float__, __complex__, __bool__ 类型转换
其他 __call__, __hash__, __copy__, __deepcopy__ 其他特殊行为

掌握魔术方法能让你的类与Python生态无缝集成,写出更加Pythonic的代码。