CLIP模型
--- title: "CLIP模型" description: "深入理解CLIP模型的对比学习原理、图像文本匹配机制以及零样本分类能力" tags: ["CLIP", "对比学习", "图像文本匹配", "零样本分类"] category: "llm" icon: "🧠"
CLIP模型
CLIP简介
CLIP(Contrastive Language-Image Pre-training)是OpenAI在2021年发布的多模态模型,通过在4亿个图像-文本对上进行对比学习预训练,实现了强大的视觉理解能力。CLIP的核心创新在于将图像和文本映射到同一个嵌入空间,从而实现跨模态的理解。
核心原理
对比学习框架
CLIP的核心是对比学习:给定一批图像-文本对,模型学习将匹配的图像和文本嵌入拉近,同时将不匹配的嵌入推远。
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor
class CLIPTrainer:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.image_processor = CLIPImageProcessor.from_pretrained(model_name)
self.temperature = nn.Parameter(torch.ones([]) * 0.07)
def get_image_features(self, images):
"""提取图像特征"""
processed = self.image_processor(images=images, return_tensors="pt")
return self.model.get_image_features(**processed)
def get_text_features(self, texts):
"""提取文本特征"""
tokens = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
return self.model.get_text_features(**tokens)
def compute_loss(self, images, texts):
"""计算对比损失"""
# 获取特征
image_features = self.get_image_features(images)
text_features = self.get_text_features(texts)
# L2归一化
image_features = nn.functional.normalize(image_features, dim=-1)
text_features = nn.functional.normalize(text_features, dim=-1)
# 计算相似度矩阵
logits = torch.mm(image_features, text_features.t()) / self.temperature
# 标签:对角线是正样本
batch_size = len(images)
labels = torch.arange(batch_size, device=logits.device)
# 双向对比损失
loss_i2t = nn.functional.cross_entropy(logits, labels)
loss_t2i = nn.functional.cross_entropy(logits.t(), labels)
return (loss_i2t + loss_t2i) / 2
零样本分类
CLIP最强大的能力之一是零样本分类:无需任何训练样本即可对图像进行分类。
import torch
from PIL import Image
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor
class ZeroShotClassifier:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
def classify(self, image, candidate_labels):
"""
零样本分类
Args:
image: PIL Image对象
candidate_labels: 候选类别列表
"""
# 构造文本提示
text_prompts = [f"a photo of a {label}" for label in candidate_labels]
# 处理输入
inputs = self.processor(
text=text_prompts,
images=image,
return_tensors="pt",
padding=True
)
# 获取相似度分数
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits_per_image.softmax(dim=-1)
# 返回分类结果
results = {
label: score.item()
for label, score in zip(candidate_labels, logits[0])
}
return results
def classify_with_template(self, image, labels, templates=None):
"""使用多种提示模板进行分类"""
if templates is None:
templates = [
"a photo of a {}",
"a blurry photo of a {}",
"a photo of many {}",
"a sculpture of a {}",
"a photo of the hard to see {}",
"a low resolution photo of the {}",
"a rendering of a {}",
"graffiti of a {}",
"a toy {}",
"itap of a {}",
]
# 为每个类别收集分数
all_scores = []
for label in labels:
label_scores = []
for template in templates:
text = template.format(label)
score = self._get_similarity_score(image, text)
label_scores.append(score)
all_scores.append(max(label_scores)) # 取最高分
# 归一化并返回
scores_tensor = torch.tensor(all_scores)
probs = scores_tensor.softmax(dim=-1)
return {label: prob.item() for label, prob in zip(labels, probs)}
def _get_similarity_score(self, image, text):
"""计算单个图像-文本对的相似度"""
inputs = self.processor(text=[text], images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.logits_per_image.item()
图像-文本检索
class CLIPSearchEngine:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.image_features_cache = {}
self.text_features_cache = {}
def index_image(self, image_id, image):
"""索引一张图片"""
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
features = self.model.get_image_features(**inputs)
self.image_features_cache[image_id] = nn.functional.normalize(features, dim=-1)
def index_text(self, text_id, text):
"""索引一段文本"""
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
features = self.model.get_text_features(**inputs)
self.text_features_cache[text_id] = nn.functional.normalize(features, dim=-1)
def search_by_text(self, query_text, top_k=5):
"""根据文本搜索相似图片"""
# 编码查询文本
inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
with torch.no_grad():
query_features = self.model.get_text_features(**inputs)
query_features = nn.functional.normalize(query_features, dim=-1)
# 计算与所有图片的相似度
scores = {}
for img_id, img_features in self.image_features_cache.items():
similarity = torch.mm(query_features, img_features.t()).item()
scores[img_id] = similarity
# 返回top-k结果
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return sorted_scores[:top_k]
def search_by_image(self, query_image, top_k=5):
"""根据图片搜索相似文本"""
inputs = self.processor(images=query_image, return_tensors="pt")
with torch.no_grad():
query_features = self.model.get_image_features(**inputs)
query_features = nn.functional.normalize(query_features, dim=-1)
scores = {}
for txt_id, txt_features in self.text_features_cache.items():
similarity = torch.mm(query_features, txt_features.t()).item()
scores[txt_id] = similarity
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return sorted_scores[:top_k]
应用场景
- 图像分类:零样本图像分类,无需标注数据
- 图像搜索:以文搜图、以图搜文
- 内容审核:检测不当内容
- 推荐系统:基于内容相似度的推荐
总结
CLIP通过对比学习将视觉和语言统一到共享空间,开创了多模态学习的新范式。其零样本能力使其在各种下游任务中表现出色,是现代多模态AI系统的重要基石。