Llava
源码来自Transformers库中的llava实现,为了梳理逻辑,我进行了部分调整
其中的Vision Encoder使用的是CLIP模型,Language Model使用的是Llama模型。
另外,CLIP模型这里使用的是VIT(24层)。
源码解析
- 第一步:获得图片和文本的嵌入
# 文本嵌入 torch.Size([1, 13, 4096])
inputs_embeds = self.get_input_embeddings()(input_ids)
# 图片嵌入 torch.Size([1, 576, 4096])
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer] # -2
selected_image_feature = selected_image_feature[:, 1:] # 去除了CLS token
image_features = self.multi_modal_projector(selected_image_feature) # 1024 -> 4096
【1】通常,最后一层用于其他任务(例如分类等),且进行了非线性激活或变换。选择 -2 的原因是,在许多情况下,倒数第二层的输出包含了丰富的特征信息,并且没有过多的非线性处理。
【2】注意这里处理除了调用vision_tower(实际上就是CLIP视觉编码器-24层),还调用了multi_modal_projector,将特征进行一个映射,从1024的视觉空间映射到了4096的文本空间。
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
- 第二步:融合文本特征和图片特征
inputs_embeds, attention_mask, labels, position_ids =
self._merge_input_ids_with_image_features(
image_features,
inputs_embeds,
input_ids,
attention_mask,
labels
)
# If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text
# and [1:576] for the image features
会发现这个融合十分简单,就是将这个图片的tokens插入到文本的tokens中间即可。
- 第三步:调用Llama模型得到输出
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# ['logits', 'past_key_values']
logits 表示模型对于输入序列每个位置预测下一个 token 的原始分数。它的维度通常是 (batch_size, sequence_length, vocab_size)
,其中sequence_length为输入序列的长度。
past_key_values 通常包含了每一层 Transformer 中的注意力机制的键(key)和值(value)矩阵。这样,当我们生成下一个 token 时,模型只需要计算新 token 的键和值,并与前面存储的键和值进行注意力计算,而不需要重新计算整个序列的注意力。(2, batch_size, num_heads, sequence_length, head_dim)
,其中:2: 第一个表示键(key),第二个表示值(value)
值得注意的是,则是一个自回归!所以会调用一直到生成<eos>
为止。
Processor分析
llava使用的processor分为:self.image_processor 和 self.tokenizer
前者为视觉处理器(具体为CLIPImageProcessor),后者为文本处理器
- 分析视觉处理器进行了哪些操作
# convert_rgb RGBA->RGB
# to_numpy_array PIL.Image.Image -> numpy.ndarray (height, width, channels)
## 注意格式是ChannelDimension.LAST(channels_last)
## -> (1152, 2048, 3)
# resize {'shortest_edge': 336}使最短边为336 -> (336, 597, 3)
# center_crop {'height': 336, 'width': 336} -> (336, 336, 3)
# rescale image = image * scale (0-255 -> 0-1)
# normalize image = (image - image_mean) / image_std
# change_channel (336, 336, 3) -> (3, 336, 336)
我觉得可以学习的点是,一般是先resize到shortest_edge,然后再center_crop。而不是直接resize到一个正方形,因为这样可以避免太多信息丢失!
之前觉得每次输出查看图片都感觉奇怪,原来llava这个模型就没考虑输出图片的,所以对图片进行了很多预处理,输入的图片是经过一堆处理的自然也就很奇怪了。
深入分析
值得分析的点是:
CLIP
获取的视觉特征,只提取倒数第二层的视觉Tokens(除CLS token)multi_modal_projector
将视觉空间(576 * 1024)直接映射到文本空间(576 * 4096)