php彩票网站建设源码,网站建设哪家最好,用wordpress建站效果怎么样,wordpress主题开发导航制作Meta最新模型LLaMA细节与代码详解0. 简介1. 项目环境依赖2. 模型细节2.1 RMS Pre-Norm2.2 SwiGLU激活函数2.3 RoPE旋转位置编码3. 代码解读3.1 tokenizer3.2 model3.2.1 模型细节详解3.2.2 transformer构建3.3 generate4. 推理0. 简介
今天介绍的内容是Facebook Meta AI最新提…
Meta最新模型LLaMA细节与代码详解0. 简介1. 项目环境依赖2. 模型细节2.1 RMS Pre-Norm2.2 SwiGLU激活函数2.3 RoPE旋转位置编码3. 代码解读3.1 tokenizer3.2 model3.2.1 模型细节详解3.2.2 transformer构建3.3 generate4. 推理0. 简介
今天介绍的内容是Facebook Meta AI最新提出的语言模型LLaMA该模型声称以更小的体积在多数任务上超越了GPT-3的性能。
模型相关项目已经开源 https://github.com/facebookresearch/llama
论文地址https://scontent-tpe1-1.xx.fbcdn.net/v/t39.8562-6/333078981_693988129081760_4712707815225756708_n.pdf?_nc_cat108ccb1-7_nc_sidad8a9d_nc_ohcov6yTHfLfNQAX-guxqd_nc_htscontent-tpe1-1.xxoh00_AfDMyTEYewg-cHT9_4_sUaW5h0gqrqwjcNMylD9HtVFCWAoe6401C9E2
由于模型较大目前的设备暂时没有办法支持进一步的实验但是其模型代码已经开源所以可以先通过代码了解一下模型结构上的一些细节今天就针对github上放出的代码了解一下模型的细节。
此外该模型其实就是transformer做了一点细节上的改进真正更有价值的工作应该在数据和训练方面。通过阅读代码可以对transformer的基础构造进行复习并且了解大模型如何在多卡上分布推理。 由于该项目源码几乎没有注释这就肯定会给很多同学阅读时带来困扰所以本文顺带着就把代码部分详细的介绍一下。
1. 项目环境依赖
此项目给出的环境依赖只有4个
torchfairscalefiresentencepiece
其中torch不比多讲fairscale是用来做GPU分布的一般是当使用DDP仍然遇到超显存的问题时使用fairscale。目前fairscale我还没有试过在下文的源码介绍中我会用torch中对应的基础网络替代fairscale中的结构层进行介绍。fire是一个命令行工具用或者不用他都可以sentencepiece是用于tokenizer的工具包会在tokenizer部分简单介绍。
2. 模型细节
由于该模型就是用的transformer的decoder所以在结构上它与GPT是非常类似的只是有一些细节需要注意一下。
2.1 RMS Pre-Norm
关于Pre-Norm和Post-Norm是神经网络中老生常谈的话题目前比较普遍的被大家接受的结论是相同的深度条件下Post-Norm的效果要优于Pre-Norm因为Pre-Norm实际上相当于通过了一个更宽的网络而非更深的网络所以在同等深度下Pre-Norm的实际效果相当于一个更浅却更宽的网络详细的推理过程参考https://spaces.ac.cn/archives/9009。
然而在LLaMA中却采用了Pre-Norm或许是因为模型够深7B13B30B65B的模型transformer layer数量分别为32406080而Pre-Norm的恒等分支更加明显有利于梯度的传播这部分暂时没有想到很合理的解释如果有更好的理解欢迎在评论区补充。
RMS NormRoot Mean Square Layer Normalization是一般LayerNorm的一种变体可以在梯度下降时令损失更加平滑。
与layerNorm相比RMS Norm的主要区别在于去掉了减去均值的部分re-centering只保留方差部分re-scaling从归一化的表达式上可以直观地看出
一般的LN
a‾iai−μσgi\overline{a}_i \frac {a_i- \mu} \sigma g_iaiσai−μgi 其中
μ1n∑i1nai\mu \frac 1 n \sum_{i1}^na_iμn1i1∑nai σ1n∑i1n(ai−μ)2\sigma \sqrt {\frac 1 n \sum_{i1}^n{{(a_i-\mu)}^2}}σn1i1∑n(ai−μ)2
RMS Norm a‾iaiRMS(a)gi\overline{a}_i \frac {a_i} {RMS(a)} g_i aiRMS(a)aigi 其中 RMS(a)1n∑i1nai2{RMS(a)}\sqrt {\frac 1 n \sum_{i1}^n{{a_i}^2}} RMS(a)n1i1∑nai2
可以看到二者的区别就在于有没有减去均值。至于RMS Norm为什么有用需要求梯度进行分析感兴趣的同学可以阅读RMS Norm的论文。
2.2 SwiGLU激活函数
LLaMA采用SwiGLU替换了原有的ReLU。
采用SwiGLU的FNN在论文中以如下公式进行表述 FFNswiGLU(x,W,V,W2)(Swish1(xW)⊗xV)W2FFN_{swiGLU}(x, W, V, W_2) (Swish_1(xW)\otimes xV)W_2FFNswiGLU(x,W,V,W2)(Swish1(xW)⊗xV)W2
其中Swishβ(x)xσ(βx)Swish_\beta(x) x\sigma(\beta x)Swishβ(x)xσ(βx), (Ramachandran et al., 2017.)
2.3 RoPE旋转位置编码
RoPERotary Position Embedding旋转位置编码是苏剑林老师提出的一种旋转位置编码方法其思想是采用绝对位置编码的形式实现相对位置编码。这一部分比较关键如果不理解的话后边的代码估计就看不懂了。读懂RoPE涉及一点复变函数的基础知识不过如果你没有学过的话也没有关系。
位置编码对大模型而言尤为重要因为既然是要训练大模型那么长文本的表征和模型对于长文本的建模能力就显得非常重要。但是对于绝对位置编码我有一个直观地感受认为其本质上不适用于长文本的场景因为它会直接导致模型的Embedding层被无限放大并且由于数据分布在seq_len方向上通常是长尾的这又会必然导致绝对位置编码的矩阵在尾部会越来越稀疏一方面造成资源浪费另一方面这种表示方法直观上就很不利于模型的学习因为它与我们实际场景是有很大的矛盾的。而RoPE虽然具有相对位置编码的性质但是从代码部分可以看出在构造的时候其也是受到了最大长度的限制的。关于这一点我无法严谨得说明只是一点个人的想法。。
而RoPE的巧妙之处在于它既保留了绝对位置编码中的绝对位置信息又保留了在内积运算下对位置信息的相对性。
RoPE主要借助了复数的思想。为了引入复数首先假设了在加入位置信息之前原有的编码向量是二维行向量qmq_mqm和knk_nkn其中mmm和nnn是绝对位置现在需要构造一个变换将mmm和nnn引入到qmq_mqm和knk_nkn中即寻找变换
qm~f(q,m),kn~f(k,n)\tilde {q_m} f(q, m), \tilde{k_n} f(k, n) qm~f(q,m),kn~f(k,n) 考虑到Attention的核心计算是内积 Attention(Q,K,V)softmax(QKTdk)VAttention(Q, K,V) softmax(\frac {QK^T} {\sqrt{d_k}})VAttention(Q,K,V)softmax(dkQKT)V
所以寻求的这个f(∗)f(*)f(∗)变换应该具有特性⟨f(q,m),f(k,n)⟩g(q,k,m−n)\langle f(q, m), f(k, n) \rangle g(q, k, m-n)⟨f(q,m),f(k,n)⟩g(q,k,m−n)
这里直接说结论寻求的变换就是qmeimθq_me^{im\theta}qmeimθ也就是给qmq_mqm乘以eimθe^{im\theta}eimθ相应地knk_nkn乘以einθe^{in\theta}einθ。
具体的求解过程请参考苏剑林老师的博客。
做了这样一个变换之后根据复数的特性有
⟨qm,kn⟩Re[qmkn∗]\langle q_m, k_n \rangle Re[q_mk^*_n]⟨qm,kn⟩Re[qmkn∗]
也就是如果把二维向量看做复数那么它们的内积等于一个复数乘以另一个复数的共轭得到的结果再取实部。
带入上面的变换也就有 ⟨qmeimθ,kneinθ⟩Re[(qmeimθ)(kneinθ)∗]Re[qmkn∗ei(m−n)θ]\langle q_me^{im\theta}, k_ne^{in\theta} \rangle Re[(q_me^{im\theta}) (k_ne^{in\theta})^*] Re[q_mk_n^*e^{i(m-n)\theta}]⟨qmeimθ,kneinθ⟩Re[(qmeimθ)(kneinθ)∗]Re[qmkn∗ei(m−n)θ]
这样一来内积的结果就只依赖于(m−n)(m-n)(m−n)也就是相对位置了。换言之经过这样一番操作通过给Embedding添加绝对位置信息可以使得两个token的编码经过内积变换self-attn之后得到结果是受它们位置的差值即相对位置影响的。
于是对于任意的位置为mmm的二维向量[x,y][x, y][x,y]把它看做复数乘以eimθe^{im\theta}eimθ而根据欧拉公式有
eimθcosmθisinmθe^{im\theta}\cos{m\theta}i\sin{m\theta}eimθcosmθisinmθ
于是上述的相乘变换也就变成了
(xiy)eimθ(xcosmθ−ysinmθ)i(xsinmθycosmθ)(xiy)e^{im\theta}(x\cos{m\theta}-y\sin{m\theta})i(x\sin{m\theta}y\cos{m\theta})(xiy)eimθ(xcosmθ−ysinmθ)i(xsinmθycosmθ)
把上述式子写成矩阵形式
f((q0,q1),m)[cosmθ−sinmθsinmθcosmθ][q0q1]f((q_0, q_1), m) \begin{bmatrix} {\cos{m\theta}}{-\sin{m\theta}} \\ {\sin{m\theta}}{\cos{m\theta}} \\ \end{bmatrix} \begin{bmatrix} q_0\\q_1\end{bmatrix}f((q0,q1),m)[cosmθsinmθ−sinmθcosmθ][q0q1]
而这个变换的几何意义就是在二维坐标系下对向量(q0,q1)(q_0, q_1)(q0,q1)进行了旋转因而这种位置编码方法被称为旋转位置编码。
根据刚才的结论结合内积的线性叠加性可以将结论推广到高维的情形。可以理解为每两个维度一组进行了上述的“旋转”操作然后再拼接在一起 [cosmθ0−sinmθ000⋯00sinmθ0cosmθ000⋯0000cosmθ1−sinmθ1⋯0000sinmθ1cosmθ1⋯00⋮⋮⋮⋮⋱⋮⋮0000⋯cosmθd/2−1−sinmθd/2−10000⋯sinmθd/2−1cosmθd/2−1][q0q1q2q3⋮qd−2qd−1]\begin{bmatrix} \cos{m\theta_0} -\sin{m\theta_0} 0 0 {\cdots} 0 0 \\ \sin{m\theta_0} \cos{m\theta_0} 0 0 {\cdots} 0 0 \\ 0 0 \cos{m\theta_1} -\sin{m\theta_1} {\cdots} 0 0 \\ 0 0 \sin{m\theta_1} \cos{m\theta_1} {\cdots} 0 0 \\ \vdots \vdots \vdots \vdots \ddots \vdots \vdots \\ 0 0 0 0 \cdots \cos{m\theta_{{d/2}-1}} -\sin{m\theta_{{d/2}-1}}\\ 0 0 0 0 \cdots \sin{m\theta_{{d/2}-1}} \cos{m\theta_{{d/2}-1}} \end{bmatrix} \begin{bmatrix} q_0\\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{bmatrix} cosmθ0sinmθ000⋮00−sinmθ0cosmθ000⋮0000cosmθ1sinmθ1⋮0000−sinmθ1cosmθ1⋮00⋯⋯⋯⋯⋱⋯⋯0000⋮cosmθd/2−1sinmθd/2−10000⋮−sinmθd/2−1cosmθd/2−1q0q1q2q3⋮qd−2qd−1
由于矩阵的稀疏性会造成计算上的浪费所以在计算时采用逐位相乘再相加的方式进行
[q0q1q2q3⋮qd−2qd−1]⊗[cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1][−q1q0−q3q2⋮−qd−1qd−2]⊗[sinmθ0sinmθ0sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1]\begin{bmatrix} q_0\\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{bmatrix} \otimes \begin{bmatrix} \cos{m\theta_0} \\ \cos{m\theta_0} \\ \cos{m\theta_1} \\ \cos{m\theta_1} \\ \vdots \\ \cos{m\theta_{{d/2}-1}} \\ \cos{m\theta_{{d/2}-1}} \end{bmatrix} \begin{bmatrix} -q_1\\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{bmatrix} \otimes \begin{bmatrix} \sin{m\theta_0} \\ \sin{m\theta_0} \\ \sin{m\theta_1} \\ \sin{m\theta_1} \\ \vdots \\ \sin{m\theta_{{d/2}-1}} \\ \sin{m\theta_{{d/2}-1}} \end{bmatrix} q0q1q2q3⋮qd−2qd−1⊗cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1−q1q0−q3q2⋮−qd−1qd−2⊗sinmθ0sinmθ0sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1
其中⊗\otimes⊗为矩阵逐位相乘操作。代码中具体的计算过程会有所出入具体见下文。
3. 代码解读
3.1 tokenizer
tokenizer这部分没有太多可以讲的主要就是用到了sentencepiece工具。
from sentencepiece import SentencePieceProcessor
from logging import getLogger
from typing import List
import oslogger getLogger()class Tokenizer:def __init__(self, model_path: str):# reload tokenizerassert os.path.isfile(model_path), model_pathself.sp_model SentencePieceProcessor(model_filemodel_path)logger.info(fReloaded SentencePiece model from {model_path})# BOS / EOS token IDsself.n_words: int self.sp_model.vocab_size()self.bos_id: int self.sp_model.bos_id()self.eos_id: int self.sp_model.eos_id()self.pad_id: int self.sp_model.pad_id()logger.info(f#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id})assert self.sp_model.vocab_size() self.sp_model.get_piece_size()def encode(self, s: str, bos: bool, eos: bool) - List[int]:assert type(s) is strt self.sp_model.encode(s)if bos:t [self.bos_id] tif eos:t t [self.eos_id]return tdef decode(self, t: List[int]) - str:return self.sp_model.decode(t)3.2 model
3.2.1 模型细节详解
model这部分的主要目的就是构建transformer由于LLaMA对transformer在细节上做了一点改动所以这里在介绍transformer部分之前先结合前文模型细节介绍几个辅助函数
1RMSNorm
这部分的基本原理在上文中已经介绍过了这里对代码部分进行简单的解释
x是输入weight是末尾乘的可训练参数x.pow(2)是平方mean(-1)实在最后一个维度即hidden特征维度上取平均eps防止取倒数之后分母为0torch.rsqrt是开平方并取倒数
结合上文的公式来看是不难理解的。
class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float 1e-6):super().__init__()self.eps epsself.weight nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdimTrue) self.eps)def forward(self, x):output self._norm(x.float()).type_as(x)return output * self.weight2RoPE旋转位置编码
为了实现旋转位置编码定义了三个辅助函数
def precompute_freqs_cis(dim: int, end: int, theta: float 10000.0):freqs 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))t torch.arange(end, devicefreqs.device) # type: ignorefreqs torch.outer(t, freqs).float() # type: ignorefreqs_cis torch.polar(torch.ones_like(freqs), freqs) # complex64return freqs_cisdef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):ndim x.ndimassert 0 1 ndimassert freqs_cis.shape (x.shape[1], x.shape[-1])shape [d if i 1 or i ndim - 1 else 1 for i, d in enumerate(x.shape)]return freqs_cis.view(*shape)def apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) - Tuple[torch.Tensor, torch.Tensor]:xq_ torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))freqs_cis reshape_for_broadcast(freqs_cis, xq_)xq_out torch.view_as_real(xq_ * freqs_cis).flatten(3)xk_out torch.view_as_real(xk_ * freqs_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)这一部分是整个项目中最不容易理解的部分因为它跟一般的位置编码不同即便是对transformer结构非常了解的同学如果没有认真读过RoPE对这一部分代码还是很难读明白。
看懂这一部分代码最关键的是弄清楚其中的变量freqs_cis所指是什么东西。
为了搞懂这部分我们需要先了解几个torch中不太常用的方法
1torch.view_as_complex
把一个tensor转为复数形式要求这个tensor的最后一个维度形状为2。
torch.view_as_complex(torch.Tensor([[1, 2], [3, 4], [5, 6]]))
# tensor([1.2.j, 3.4.j, 5.6.j])2torch.view_as_real 把复数tensor变回实数可以看做是是刚才操作的逆变换。
torch.view_as_real(torch.view_as_complex(torch.Tensor([[1, 2], [3, 4], [5, 6]])))
# tensor([[1., 2.],
# [3., 4.],
# [5., 6.]])3torch.outer
一个向量的转置乘以另一个向量torch.outer(a, b) a^T * b
a torch.arange(1, 5)
b torch.arange(1, 4)
torch.outer(a, b)
# tensor([[ 1, 2, 3],
# [ 2, 4, 6],
# [ 3, 6, 9],
# [ 4, 8, 12]])4torch.polar
torch.polar(abs, angle)利用一个绝对数值和一个角度值在极坐标下构造一个复数张量abs∗cos(angle)abs∗sin(angle)jabs * \cos(angle) abs * \sin(angle) jabs∗cos(angle)abs∗sin(angle)j。
torch.polar(torch.tensor([1], dtypetorch.float64), torch.tensor([np.pi / 2], dtypetorch.float64))
# tensor([6.1232e-171.j], dtypetorch.complex128)接下来进入RoPE的计算首先为了更加具象的表达我们在此对各个维度的尺寸进行假设假设batch_size为2seq_len固定为512attention_head的数量为12每个attention_head的维度为64那么对于输入到multi-head attn中的输入xqx_qxq的尺寸就是(2, 512, 12, 64)。
回到我们刚才提出的问题freqs_cis所指是什么东西其实它就是需要计算出来的mθm\thetamθ也就是跟绝对位置相关的旋转的角度在极坐标下对应的复数tensor。
而函数precompute_freqs_cis就是提前将这些旋转角度对应的tensor给创建出来并可以重复利用。因为确定了序列的最大长度所以这个tensor是固定死的。根据后续的数据流我们可以发现在调用该函数时传入的两个参数分别是attention_head的维度以及最大长度的两倍具象地也就是64和1024。
我们逐行来理解这个方法
freqs 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))首先torch.arange创建了一个tensor[0,2,4,...,60,62][0, 2, 4, ..., 60, 62][0,2,4,...,60,62]然后统一除以64把它变成分数然后整体作为基础角度的指数它的shape是(32)
t torch.arange(end, devicefreqs.device)t比较容易理解也就是绝对位置信息它的shape是(1024)。
freqs torch.outer(t, freqs).float()于是根据torch.outer运算我们得到了一个shape为(1024, 32)的tensor。其意义也就是将每一个绝对位置分配到对应的角度相乘。直观理解一下就是每一个绝对位置上都有32个角度。为什么是这样的呢回顾计算的公式对于旋转矩阵每两个元素为一组它们乘以的角度是同一个θ\thetaθ所以这个(1024, 32)在后续的过程中就可以reshape成(512, 64)并且在64的那个维度上每两个是相同的。
freqs_cis torch.polar(torch.ones_like(freqs), freqs)这一步就是在生成我们需要的位置信息直观理解一下像是在复平面内以原点为中心转了1024组每一组64个的单位向量它的shape是(1024, 64)。
reshape_for_broadcast方法是把freqs_cis变成和输入的tensor相同的形状结合下边的另一个方法一起介绍。
然后来看apply_rotary_emb方法这个方法其实就是把位置信息添加到原有的编码结果上在multi-head attention阶段调用。我们还是逐行来看
xq_ torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))上文中我们假设了输入xqx_qxq的尺寸就是(2, 512, 12, 64)那么这一句操作的reshape就是把它变成(2, 512, 12, -1, 2)也就是(2, 512, 12, 32, 2)。xkx_kxk同理略。紧接着把它变成复数形式也就是变成了(2, 512, 12, 32)的形状。
然后进入到reshape_for_broadcast方法
shape [d if i 1 or i ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)这个方法的作用是为了把freqs_cis变成和输入的tensor相同的形状。需要注意的是这里的freqs_cis并不是precompute_freqs_cis生成的形状为(1024, 64)的那个tensor而是根据输入的绝对位置在(1024, 64)的tensor中截取了长度为当前seq_len的一部分代码在Transformer类的forward方法中
freqs_cis self.freqs_cis[start_pos : start_pos seqlen]也就是说假如当前输入的序列长度是512那么截取出来的这个新的freqs_cis形状就是(512, 64)reshape之后形状就变成了(1, 512, 1, 32)也就是在每一个位置上都对应有32个角度根据刚刚torch.polar的介绍当我们固定绝对值也就是向量的模长时角度就可以在笛卡尔坐标系下唯一确定一个复数这样一来也就是32个复数即64个特征维度所以就可以对应的将它融合到每个attention head的64个特征中去了。
reshape之后就是将位置信息融入query和key中
xq_out torch.view_as_real(xq_ * freqs_cis).flatten(3)这一步将二者相乘得到的复数tensor重新转换为实数形式得到的shape为(2, 512, 12, 32, 2)然后再flatten成(2, 512, 12, 64)这样一来就变回了和最开始xqx_qxq相同的形状也就完成了将位置信息融入到xqx_qxq的这一操作。xkx_kxk同理。
以上就是添加位置编码的整个过程建议这一部分仔细阅读反复理解。
至于SwiGLU激活函数可以通过调用torch内置方法F.silu()实现会在下文的FFN部分介绍。
3.2.2 transformer构建
接下来是transformer模型的构建。通常我们在构建transformer时是按Block构建的每个transformer Block包含SA和FFN两部分然后再通过堆叠block的形式构建起整个transformer网络LLaMA也是这样做的读过BERT或者任何transformer结构的模型源码的同学一定对这个结构非常熟悉了。
首先看SA部分
class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_local_heads args.n_heads // fs_init.get_model_parallel_world_size()self.head_dim args.dim // args.n_headsself.wq ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,biasFalse,gather_outputFalse,init_methodlambda x: x,)self.wk ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,biasFalse,gather_outputFalse,init_methodlambda x: x,)self.wv ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,biasFalse,gather_outputFalse,init_methodlambda x: x,)self.wo RowParallelLinear(args.n_heads * self.head_dim,args.dim,biasFalse,input_is_parallelTrue,init_methodlambda x: x,)self.cache_k torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()self.cache_v torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):bsz, seqlen, _ x.shapexq, xk, xv self.wq(x), self.wk(x), self.wv(x)xq xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)xk xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)xv xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)xq, xk apply_rotary_emb(xq, xk, freqs_cisfreqs_cis)self.cache_k self.cache_k.to(xq)self.cache_v self.cache_v.to(xq)self.cache_k[:bsz, start_pos : start_pos seqlen] xkself.cache_v[:bsz, start_pos : start_pos seqlen] xvkeys self.cache_k[:bsz, : start_pos seqlen]values self.cache_v[:bsz, : start_pos seqlen]xq xq.transpose(1, 2)keys keys.transpose(1, 2)values values.transpose(1, 2)scores torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)if mask is not None:scores scores mask # (bs, n_local_heads, slen, cache_len slen)scores F.softmax(scores.float(), dim-1).type_as(xq)output torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)output output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)return self.wo(output)这一部分看上去会比较复杂涉及到了很多的计算但其实它就是最普通的attention只要牢记attention的核心计算公式也不难理解。
其中为了执行多卡并行这里的Linear层用的都是fairscale中的类在阅读代码时直接理解为Linear即可。
attention计算的总体过程是 输入xxx分别经过三个Linear得到xq,xk,xvx_q, x_k, x_vxq,xk,xv 在xqx_qxq和xkx_kxk中加入旋转位置编码 缓存xqx_qxq和xkx_kxk 计算softmax(QKTdk)Vsoftmax(\frac {QK^T} {\sqrt{d_k}})Vsoftmax(dkQKT)V。
其中有一个细节就是缓存机制这里简单介绍一下很多初学者甚至NLP老手都容易忽视这个问题。这个机制在模型的训练过程中其实是不发挥作用的它设计的目的是在generate时减少token的重复计算。
简单解释一下就是在计算第nnn个token特征的时候需要用到第1,...,n−11,...,n-11,...,n−1个token即每次生成时需要知道前面所有的过往信息如果每次都从头算的话那就会造成极大的浪费所以就没算一个位置的信息就把它缓存下来。
然后是FFN部分需要注意的点就是采用的激活函数以及激活函数的位置
class FeedForward(nn.Module):def __init__(self,dim: int,hidden_dim: int,multiple_of: int,):super().__init__()hidden_dim int(2 * hidden_dim / 3)hidden_dim multiple_of * ((hidden_dim multiple_of - 1) // multiple_of)self.w1 ColumnParallelLinear(dim, hidden_dim, biasFalse, gather_outputFalse, init_methodlambda x: x)self.w2 RowParallelLinear(hidden_dim, dim, biasFalse, input_is_parallelTrue, init_methodlambda x: x)self.w3 ColumnParallelLinear(dim, hidden_dim, biasFalse, gather_outputFalse, init_methodlambda x: x)def forward(self, x):return self.w2(F.silu(self.w1(x)) * self.w3(x))这里与常见模型中的FFN做一下简单的对比BART中的FFN用的是fc-act-fc用了两层全连接 GPT中的FFN用的是conv1D-act-conv1D也是只用了两层。
而LLaMA中的FFN采用了三个全连接层以实现FFNSwiGLU即
FFNswiGLU(x,W,V,W2)(Swish1(xW)⊗xV)W2FFN_{swiGLU}(x, W, V, W_2) (Swish_1(xW)\otimes xV)W_2FFNswiGLU(x,W,V,W2)(Swish1(xW)⊗xV)W2
然后将SA和FFN这两部分拼在一起就是一个transformer block
class TransformerBlock(nn.Module):def __init__(self, layer_id: int, args: ModelArgs):super().__init__()self.n_heads args.n_headsself.dim args.dimself.head_dim args.dim // args.n_headsself.attention Attention(args)self.feed_forward FeedForward(dimargs.dim, hidden_dim4 * args.dim, multiple_ofargs.multiple_of)self.layer_id layer_idself.attention_norm RMSNorm(args.dim, epsargs.norm_eps)self.ffn_norm RMSNorm(args.dim, epsargs.norm_eps)def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):h x self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)out h self.feed_forward.forward(self.ffn_norm(h))return out最后利用torch的module list将transformer block进行堆叠拼上最前头的embedding部分就是一个完整的transformerdecoder结构了。
class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.params paramsself.vocab_size params.vocab_sizeself.n_layers params.n_layersself.tok_embeddings ParallelEmbedding(params.vocab_size, params.dim, init_methodlambda x: x)self.layers torch.nn.ModuleList()for layer_id in range(params.n_layers):self.layers.append(TransformerBlock(layer_id, params))self.norm RMSNorm(params.dim, epsparams.norm_eps)self.output ColumnParallelLinear(params.dim, params.vocab_size, biasFalse, init_methodlambda x: x)self.freqs_cis precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int):_bsz, seqlen tokens.shapeh self.tok_embeddings(tokens)self.freqs_cis self.freqs_cis.to(h.device)freqs_cis self.freqs_cis[start_pos : start_pos seqlen]mask Noneif seqlen 1:mask torch.full((1, 1, seqlen, seqlen), float(-inf), devicetokens.device)mask torch.triu(mask, diagonalstart_pos 1).type_as(h)for layer in self.layers:h layer(h, start_pos, freqs_cis, mask)h self.norm(h)output self.output(h[:, -1, :]) # only compute last logitsreturn output.float()直接看forward部分输入是token先做token embedding然后添加位置信息。对于decoder模型为了防止标签泄漏需要mask所以做了一个上三角的mask矩阵。接下来就是逐层的计算transformer。
3.3 generate
class LLaMA:def __init__(self, model: Transformer, tokenizer: Tokenizer):self.model modelself.tokenizer tokenizerdef generate(self,prompts: List[str],max_gen_len: int,temperature: float 0.8,top_p: float 0.95,) - List[str]:bsz len(prompts)params self.model.paramsassert bsz params.max_batch_size, (bsz, params.max_batch_size)prompt_tokens [self.tokenizer.encode(x, bosTrue, eosFalse) for x in prompts]min_prompt_size min([len(t) for t in prompt_tokens])max_prompt_size max([len(t) for t in prompt_tokens])total_len min(params.max_seq_len, max_gen_len max_prompt_size)tokens torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()for k, t in enumerate(prompt_tokens):tokens[k, : len(t)] torch.tensor(t).long()input_text_mask tokens ! self.tokenizer.pad_idstart_pos min_prompt_sizeprev_pos 0for cur_pos in range(start_pos, total_len):logits self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)if temperature 0:probs torch.softmax(logits / temperature, dim-1)next_token sample_top_p(probs, top_p)else:next_token torch.argmax(logits, dim-1)next_token next_token.reshape(-1)# only replace token if prompt has already been generatednext_token torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)tokens[:, cur_pos] next_tokenprev_pos cur_posdecoded []for i, t in enumerate(tokens.tolist()):# cut to max gen lent t[: len(prompt_tokens[i]) max_gen_len]# cut to eos tok if anytry:t t[: t.index(self.tokenizer.eos_id)]except ValueError:passdecoded.append(self.tokenizer.decode(t))return decodeddef sample_top_p(probs, p):probs_sort, probs_idx torch.sort(probs, dim-1, descendingTrue)probs_sum torch.cumsum(probs_sort, dim-1)mask probs_sum - probs_sort pprobs_sort[mask] 0.0probs_sort.div_(probs_sort.sum(dim-1, keepdimTrue))next_token torch.multinomial(probs_sort, num_samples1)next_token torch.gather(probs_idx, -1, next_token)return next_token生成的过程如下 对prompts进行tokenize得到token ids 计算当前batch的最大长度total_len用来创建输入的token tensor最大长度不能超过前文所述缓存的大小 从当前batch中最短的一个prompt的位置作为生成的开始位置开始生成 输入的token tensor传入transformer模型计算logits得到形状为(batch_size, hidden_size)的logitstransformer最后一层的输出 softmaxtop_p采样得到当前预测的token并更新当前位置准备预测下一个token 解码得到生成的文本。
4. 推理
简单看一下官方example中给出的推理样例prompt
[The capital of Germany is the city of,Here is my sonnet in the style of Shakespeare about an artificial intelligence:]生成结果为
[The capital of Germany is the city of Berlin. The city is also the capital of the Federal Republic of Germany.\nThe city of Berlin is located in the state of Berlin in Germany. The city is the capital of the federal Republic of Germany.\nBerlin has a total population of around 3.4 million and is the 2nd most populous city in the European Union after London. The city has an area of 892 square kilometers and is the 9th most populated city in Europe.\nThe city of Berlin was founded in the 13th century. Berlin was also the capital of the German Empire, the German Democratic Republic and the united Federal Republic of Germany.\nThe city of Berlin has many tourist attractions that include Museumsinsel, Brandenburger Tor, the Reichstag, and the Schloss Charlottenburg.\nThe city of Berlin is a major center for the Arts, Science, Education and Innovation. The city is also the political, economic, and cultural center of Germany.\nBerlin is home to a number of world renowned universities including the Free University of Berlin, the Humboldt University of Berlin, the Technical University of Berlin, and the Berlin Institute of Technology.\nThe city of Berlin has,Here is my sonnet in the style of Shakespeare about an artificial intelligence:\nLet us take a moment from the tumultuous storm\nOf the politics of religion to examine the shape of things.\nOur intuition tells us that whatever we can conceive\nCan exist – our minds have no limit.\nHowever, our senses tell us that there is a limit.\nLet us examine the infinite and what we can say about it.\nThe infinite is something that we can never see.\nWe cannot say what it is and we cannot say what it is not.\nBut, somehow, it is nonetheless real.\nWe can also say that the infinite is eternal –\nIt has no beginning and it has no end.\nThat is what it is – it is the eternal.\nIn a word, it is God.\nBut what about the universe?\nThe universe is a finite construct –\nThe infinitely large and the infinitely small –\nAll of it finite.\nEven the singularity at the end of time is finite.\nSo, the universe is not God.\nPerhaps it is the vessel of God.\nPerhaps, in some sense, the universe is God.\nBut, I am still a man.\nI cannot see the infinite.\nI can only]总结一下本文对LLaMA大模型的结构代码进行了详细的介绍其开源出来的结构代码量并不多但是其中很多细节值得反复推敲理解。
在后续的工作中可能会对大模型进行进一步的实验对此欢迎对此感兴趣的朋友们在下方留言交流。如果本文中出现了不够准确的地方也欢迎大家在评论区指出。