论文阅读——CDUL:CLIP-Driven Unsupervised Learning for Multi-Label Image Classification
论文阅读:CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification
ICCV2023的的论文,提出来一个在CLIP中使用无标记的多标签的方案
在2024年还有人对其做再现性[研究(Reproducibility Study of CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification)
另外,本篇论文没有公开代码,但有相关同行实现了部分代码,这部分会在后面提及
本篇Blog仅记录些让我感兴趣的部分,不完善的地方与细节,还请各位自行补充
参考资料
【ICCV 2023】CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification
CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification
研究动机
· CLIP往往适合单标签分类,而不适合多标签分类
· 多标签的注释获取往往是带有噪声的
本文方法:
· 伪标签初始化。聚合全局和局部信息,令CLIP关注多类对象。
· 梯度对齐训练。递归地更新网络参数和伪标签(潜在参数)。
提出一种基于CLIP的无监督学习的无标注多标签图像分类方法。包括三个阶段:初始化、训练(Train)和推断(Inference)。
· 在初始化阶段充分利用强大的CLIP模型,并提出一种基于全球本地图像文本相似度的聚合的方法,以扩展Clip进行多标签预测
· 在训练阶段,我们将聚合相似度得分作为初始的伪标签,并提出一种优化框架来训练分类网络的参数,并优化未观测标签的伪标签。
· 在推断阶段,仅使用分类网络预测输入图像的标签。
伪标签初始化
全局与局部对齐
有一个Global Alignment 和一个Local Alignment
Global Alignment 是指整张图片的Embedding
Local Alignment是图片拆成块后的Embedding
两部分的使用的公式一模一样
有一个聚合器:Global-Local Image-Text Similarity Aggregator
针对Local Alignment计算出的相似度(similirity)给了一个聚合方案
基于该方案,与Global Alignment进行算术平均
与PromptPar中的代码所描述基本一致
def forward_aggregate(self, image, text):
all_class = (image / image.norm(dim=-1, keepdim=True)).float()
text_features = (text / text.norm(dim=-1, keepdim=True)).float()
\# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * all_class @ text_features.t()
//上面部分是OpenAI原本内容,下面部分为实现:
similarity = self.softmax_model(logits_per_image)
global_similarity = similarity[:,0]
local_similarity = similarity[:,1:]
for logits_local in local_similarity:
max_values, _ = torch.max(logits_local, dim=0)#max_values.detach().numpy()
min_values, _ = torch.min(logits_local, dim=0)
gama=max_values > args.ag_threshold
similarity_aggregate = gama.float() * max_values + (1 - gama.float()) * min_values
final_similarity = (similarity_aggregate + global_similarity) / 2
return self.agg_bn(final_similarity),logits_per_image
梯度对齐训练
使用后 Kullback-Leibler (KL散度)作为损失计算,根据结果更新伪标签,然后进行下一轮计算,逐步更新网络参数。
Note:好像不少多标签的CLIP都在用KL散度替换交叉熵(Cross Entrophy Loss)
结语
相比较于CSDN原文,好像也没有新增多少内容😂溜了溜了
非要说什么的话,就是这个方案确实有开创性,一个不难理解的方案,完成了CLIP对多标签任务的拓展