跳转至

Llava

源码来自Transformers库中的llava实现,为了梳理逻辑,我进行了部分调整

An image caption

其中的Vision Encoder使用的是CLIP模型,Language Model使用的是Llama模型。

另外,CLIP模型这里使用的是VIT(24层)。

源码解析

  • 第一步:获得图片和文本的嵌入
# 文本嵌入 torch.Size([1, 13, 4096])
inputs_embeds = self.get_input_embeddings()(input_ids)

# 图片嵌入 torch.Size([1, 576, 4096])
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer] # -2
selected_image_feature = selected_image_feature[:, 1:] # 去除了CLS token
image_features = self.multi_modal_projector(selected_image_feature) # 1024 -> 4096

【1】通常,最后一层用于其他任务(例如分类等),且进行了非线性激活或变换。选择 -2 的原因是,在许多情况下,倒数第二层的输出包含了丰富的特征信息,并且没有过多的非线性处理。

【2】注意这里处理除了调用vision_tower(实际上就是CLIP视觉编码器-24层),还调用了multi_modal_projector,将特征进行一个映射,从1024的视觉空间映射到了4096的文本空间。

self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
  • 第二步:融合文本特征和图片特征
inputs_embeds, attention_mask, labels, position_ids = 
self._merge_input_ids_with_image_features(
                    image_features, 
                    inputs_embeds, 
                    input_ids, 
                    attention_mask, 
                    labels
                )
# If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text 
# and [1:576] for the image features

会发现这个融合十分简单,就是将这个图片的tokens插入到文本的tokens中间即可。

  • 第三步:调用Llama模型得到输出
outputs = self.language_model(
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
# ['logits', 'past_key_values']

logits 表示模型对于输入序列每个位置预测下一个 token 的原始分数。它的维度通常是 (batch_size, sequence_length, vocab_size),其中sequence_length为输入序列的长度。

past_key_values 通常包含了每一层 Transformer 中的注意力机制的键(key)和值(value)矩阵。这样,当我们生成下一个 token 时,模型只需要计算新 token 的键和值,并与前面存储的键和值进行注意力计算,而不需要重新计算整个序列的注意力。(2, batch_size, num_heads, sequence_length, head_dim),其中:2: 第一个表示键(key),第二个表示值(value)

值得注意的是,则是一个自回归!所以会调用一直到生成<eos>为止。

Processor分析

llava使用的processor分为:self.image_processor 和 self.tokenizer

前者为视觉处理器(具体为CLIPImageProcessor),后者为文本处理器

  • 分析视觉处理器进行了哪些操作
# convert_rgb   RGBA->RGB

# to_numpy_array   PIL.Image.Image  -> numpy.ndarray (height, width, channels) 
 ## 注意格式是ChannelDimension.LAST(channels_last)
 ## -> (1152, 2048, 3)

# resize         {'shortest_edge': 336}使最短边为336 -> (336, 597, 3)

# center_crop    {'height': 336, 'width': 336} -> (336, 336, 3)

# rescale        image = image * scale (0-255 -> 0-1)

# normalize      image = (image - image_mean) / image_std

# change_channel (336, 336, 3) -> (3, 336, 336)

我觉得可以学习的点是,一般是先resize到shortest_edge,然后再center_crop。而不是直接resize到一个正方形,因为这样可以避免太多信息丢失!

之前觉得每次输出查看图片都感觉奇怪,原来llava这个模型就没考虑输出图片的,所以对图片进行了很多预处理,输入的图片是经过一堆处理的自然也就很奇怪了。

深入分析

值得分析的点是:

  • CLIP获取的视觉特征,只提取倒数第二层的视觉Tokens(除CLS token)
  • multi_modal_projector 将视觉空间(576 * 1024)直接映射到文本空间(576 * 4096)