← 返回首页
🧠

CLIP模型

📂 llm ⏱ 3 min 479 words

--- title: "CLIP模型" description: "深入理解CLIP模型的对比学习原理、图像文本匹配机制以及零样本分类能力" tags: ["CLIP", "对比学习", "图像文本匹配", "零样本分类"] category: "llm" icon: "🧠"

CLIP模型

CLIP简介

CLIP(Contrastive Language-Image Pre-training)是OpenAI在2021年发布的多模态模型,通过在4亿个图像-文本对上进行对比学习预训练,实现了强大的视觉理解能力。CLIP的核心创新在于将图像和文本映射到同一个嵌入空间,从而实现跨模态的理解。

核心原理

对比学习框架

CLIP的核心是对比学习:给定一批图像-文本对,模型学习将匹配的图像和文本嵌入拉近,同时将不匹配的嵌入推远。

import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor

class CLIPTrainer:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
        self.image_processor = CLIPImageProcessor.from_pretrained(model_name)
        self.temperature = nn.Parameter(torch.ones([]) * 0.07)
    
    def get_image_features(self, images):
        """提取图像特征"""
        processed = self.image_processor(images=images, return_tensors="pt")
        return self.model.get_image_features(**processed)
    
    def get_text_features(self, texts):
        """提取文本特征"""
        tokens = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
        return self.model.get_text_features(**tokens)
    
    def compute_loss(self, images, texts):
        """计算对比损失"""
        # 获取特征
        image_features = self.get_image_features(images)
        text_features = self.get_text_features(texts)
        
        # L2归一化
        image_features = nn.functional.normalize(image_features, dim=-1)
        text_features = nn.functional.normalize(text_features, dim=-1)
        
        # 计算相似度矩阵
        logits = torch.mm(image_features, text_features.t()) / self.temperature
        
        # 标签:对角线是正样本
        batch_size = len(images)
        labels = torch.arange(batch_size, device=logits.device)
        
        # 双向对比损失
        loss_i2t = nn.functional.cross_entropy(logits, labels)
        loss_t2i = nn.functional.cross_entropy(logits.t(), labels)
        
        return (loss_i2t + loss_t2i) / 2

零样本分类

CLIP最强大的能力之一是零样本分类:无需任何训练样本即可对图像进行分类。

import torch
from PIL import Image
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor

class ZeroShotClassifier:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
    
    def classify(self, image, candidate_labels):
        """
        零样本分类
        
        Args:
            image: PIL Image对象
            candidate_labels: 候选类别列表
        """
        # 构造文本提示
        text_prompts = [f"a photo of a {label}" for label in candidate_labels]
        
        # 处理输入
        inputs = self.processor(
            text=text_prompts,
            images=image,
            return_tensors="pt",
            padding=True
        )
        
        # 获取相似度分数
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits_per_image.softmax(dim=-1)
        
        # 返回分类结果
        results = {
            label: score.item() 
            for label, score in zip(candidate_labels, logits[0])
        }
        return results
    
    def classify_with_template(self, image, labels, templates=None):
        """使用多种提示模板进行分类"""
        if templates is None:
            templates = [
                "a photo of a {}",
                "a blurry photo of a {}",
                "a photo of many {}",
                "a sculpture of a {}",
                "a photo of the hard to see {}",
                "a low resolution photo of the {}",
                "a rendering of a {}",
                "graffiti of a {}",
                "a toy {}",
                "itap of a {}",
            ]
        
        # 为每个类别收集分数
        all_scores = []
        for label in labels:
            label_scores = []
            for template in templates:
                text = template.format(label)
                score = self._get_similarity_score(image, text)
                label_scores.append(score)
            all_scores.append(max(label_scores))  # 取最高分
        
        # 归一化并返回
        scores_tensor = torch.tensor(all_scores)
        probs = scores_tensor.softmax(dim=-1)
        
        return {label: prob.item() for label, prob in zip(labels, probs)}
    
    def _get_similarity_score(self, image, text):
        """计算单个图像-文本对的相似度"""
        inputs = self.processor(text=[text], images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.logits_per_image.item()

图像-文本检索

class CLIPSearchEngine:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.image_features_cache = {}
        self.text_features_cache = {}
    
    def index_image(self, image_id, image):
        """索引一张图片"""
        inputs = self.processor(images=image, return_tensors="pt")
        with torch.no_grad():
            features = self.model.get_image_features(**inputs)
        self.image_features_cache[image_id] = nn.functional.normalize(features, dim=-1)
    
    def index_text(self, text_id, text):
        """索引一段文本"""
        inputs = self.processor(text=[text], return_tensors="pt", padding=True)
        with torch.no_grad():
            features = self.model.get_text_features(**inputs)
        self.text_features_cache[text_id] = nn.functional.normalize(features, dim=-1)
    
    def search_by_text(self, query_text, top_k=5):
        """根据文本搜索相似图片"""
        # 编码查询文本
        inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
        with torch.no_grad():
            query_features = self.model.get_text_features(**inputs)
        query_features = nn.functional.normalize(query_features, dim=-1)
        
        # 计算与所有图片的相似度
        scores = {}
        for img_id, img_features in self.image_features_cache.items():
            similarity = torch.mm(query_features, img_features.t()).item()
            scores[img_id] = similarity
        
        # 返回top-k结果
        sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_scores[:top_k]
    
    def search_by_image(self, query_image, top_k=5):
        """根据图片搜索相似文本"""
        inputs = self.processor(images=query_image, return_tensors="pt")
        with torch.no_grad():
            query_features = self.model.get_image_features(**inputs)
        query_features = nn.functional.normalize(query_features, dim=-1)
        
        scores = {}
        for txt_id, txt_features in self.text_features_cache.items():
            similarity = torch.mm(query_features, txt_features.t()).item()
            scores[txt_id] = similarity
        
        sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_scores[:top_k]

应用场景

总结

CLIP通过对比学习将视觉和语言统一到共享空间,开创了多模态学习的新范式。其零样本能力使其在各种下游任务中表现出色,是现代多模态AI系统的重要基石。