← 返回首页
🔗

SQLAlchemy入门:ORM、Model、Session与查询

📂 python ⏱ 4 min 796 words

SQLAlchemy入门:ORM、Model、Session与查询

SQLAlchemy是Python最强大的ORM(对象关系映射)框架,它将数据库表映射为Python类,让开发者用面向对象的方式操作数据库。本文将带你掌握SQLAlchemy的核心概念和使用方法。

安装SQLAlchemy

pip install sqlalchemy
# 如果使用SQLite
# 内置,无需额外安装
# 如果使用MySQL
# pip install pymysql
# 如果使用PostgreSQL
# pip install psycopg2-binary

核心概念

SQLAlchemy由三个核心组件构成:

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base

# 创建引擎
engine = create_engine(
    "sqlite:///mydb.db",
    echo=True,  # 显示SQL语句
    pool_size=5,
    pool_pre_ping=True
)

# 创建会话工厂
SessionLocal = sessionmaker(bind=engine)

# 创建基类
Base = declarative_base()

定义Model

from sqlalchemy import Column, Integer, String, Float, ForeignKey, DateTime, Text
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from sqlalchemy.orm import declarative_base

Base = declarative_base()

class User(Base):
    __tablename__ = "users"
    
    id = Column(Integer, primary_key=True, autoincrement=True)
    username = Column(String(50), unique=True, nullable=False, index=True)
    email = Column(String(100), unique=True, nullable=False)
    age = Column(Integer)
    created_at = Column(DateTime, server_default=func.now())
    
    # 关系
    posts = relationship("Post", back_populates="author", cascade="all, delete-orphan")
    
    def __repr__(self):
        return f"<User(id={self.id}, username='{self.username}')>"
    
    def to_dict(self):
        return {
            "id": self.id,
            "username": self.username,
            "email": self.email,
            "age": self.age
        }

class Post(Base):
    __tablename__ = "posts"
    
    id = Column(Integer, primary_key=True, autoincrement=True)
    title = Column(String(200), nullable=False)
    content = Column(Text)
    author_id = Column(Integer, ForeignKey("users.id"), nullable=False)
    created_at = Column(DateTime, server_default=func.now())
    updated_at = Column(DateTime, onupdate=func.now())
    
    # 关系
    author = relationship("User", back_populates="posts")
    tags = relationship("Tag", secondary="post_tags", back_populates="posts")
    
    def __repr__(self):
        return f"<Post(id={self.id}, title='{self.title}')>"

class Tag(Base):
    __tablename__ = "tags"
    
    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(String(50), unique=True, nullable=False)
    
    # 多对多关系
    posts = relationship("Post", secondary="post_tags", back_populates="tags")

# 关联表
from sqlalchemy import Table
post_tags = Table(
    "post_tags",
    Base.metadata,
    Column("post_id", Integer, ForeignKey("posts.id"), primary_key=True),
    Column("tag_id", Integer, ForeignKey("tags.id"), primary_key=True)
)

创建表与Session操作

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session

engine = create_engine("sqlite:///blog.db", echo=False)
SessionLocal = sessionmaker(bind=engine)

# 创建所有表
Base.metadata.create_all(engine)

# 创建Session的函数
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 基本的CRUD操作
def create_user(db: Session, username: str, email: str, age: int = None):
    user = User(username=username, email=email, age=age)
    db.add(user)
    db.commit()
    db.refresh(user)  # 刷新以获取数据库生成的值
    return user

def get_user(db: Session, user_id: int):
    return db.query(User).filter(User.id == user_id).first()

def get_user_by_username(db: Session, username: str):
    return db.query(User).filter(User.username == username).first()

def get_users(db: Session, skip: int = 0, limit: int = 100):
    return db.query(User).offset(skip).limit(limit).all()

def update_user(db: Session, user_id: int, **kwargs):
    user = db.query(User).filter(User.id == user_id).first()
    if user:
        for key, value in kwargs.items():
            setattr(user, key, value)
        db.commit()
        db.refresh(user)
    return user

def delete_user(db: Session, user_id: int):
    user = db.query(User).filter(User.id == user_id).first()
    if user:
        db.delete(user)
        db.commit()
        return True
    return False

查询构建

# 基本查询
users = db.query(User).all()  # 所有用户
user = db.query(User).first()  # 第一个用户
user = db.query(User).get(1)   # 按主键查询

# 条件过滤
young_users = db.query(User).filter(User.age < 25).all()
specific_user = db.query(User).filter(
    User.username == "张三",
    User.email.contains("@")
).first()

# 多种过滤方式
from sqlalchemy import or_, and_, not_

results = db.query(User).filter(
    or_(
        User.age < 20,
        User.age > 60
    )
).all()

# 排序
users = db.query(User).order_by(User.age.desc()).all()
users = db.query(User).order_by(User.username.asc(), User.age.desc()).all()

# 分页
page = 1
per_page = 10
users = db.query(User).offset((page - 1) * per_page).limit(per_page).all()

# 计数
total = db.query(User).filter(User.age >= 18).count()

# 存在性检查
exists = db.query(User.filter(User.username == "张三").exists()).scalar()

# 指定查询列
results = db.query(User.username, User.email).all()

# 去重
distinct_ages = db.query(User.age).distinct().all()

关系查询

# 一对多查询
user = db.query(User).filter(User.username == "张三").first()
print(f"用户的文章: {user.posts}")

# 预加载关系(减少N+1查询问题)
from sqlalchemy.orm import joinedload, selectinload

# joinedload - 使用JOIN
users = db.query(User).options(joinedload(User.posts)).all()

# selectinload - 使用子查询
users = db.query(User).options(selectinload(User.posts)).all()

# 多对多查询
post = db.query(Post).filter(Post.id == 1).first()
print(f"文章标签: {post.tags}")

tag = db.query(Tag).filter(Tag.name == "Python").first()
print(f"标签下的文章: {tag.posts}")

# 关系操作
user = db.query(User).filter(User.id == 1).first()
post = Post(title="新文章", content="内容", author=user)
db.add(post)
db.commit()

# 添加标签
tag = db.query(Tag).filter(Tag.name == "Python").first()
post.tags.append(tag)
db.commit()

聚合与分组

from sqlalchemy import func

# 聚合查询
stats = db.query(
    func.count(User.id).label("total_users"),
    func.avg(User.age).label("avg_age"),
    func.max(User.age).label("max_age"),
    func.min(User.age).label("min_age")
).first()

print(f"总用户数: {stats.total_users}")
print(f"平均年龄: {stats.avg_age:.1f}")

# 分组查询
age_groups = db.query(
    User.age,
    func.count(User.id).label("count")
).group_by(User.age).all()

for age, count in age_groups:
    print(f"{age}岁: {count}人")

# HAVING子句
popular_ages = db.query(
    User.age,
    func.count(User.id).label("count")
).group_by(User.age).having(func.count(User.id) > 5).all()

# 子查询
from sqlalchemy import select

subquery = db.query(
    Post.author_id,
    func.count(Post.id).label("post_count")
).group_by(Post.author_id).subquery()

results = db.query(
    User.username,
    subquery.c.post_count
).join(subquery, User.id == subquery.c.author_id).all()

事务管理

# 使用Session进行事务管理
def transfer_posts(from_user_id: int, to_user_id: int):
    db = SessionLocal()
    try:
        from_user = db.query(User).get(from_user_id)
        to_user = db.query(User).get(to_user_id)
        
        if not from_user or not to_user:
            raise ValueError("用户不存在")
        
        # 修改作者
        for post in from_user.posts:
            post.author_id = to_user_id
        
        db.commit()
        return True
        
    except Exception as e:
        db.rollback()
        print(f"事务失败: {e}")
        return False
    finally:
        db.close()

# 使用上下文管理器
from contextlib import contextmanager

@contextmanager
def get_db_session():
    db = SessionLocal()
    try:
        yield db
        db.commit()
    except Exception:
        db.rollback()
        raise
    finally:
        db.close()

# 使用示例
with get_db_session() as db:
    user = User(username="新用户", email="new@example.com")
    db.add(user)
    # 退出时自动提交

实战示例:博客系统

from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, relationship, declarative_base
from sqlalchemy.sql import func
from datetime import datetime

Base = declarative_base()

class BlogDB:
    def __init__(self, db_url="sqlite:///blog.db"):
        self.engine = create_engine(db_url, echo=False)
        self.Session = sessionmaker(bind=self.engine)
        Base.metadata.create_all(self.engine)
    
    def get_session(self):
        return self.Session()
    
    def create_user(self, username, email):
        with self.get_session() as db:
            user = User(username=username, email=email)
            db.add(user)
            db.commit()
            return user.to_dict()
    
    def create_post(self, title, content, author_id, tag_names=None):
        with self.get_session() as db:
            post = Post(title=title, content=content, author_id=author_id)
            
            if tag_names:
                for tag_name in tag_names:
                    tag = db.query(Tag).filter(Tag.name == tag_name).first()
                    if not tag:
                        tag = Tag(name=tag_name)
                        db.add(tag)
                    post.tags.append(tag)
            
            db.add(post)
            db.commit()
            return {"id": post.id, "title": post.title}
    
    def get_user_posts(self, user_id):
        with self.get_session() as db:
            user = db.query(User).get(user_id)
            if user:
                return [
                    {"id": p.id, "title": p.title, "created_at": str(p.created_at)}
                    for p in user.posts
                ]
            return []
    
    def search_posts(self, keyword):
        with self.get_session() as db:
            posts = db.query(Post).filter(
                Post.title.contains(keyword) | Post.content.contains(keyword)
            ).all()
            return [{"id": p.id, "title": p.title} for p in posts]

# 使用示例
blog = BlogDB()
blog.create_user("张三", "zhangsan@example.com")
blog.create_post("Python入门", "Python是一门优雅的...", 1, ["Python", "教程"])
posts = blog.get_user_posts(1)

总结

SQLAlchemy是Python数据库开发的瑞士军刀。掌握Model定义、Session管理、关系映射和查询构建后,你可以高效地操作任何关系型数据库。实际项目中,推荐结合FastAPI或Flask使用,并考虑使用Alembic进行数据库迁移。