SQLAlchemy入门:ORM、Model、Session与查询
SQLAlchemy入门:ORM、Model、Session与查询
SQLAlchemy是Python最强大的ORM(对象关系映射)框架,它将数据库表映射为Python类,让开发者用面向对象的方式操作数据库。本文将带你掌握SQLAlchemy的核心概念和使用方法。
安装SQLAlchemy
pip install sqlalchemy
# 如果使用SQLite
# 内置,无需额外安装
# 如果使用MySQL
# pip install pymysql
# 如果使用PostgreSQL
# pip install psycopg2-binary
核心概念
SQLAlchemy由三个核心组件构成:
- Engine:数据库连接引擎
- Session:数据库会话,管理事务
- Model:映射到数据库表的Python类
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
# 创建引擎
engine = create_engine(
"sqlite:///mydb.db",
echo=True, # 显示SQL语句
pool_size=5,
pool_pre_ping=True
)
# 创建会话工厂
SessionLocal = sessionmaker(bind=engine)
# 创建基类
Base = declarative_base()
定义Model
from sqlalchemy import Column, Integer, String, Float, ForeignKey, DateTime, Text
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from sqlalchemy.orm import declarative_base
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(100), unique=True, nullable=False)
age = Column(Integer)
created_at = Column(DateTime, server_default=func.now())
# 关系
posts = relationship("Post", back_populates="author", cascade="all, delete-orphan")
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}')>"
def to_dict(self):
return {
"id": self.id,
"username": self.username,
"email": self.email,
"age": self.age
}
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True, autoincrement=True)
title = Column(String(200), nullable=False)
content = Column(Text)
author_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime, server_default=func.now())
updated_at = Column(DateTime, onupdate=func.now())
# 关系
author = relationship("User", back_populates="posts")
tags = relationship("Tag", secondary="post_tags", back_populates="posts")
def __repr__(self):
return f"<Post(id={self.id}, title='{self.title}')>"
class Tag(Base):
__tablename__ = "tags"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), unique=True, nullable=False)
# 多对多关系
posts = relationship("Post", secondary="post_tags", back_populates="tags")
# 关联表
from sqlalchemy import Table
post_tags = Table(
"post_tags",
Base.metadata,
Column("post_id", Integer, ForeignKey("posts.id"), primary_key=True),
Column("tag_id", Integer, ForeignKey("tags.id"), primary_key=True)
)
创建表与Session操作
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
engine = create_engine("sqlite:///blog.db", echo=False)
SessionLocal = sessionmaker(bind=engine)
# 创建所有表
Base.metadata.create_all(engine)
# 创建Session的函数
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 基本的CRUD操作
def create_user(db: Session, username: str, email: str, age: int = None):
user = User(username=username, email=email, age=age)
db.add(user)
db.commit()
db.refresh(user) # 刷新以获取数据库生成的值
return user
def get_user(db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first()
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
def get_users(db: Session, skip: int = 0, limit: int = 100):
return db.query(User).offset(skip).limit(limit).all()
def update_user(db: Session, user_id: int, **kwargs):
user = db.query(User).filter(User.id == user_id).first()
if user:
for key, value in kwargs.items():
setattr(user, key, value)
db.commit()
db.refresh(user)
return user
def delete_user(db: Session, user_id: int):
user = db.query(User).filter(User.id == user_id).first()
if user:
db.delete(user)
db.commit()
return True
return False
查询构建
# 基本查询
users = db.query(User).all() # 所有用户
user = db.query(User).first() # 第一个用户
user = db.query(User).get(1) # 按主键查询
# 条件过滤
young_users = db.query(User).filter(User.age < 25).all()
specific_user = db.query(User).filter(
User.username == "张三",
User.email.contains("@")
).first()
# 多种过滤方式
from sqlalchemy import or_, and_, not_
results = db.query(User).filter(
or_(
User.age < 20,
User.age > 60
)
).all()
# 排序
users = db.query(User).order_by(User.age.desc()).all()
users = db.query(User).order_by(User.username.asc(), User.age.desc()).all()
# 分页
page = 1
per_page = 10
users = db.query(User).offset((page - 1) * per_page).limit(per_page).all()
# 计数
total = db.query(User).filter(User.age >= 18).count()
# 存在性检查
exists = db.query(User.filter(User.username == "张三").exists()).scalar()
# 指定查询列
results = db.query(User.username, User.email).all()
# 去重
distinct_ages = db.query(User.age).distinct().all()
关系查询
# 一对多查询
user = db.query(User).filter(User.username == "张三").first()
print(f"用户的文章: {user.posts}")
# 预加载关系(减少N+1查询问题)
from sqlalchemy.orm import joinedload, selectinload
# joinedload - 使用JOIN
users = db.query(User).options(joinedload(User.posts)).all()
# selectinload - 使用子查询
users = db.query(User).options(selectinload(User.posts)).all()
# 多对多查询
post = db.query(Post).filter(Post.id == 1).first()
print(f"文章标签: {post.tags}")
tag = db.query(Tag).filter(Tag.name == "Python").first()
print(f"标签下的文章: {tag.posts}")
# 关系操作
user = db.query(User).filter(User.id == 1).first()
post = Post(title="新文章", content="内容", author=user)
db.add(post)
db.commit()
# 添加标签
tag = db.query(Tag).filter(Tag.name == "Python").first()
post.tags.append(tag)
db.commit()
聚合与分组
from sqlalchemy import func
# 聚合查询
stats = db.query(
func.count(User.id).label("total_users"),
func.avg(User.age).label("avg_age"),
func.max(User.age).label("max_age"),
func.min(User.age).label("min_age")
).first()
print(f"总用户数: {stats.total_users}")
print(f"平均年龄: {stats.avg_age:.1f}")
# 分组查询
age_groups = db.query(
User.age,
func.count(User.id).label("count")
).group_by(User.age).all()
for age, count in age_groups:
print(f"{age}岁: {count}人")
# HAVING子句
popular_ages = db.query(
User.age,
func.count(User.id).label("count")
).group_by(User.age).having(func.count(User.id) > 5).all()
# 子查询
from sqlalchemy import select
subquery = db.query(
Post.author_id,
func.count(Post.id).label("post_count")
).group_by(Post.author_id).subquery()
results = db.query(
User.username,
subquery.c.post_count
).join(subquery, User.id == subquery.c.author_id).all()
事务管理
# 使用Session进行事务管理
def transfer_posts(from_user_id: int, to_user_id: int):
db = SessionLocal()
try:
from_user = db.query(User).get(from_user_id)
to_user = db.query(User).get(to_user_id)
if not from_user or not to_user:
raise ValueError("用户不存在")
# 修改作者
for post in from_user.posts:
post.author_id = to_user_id
db.commit()
return True
except Exception as e:
db.rollback()
print(f"事务失败: {e}")
return False
finally:
db.close()
# 使用上下文管理器
from contextlib import contextmanager
@contextmanager
def get_db_session():
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# 使用示例
with get_db_session() as db:
user = User(username="新用户", email="new@example.com")
db.add(user)
# 退出时自动提交
实战示例:博客系统
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, relationship, declarative_base
from sqlalchemy.sql import func
from datetime import datetime
Base = declarative_base()
class BlogDB:
def __init__(self, db_url="sqlite:///blog.db"):
self.engine = create_engine(db_url, echo=False)
self.Session = sessionmaker(bind=self.engine)
Base.metadata.create_all(self.engine)
def get_session(self):
return self.Session()
def create_user(self, username, email):
with self.get_session() as db:
user = User(username=username, email=email)
db.add(user)
db.commit()
return user.to_dict()
def create_post(self, title, content, author_id, tag_names=None):
with self.get_session() as db:
post = Post(title=title, content=content, author_id=author_id)
if tag_names:
for tag_name in tag_names:
tag = db.query(Tag).filter(Tag.name == tag_name).first()
if not tag:
tag = Tag(name=tag_name)
db.add(tag)
post.tags.append(tag)
db.add(post)
db.commit()
return {"id": post.id, "title": post.title}
def get_user_posts(self, user_id):
with self.get_session() as db:
user = db.query(User).get(user_id)
if user:
return [
{"id": p.id, "title": p.title, "created_at": str(p.created_at)}
for p in user.posts
]
return []
def search_posts(self, keyword):
with self.get_session() as db:
posts = db.query(Post).filter(
Post.title.contains(keyword) | Post.content.contains(keyword)
).all()
return [{"id": p.id, "title": p.title} for p in posts]
# 使用示例
blog = BlogDB()
blog.create_user("张三", "zhangsan@example.com")
blog.create_post("Python入门", "Python是一门优雅的...", 1, ["Python", "教程"])
posts = blog.get_user_posts(1)
总结
SQLAlchemy是Python数据库开发的瑞士军刀。掌握Model定义、Session管理、关系映射和查询构建后,你可以高效地操作任何关系型数据库。实际项目中,推荐结合FastAPI或Flask使用,并考虑使用Alembic进行数据库迁移。