RoPE

1 RoPE

旋转位置编码(Rotary Position Embedding, RoPE)是一种位置编码方法,广泛应用于 Transformer 架构中。它的核心思想是:

对输入向量的施加一个与位置相关的旋转变换,从而在注意力计算中,使得点积的结果中包含位置的相对差值 nm 信息

理论推导: Transformer升级之路:2、博采众长的旋转式位置编码 - 科学空间|Scientific Spaces
实现:
Pasted image 20251004222749.png|425

    def sinusoidal_pos_embed(self, seq_len, dim):
        """
        @return: (1, n, seq_len, dim) 完整的位置编码
        """
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        pos = torch.arange(seq_len, dtype = torch.float)
        # will product
        sinusoid_inp = torch.einsum("i,j->ij", pos, inv_freq)

        sinusoid_pos = torch.stack([sinusoid_inp.sin(), sinusoid_inp.cos()], dim = -1).reshape(seq_len, -1)
        sinusoid_pos = sinusoid_pos.unsqueeze(0).unsqueeze(0)

        return sinusoid_pos.repeat(1, self.n, 1, 1)

    def apply_RoPE(self, x, pos_id=0):
        seq_len = x.shape[2]
        pos_emb = self.sinusoidal_pos_embed[:, :, pos_id:pos_id + seq_len, :]
        # cos_pos,sin_pos: (bs, head, max_len, output_dim)
        # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
        pos_emb = pos_emb.to(x.device)
        # 将奇数列信息抽取出来也就是cos 拿出来并复制
        cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim = -1)
        # 将偶数列信息抽取出来也就是sin 拿出来并复制
        sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim = -1)

        x2 = torch.stack([-x[..., 1::2], x[..., ::2]], dim = -1)
        # reshape后就是正负交替了
        x2 = x2.reshape(x.shape)

        x = x * cos_pos + x2 * sin_pos
        return x

探讨:

2 数学形式

二维向量

Pasted image 20251004202138.png
(这里的 q 0,q 1实际上是实现中 q的 dim 隐藏层维度的 d1, d2)

给定输入向量的某个二维子空间:

qj,kjR2

在位置 m 的旋转矩阵为:

Rj(m)=(cos(mθj)sin(mθj)sin(mθj)cos(mθj))

其中, θj 为该子空间的角度步长,或者叫旋转的幅角,例如 π50。m 为位置,例如序列的第0……100,实践上我们设置θj100002i/d, i=1,2,3,4d/2,d 是隐藏层维度,例如 768。

旋转后的 Query/Key:

q~m=Rj(m)qj,k~n=Rj(n)kj

旋转后的 Query/Key 内积如下,嵌入了相对位置信息:

q~mk~n=qjRj(nm)kj

注意:以上的过程固定了幅角

多维向量

Pasted image 20251004202119.png

d 维向量分为 d/2 个二维子空间,整体旋转矩阵为块对角矩阵:

T(m)=blockdiag(R0(m),R1(m),,Rd/21(m))

整体编码:

q~m=T(m)q,k~n=T(n)k

内积如下,只依赖相对位置:

q~mk~n=j=0d/21qjRj(nm)kj

注意:这意味着,对于不同位置的 token,例如 len 1, len 2 位置的(对应于 n,m),它们对应的向量进行内积,d 1,d 2 维度的幅角相同(q 0, q 1),d 3,d 4 维度(q 2, q 3)的幅角相同,……, 也就是两两分组包括了相对位置信息在其中


2.1 推导

2.1.1 事后精简版

证明引入复数的幅角,即旋转矩阵,可以使得内积的结果包含了相对位置信息。

我们将二维向量 (x,y) 看作复数 z=x+iy。在位置 m 的旋转变换等价于乘以相位:

zzeimθ.

于是两个位置 m,n 的向量在内积中表现为:

Re(zqeimθ(zkeinθ))=Re(zqzkei(mn)θ)

显然只依赖于 mn,而不是 m,n。

2.1.2 作者思路版

Transformer升级之路:2、博采众长的旋转式位置编码 - 科学空间|Scientific Spaces

作者首先假定运算 f , 给 qk 添加绝对位置信息,即

(1)q~m=f(q,m),k~n=f(k,n)

同时也自然有

(2)f(q,0)=q,f(k,0)=k

因为 attention 就是在做内积,所以假设其内积为 g,与 m - n 有关是希望它和相对位置有关,能反映它,这是我们的目标

(3)f(q,m),f(k,n)=g(q,k,mn)

考虑 f(q,m)f(k,m) 是复数进行求解,因为 q,k=Re[qk]Re 代表实部,所以有

(4)Re[f(q,m)f(k,n)]=g(q,k,mn)

这里作者为了简便,直接令

(5)f(q,m)f(k,n)=g(q,k,mn)

性质奇妙的地方来了,作者考虑以复数指数的形式进行求解(任何一个复数都可以用指数形式表示),将 f(q,m),f(k,m)g(q,k,mn) 都以复数指数进行表达,有

(6)f(q,m)=Rf(q,m)eiΘf(q,m)f(k,n)=Rf(k,n)eiΘf(k,n)g(q,k,mn)=Rg(q,k,mn)eiΘg(q,k,mn)

其中,R 表示实部,Θ 表示虚部函数

因为 (5) 式的直接相等,所以实部等于实部,虚部等于虚部,注意(5)式的*表示共轭取反,因此有

(7)Rf(q,m)Rf(k,n)=Rg(q,k,mn)Θf(q,m)Θf(k,n)=Θg(q,k,mn)

令 m = n = 0, (7)-1 有

Rf(q,0)Rf(k,0)=Rg(q,k,0)

令 m = n,(7)-1 有

Rf(q,m)Rf(k,m)=Rg(q,k,mn)=Rg(q,k,0)=Rf(q,0)Rf(k,0)

注意到(2)式,所以

Rf(q,0)Rf(k,0)=R(q)R(k)=||q||||k||

也即

Rf(q,m)Rf(k,m)=||q||||k||

这说明两者的实部与 m 无关,也就是和位置无关,我们可以不关注这一项了

现在我们来看 (7)-2 式,令 m = n = 0, (7)-2 有

Θf(q,0)Θf(k,0)=Θg(q,k,0)

令 m = n,(7)-2 有

Θf(q,m)Θf(k,m)=Θg(q,k,0)=Θf(q,0)Θf(k,0)

注意到(2)式,所以

Θf(q,0)Θf(k,0)=Θ(q)Θ(k)

也就是

Θf(q,m)Θf(k,m)=Θ(q)Θ(k)

因此

Θf(q,m)Θ(q)=Θf(k,m)Θ(k)

这意味着 Θf(q,m)Θ(q) 是一个只与 m 有关,和 q 无关的函数,记为 φ(m), 所以

(8)Θf(q,m)=Θ(q)+φ(m)

代入 n=m1, 有

(9)Θf(k,m1)=Θf(k)+φ(m1)

由(7)(8)(9)式整理可得:

φ(m)φ(m1)=Θf(q,m)Θ(q)Θf(k,m1)+Θf(k)=Θf(q,k,1)+Θf(k)Θ(q)

注意到等式的右边全部和 m 无关,可以认为是一个和 q, k 相关的常数,这意味着 φ 是一个等差数列,令右端为 θ, 有

φ(m)=mθ

Pasted image 20251004235520.png|600