论文阅读:CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification

ICCV2023的的论文,提出来一个在CLIP中使用无标记的多标签的方案

在2024年还有人对其做再现性[研究(Reproducibility Study of CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification

另外,本篇论文没有公开代码,但有相关同行实现了部分代码,这部分会在后面提及

本篇Blog仅记录些让我感兴趣的部分,不完善的地方与细节,还请各位自行补充

参考资料

【ICCV 2023】CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification

CDUL: CLIP-Driven Unsupervised Learning for Multi-Label Image Classification

研究动机

· CLIP往往适合单标签分类,而不适合多标签分类

· 多标签的注释获取往往是带有噪声的

本文方法:

· 伪标签初始化。聚合全局和局部信息,令CLIP关注多类对象。

· 梯度对齐训练。递归地更新网络参数和伪标签(潜在参数)。

提出一种基于CLIP的无监督学习的无标注多标签图像分类方法。包括三个阶段:初始化、训练(Train)和推断(Inference)。

· 在初始化阶段充分利用强大的CLIP模型,并提出一种基于全球本地图像文本相似度的聚合的方法,以扩展Clip进行多标签预测

· 在训练阶段,我们将聚合相似度得分作为初始的伪标签,并提出一种优化框架来训练分类网络的参数,并优化未观测标签的伪标签。

· 在推断阶段,仅使用分类网络预测输入图像的标签。

伪标签初始化

全局与局部对齐

有一个Global Alignment 和一个Local Alignment

Global Alignment 是指整张图片的Embedding

Local Alignment是图片拆成块后的Embedding

两部分的使用的公式一模一样

有一个聚合器:Global-Local Image-Text Similarity Aggregator

针对Local Alignment计算出的相似度(similirity)给了一个聚合方案

基于该方案,与Global Alignment进行算术平均

PromptPar中的代码所描述基本一致

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def forward_aggregate(self, image, text):
all_class = (image / image.norm(dim=-1, keepdim=True)).float()
text_features = (text / text.norm(dim=-1, keepdim=True)).float()
\# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * all_class @ text_features.t()

//上面部分是OpenAI原本内容,下面部分为实现:
similarity = self.softmax_model(logits_per_image)
global_similarity = similarity[:,0]
local_similarity = similarity[:,1:]
for logits_local in local_similarity:
max_values, _ = torch.max(logits_local, dim=0)#max_values.detach().numpy()
min_values, _ = torch.min(logits_local, dim=0)
gama=max_values > args.ag_threshold
similarity_aggregate = gama.float() * max_values + (1 - gama.float()) * min_values

final_similarity = (similarity_aggregate + global_similarity) / 2

return self.agg_bn(final_similarity),logits_per_image

梯度对齐训练

使用后 Kullback-Leibler (KL散度)作为损失计算,根据结果更新伪标签,然后进行下一轮计算,逐步更新网络参数。

Note:好像不少多标签的CLIP都在用KL散度替换交叉熵(Cross Entrophy Loss)

结语

相比较于CSDN原文,好像也没有新增多少内容😂溜了溜了

非要说什么的话,就是这个方案确实有开创性,一个不难理解的方案,完成了CLIP对多标签任务的拓展