AINN Attention

Attention Mechanism

Introduction

由于笔者发现在讲解Transformer的网络架构的时候缺乏对前置知识的细致理解。走马观花式的学习终究还是假学习。

  • 为什么RNN的Encoder-Decoder架构的表现不好?
  • 为什么Transformer需要设计成矩阵?
  • Q, K, V究竟代表什么意思?

因此,笔者下决心重起炉灶重构博客,尝试啃下这一块硬骨头。

今天的博客将会聚焦于Attention Mechanism的思想和数学实现。

Table of Contents

  • 注意力机制
  • RNN和seq2seq
  • seq2seqRNN的结合

Attention Mechanism

注意力机制是从生物学上获得的启发。应用到视觉世界中,威廉·詹姆斯提出了双组件框架,即人通过自主性提示和非自主性提示来选择性地引导注意力的焦点。

非自主性提示和自主性提示

  • 非自主性提示:“显眼”的物品,视觉上最敏锐的部分,获得直接的感官感受。
  • 自主性提示:受到认知和意识的控制。

macbook apple

例如,

  • 我因为Macbook图标很亮眼而注意到它:非自主性提示。(直接的感官感受)
  • 我因为想到要去买苹果而注意到桌上的苹果。(自主性提示,收到自我意识的驱动)

查询,键和值

如果建模这两种注意力机制?这是神经科学家非常关心的问题。

显然,对于非自主性提示,本质就是图片的特征提取,我们可以使用卷积核和卷积神经网络来提取图片中的特征(例如高亮度区域,色彩鲜艳的区域,锐度高的区域等等),不是我们今天讨论的重点。

但是对于自主性提示,其机制相对更加复杂。并且在现实世界中注意力往往是两种提示相互夹杂的。我们可以使用如下的基本结构来进行建模:

动手学深度学习 注意力机制

对于人类一般性的行为,我们可以建模成:

$$\boxed{\text{attention}} \to \boxed{\text{sensory input}} \to \boxed{\text{output}}$$

视觉上产生了注意力(自主 & 非自主),接着在大脑皮层产生感觉,并做出对应的输出。

  • 翻译:接受视觉输入 $\to$ 在脑海中产生感觉(思考的过程) $\to$ 输出成另一种语言

当我们看到一幅画中的苹果,可能会联想到吃苹果的甜味,甚至不自觉流口水。这一过程可以用注意力机制来解释,涉及三个核心概念:

  1. 查询(Query):当前的感知输入,如画中苹果的视觉特征(颜色、形状等),它触发大脑的检索机制(就是非自主性提示)。
  2. 键(Key):长期记忆中的关联信息,如过去吃苹果的经验。大脑会计算QueryKey的匹配程度,选择最相关的记忆。
  3. 值(Value):被选中的Key对应的具体信息,如“苹果是甜的”,进而触发味觉联想和生理反应。

这一机制的关键在于:

  • 选择性:并非所有记忆都会被激活,只有与当前输入最相关的信息才会被提取。
  • 联想性:视觉输入(Query)通过匹配记忆(Key)触发跨模态体验(Value),如“看到苹果→想到甜味”。
  • 自动化:该过程通常是快速、无意识的,体现大脑的高效信息检索能力。

通过这种机制,外部刺激(如画面)能自动激活相关记忆,并影响认知和生理反应。

在注意力机制中,自主性提示被称为查询(query),而非自主性提示(客观存在的咖啡杯和书本)作为(key)与感官输入(sensory inputs)的(value)构成一组 pair 作为输入。而给定任何查询,注意力机制通过注意力汇聚(attention pooling)将非自主性提示的 key 引导至感官输入。

  • 例如在RNN中,隐藏状态H是自主性提示,而新的input就是非自主性提示。
    • 如果不更新隐藏状态,而是直接对序列进行暴力建模,就相当于只考虑到非自主性提示,就是经典的MLP!

来点抽象的!我们给定键值对$(x_i, y_i)$,并给出需要查询的$x_q$(query),我们需要输出$\hat{y}=f(x_q)$作为我们的输出值。

$f$就是注意力机制的黑箱函数,给定好框架之后的创新无非就是针对对于$f$的函数的设计。

注意力汇聚算法

我们不妨设计一些简单的$f$,来看看效果怎么样。最简单的设计就是平均估计个体,即使用Average Pooling的操作:

$$f(x_q) = \frac{1}{n} \sum_{i= 1}^{n}y_i$$

这细细一想就很扯淡,因为这样无论给出什么样的query输出的结果总是一样的,bullshit!

也就是说,我们希望在输出预测的时候尽可能的使用到已知的信息,即我们所知道的键值对$(x_i, y_i)$。一个很自然的想法是**判断$x_q$和每一个键$x_i$**的相关程度,并以此为基础作为权重,再加权平均所有的$y_i$。

例如我从Query为“画中的苹果”检索到脑海中的苹果,可以认为是这两个事物的相关性很强,因此权重非常大,最后的输出就几乎全是“脑海中的苹果”对应的值。如果我的query是“画中那个洒满椒盐的苹果”,那可能这和脑海中“椒盐”这个键的相关性就会大幅度上升,进而产生相对应(值)的咸咸的感觉。

我们便得到了注意力汇聚公式(Attention Pooling),本质上就是加一个与键有关的权重,没什么特别的:

$$f(x_q) = \sum_{i= 1}^{n} \alpha(x_q, x_i) y_i,\ \sum_{i = 1}^{n}\alpha(x_q, x_i) = 1 $$

如何设计权重?考虑相关性,因此我们可以对$\alpha(x_q, x_i)$进一步细化,我们假设$K(x_i,x_j)$衡量了两个向量的相关性:

$$\alpha(x_q, x_i) = \frac{K(x_q, x_i)}{\sum_{j = 1}^{n}K(x_q, x_j)}$$

更进一步,如果我们考虑高斯核公式(先考虑一阶的):

$$\alpha(x_q,x_i) = \frac{\exp(-\frac{1}{2}(x_q - x_i)^2)}{\sum_{j = 1}^{n}\exp(-\frac{1}{2}(x_q - x_j)^2)} = \text{softmax}(-\frac{1}{2}(x_q - x_i)^2)$$

那我们就可以得到一个具体化的$f(x)$:

$$\hat{y} = f(x) = \sum_{i = 1}^{n} \text{softmax}(-\frac{1}{2}(x_q - x_i)^2) y_i$$

更进一步,我们希望这个函数可以携带可被学习的参数:

$$\hat{y} = f(x) = \sum_{i = 1}^{n} \text{softmax}(-\frac{1}{2}((x_q - x_i)\omega)^2) y_i$$

实验

热力图显示当key和query越接近的时候,其权重更高。(这里是一维的情况,就直接比较两个值的大小)

但是带权重的注意力汇聚在做梯度下降的时候会出现不平滑的现象(虽然loss确实下降了),因为在键值对很小的情况下很容易出现过拟合的情况。

Overfitting

可以理解为在一些有噪声的点附近,由于过拟合的存在(额外参数的影响)让这个点的权重变大了,导致heatmap出现了断点处的偏移。

这里引用一个非常好的回答:参数化的意义是什么

增加权重的意义在于使得对于远距离的key有着更小的关注,缩小的注意力的范围。

Adding Batches

为了充分发挥GPU的并行优势,我们可以使用批量矩阵的乘法

假设现在有$n$个$(a,b)$的二维矩阵$X_i(1\le i \le n)$和$n$个$(b,c)$的二维矩阵$Y_i(1\le i \le n)$,在非并行的状态下,我们需要做$n$次顺序矩阵乘法,即$X_i \times Y_i:=A_i$,最终得到$n$个$(a,c)$的矩阵。

并行的本质就是大家一起做运算,把$n$个$(a,b)$的二维矩阵$X_i(1\le i \le n)$可以定义为三维张量$T$,size是$(n,a,b)$。

因此,**假定两个张量的形状分别是$(n,a,b)$和$(n,b,c)$,它们的批量矩阵乘法输出的形状为$(n,a,c)$**。

如何使用小批量乘法来改写注意力机制?

计算加权平均值的过程可以看做是两个向量的内积操作(需要归一化),因此假设每一次批量是size为$n$,那么:

$$\text{Weight(n, 1, len)} \times \text{Value(n, len, 1)} = \text{Score(n,1)}$$

$n$在实际情况下可以定义为查询的个数,$len$是键值对的个数。

很惊讶的发现在这里$len$和$n$是无关的量,即我们可以不断的scale up键值对的个数,即获得更多的采样。

批量矩阵乘法在这里可以计算小批量数据的加权平均值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# orginal weights and values
weights = torch.ones((2,10)) * 0.1
values = torch.arange(20.0).reshape((2,10))
# weight: torch.Size([2, 10])
# values: torch.Size([2, 10])

# compute in sequential order
all_answer = torch.Tensor()
for i in range(0,weights.shape[0]):
weights_single = weights[i]
values_single = values[i]
answer = weights_single.dot(values_single.T).unsqueeze(-1)
all_answer = torch.cat([all_answer, answer])

# custom the size
all_answer1 = all_answer.unsqueeze(-1).unsqueeze(-1)

# using torch.bmm
weights = weights.unsqueeze(1)
# weights: [2,1,10]
values = values.unsqueeze(-1)
# values: [2,10,1]
all_answer2 = torch.bmm(weights, values)

assert torch.allclose(all_answer1, all_answer2)

Code: Nadaraya-Waston Kernel

在这个部分,我们将要实现最基本的注意力汇聚算法来尝试拟合一个函数,数学原理如上文所呈现:

不带参数的版本:

$$\hat{y} = f(x) = \sum_{i = 1}^{n} \text{softmax}(-\frac{1}{2}(x_q - x_i)^2) y_i$$

带参数的版本:

$$\hat{y} = f(x) = \sum_{i = 1}^{n} \text{softmax}(-\frac{1}{2}((x_q - x_i)\omega_i)^2) y_i$$

注意,这里和书上的公式有点不太一样,书上只引入了一个标量参数$w$,导致非常容易出现过拟合的现象,这里的可学习参数$\vec{\mathbf{w}} = \{w_1, w_2, w_3,\dots,w_n\}$,其中$n$是预先定义好的键值对的个数

代码如下,主要的核心逻辑是:

  • 生成数据集,我们需要拟合的函数是$f(x) = 2\sin(x) + 0.4\sin(3x) + 0.6\sin(6x) + \sqrt{x}$
  • 首先直接根据不带参数的版本,即没有任何的可训练参数,直接生成test的结果。
  • 接下来使用带参数的版本进行梯度下降训练,得到第二个拟合结果。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
import torch.nn as nn
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"


def show_heatmaps(
matrices,
xlabel="Keys",
ylabel="Queries",
titles=None,
figsize=(6, 6),
cmap="Reds",
save_path=None,
):
"""
Display heatmaps for attention weights or other matrices.

Args:
matrices: Input tensor or array of shape (num_rows, num_cols, height, width)
xlabel: Label for x-axis
ylabel: Label for y-axis
titles: List of titles for subplots
figsize: Size of the figure
cmap: Color map
save_path: Path to save the figure (None for not saving)
show: Whether to display the figure
"""
# Convert PyTorch tensor to numpy array if needed
if hasattr(matrices, "detach"):
matrices = matrices.detach().cpu().numpy()

num_rows, num_cols = matrices.shape[0], matrices.shape[1]

# Create figure and axes
fig, axes = plt.subplots(
num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False
)

# Plot each matrix
for i in range(num_rows):
for j in range(num_cols):
pcm = axes[i, j].imshow(matrices[i, j], cmap=cmap)

# Add labels only to the bottom row and leftmost column
if i == num_rows - 1:
axes[i, j].set_xlabel(xlabel)
if j == 0:
axes[i, j].set_ylabel(ylabel)

# Add titles if provided
if titles is not None:
axes[i, j].set_title(titles[j])

fig.savefig(save_path, dpi=300, bbox_inches="tight")

# Close the figure to prevent memory leaks
plt.close()


def plot_kernel_reg(
y_hat, x_test, y_truth, x_train, y_train, save_path="img/kernel_regression.png"
):
"""
Plot kernel regression results and save the figure.

Args:
y_hat: Predicted values (tensor or array)
x_test: Test input values
y_truth: Ground truth values for test inputs
x_train: Training input values
y_train: Training target values
save_path: Path to save the figure (default: "img/kernel_regression.png")
"""
# Convert tensors to numpy if needed
if hasattr(y_hat, "detach"):
y_hat = y_hat.detach().cpu().numpy()
if hasattr(y_truth, "detach"):
y_truth = y_truth.detach().cpu().numpy()
if hasattr(x_test, "detach"):
x_test = x_test.detach().cpu().numpy()
if hasattr(x_train, "detach"):
x_train = x_train.detach().cpu().numpy()
if hasattr(y_train, "detach"):
y_train = y_train.detach().cpu().numpy()

# Create figure
plt.figure(figsize=(10, 6))

# Plot truth and prediction lines
plt.plot(x_test, y_truth, label="Truth")
plt.plot(x_test, y_hat, label="Pred")

# Plot training data points
plt.scatter(x_train, y_train, marker="o", alpha=0.5, s=1, label="Training data")

# Add labels and legend
plt.xlabel("x")
plt.ylabel("y")
plt.title("Kernel Regression Results")
plt.legend()

# Save figure with high quality
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()


def generate_datasets(width: float, n_train: int, n_test: int):
"""Generate random datasets for NWKernel regression."""
x_train = torch.sort(torch.rand(n_train) * width).values

def target_function(x):
return (
2 * torch.sin(x) + 0.4 * torch.sin(3 * x) + 0.6 * torch.sin(6 * x) + x**0.5
)

y_train = target_function(x_train) + torch.normal(0.0, 0.5, (n_train,))
x_test = torch.linspace(0, width, n_test)
y_truth = target_function(x_test)

print(f"Generated datasets - Train: {n_train}, Test: {n_test}")
return x_train, y_train, x_test, y_truth


class NWKernelRegression(nn.Module):
def __init__(self, x_train, y_train):
super().__init__()
self.register_buffer("x_train", x_train)
self.register_buffer("y_train", y_train)
# We use one weight per training example
self.w = nn.Parameter(torch.ones(len(x_train)), requires_grad=True)

def forward(self, queries):
queries = queries.unsqueeze(1) # [n_queries, 1]
keys = self.x_train.unsqueeze(0) # [1, n_train]
diff = queries - keys # [n_queries, n_train]
self.attention_weights = F.softmax(-((diff * self.w) ** 2) / 2, dim=1)
return torch.matmul(self.attention_weights, self.y_train)


def visualize_attention(
net, x_test, x_train, num_points=3, save_path="img/visualize_attention.png"
):
"""Visualize attention weights for a few test points."""

idxs = torch.linspace(0, len(x_test) - 1, num_points).long()
queries = x_test[idxs]
keys = x_train
w_cpu = net.w.detach().cpu()
keys_cpu = keys.detach().cpu() if hasattr(keys, "detach") else torch.tensor(keys)
with torch.no_grad():
for i, query in enumerate(queries):
query_cpu = (
query.detach().cpu()
if hasattr(query, "detach")
else torch.tensor(query)
)
diff = query_cpu - keys_cpu
attn = torch.softmax(-(diff * w_cpu).pow(2) / 2, dim=0)
plt.figure()
plt.title(f"Attention for test x={query_cpu.item():.2f}")
plt.plot(keys_cpu, attn.numpy(), "o-")
plt.xlabel("x_train")
plt.ylabel("Attention weight")
plt.savefig(save_path)
plt.close()


def visualize_kernel_shape(net, x_train, save_path="img/kernelshape_visualize.png"):
"""Visualize the learned kernel shape centered at a point."""
import matplotlib.pyplot as plt

center = x_train[len(x_train) // 2]
diffs = torch.linspace(-5, 5, 100)
w_mean_cpu = net.w.mean().detach().cpu()
with torch.no_grad():
attn = torch.softmax(-(diffs * w_mean_cpu).pow(2) / 2, dim=0)
plt.figure()
plt.title("Learned Kernel Shape")
plt.plot(diffs.numpy(), attn.numpy())
plt.xlabel("x - center")
plt.ylabel("Kernel value")
plt.savefig(save_path)
plt.close()


def visualize_training_process(
record_epoch, record_loss, save_path="img/visualize_training_process.png"
):
plt.figure()
plt.title("Training Process")
plt.plot(record_epoch, record_loss)
plt.xlabel("epoches")
plt.ylabel("Training Loss")
plt.savefig(save_path)
plt.close()


def train(width, epochs, n_train, n_test, x_train, y_train, x_test, y_truth):
# Move data to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_train = x_train.to(device)
y_train = y_train.to(device)
x_test = x_test.to(device)
y_truth = y_truth.to(device)

# Initialize model and optimizer
net = NWKernelRegression(x_train, y_train).to(device)

if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs for DataParallel.")
net = nn.DataParallel(net)

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
record_loss = []
record_epochs = []

plot_epochs = {
10,
100,
1000,
5000,
10000,
20000,
50000,
60000,
70000,
80000,
100000,
110000,
120000,
}

for epoch in range(epochs):
optimizer.zero_grad()
y_pred = net(x_train)
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()

record_loss.append(loss.item())
record_epochs.append(epoch)

if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

if (epoch + 1) in plot_epochs:
with torch.no_grad():
y_hat = (
net.module(x_test)
if isinstance(net, nn.DataParallel)
else net(x_test)
)
plot_kernel_reg(
y_hat.cpu(),
x_test.cpu(),
y_truth.cpu(),
x_train.cpu(),
y_train.cpu(),
save_path=f"img/kernel_regression_epoch{epoch+1}.png",
)

visualize_training_process(record_epochs, record_loss)

# Testing
with torch.no_grad():
y_hat = net.module(x_test) if isinstance(net, nn.DataParallel) else net(x_test)

model_for_vis = net.module if isinstance(net, nn.DataParallel) else net
plot_kernel_reg(
y_hat.cpu(), x_test.cpu(), y_truth.cpu(), x_train.cpu(), y_train.cpu()
)
visualize_attention(model_for_vis, x_test.cpu(), x_train.cpu())
visualize_kernel_shape(model_for_vis, x_train.cpu())

show_heatmaps(
model_for_vis.attention_weights.unsqueeze(0).unsqueeze(0).cpu(),
xlabel="training inputs",
ylabel="testing inputs",
titles="HeatMaps for the final attention",
save_path="img/heatmap_params.png"
)


def singleNWKernel(width, n_train, n_test, x_train, y_train, x_test, y_truth):
sigma = 1.0 # fixed kernel size
x_train_cpu = x_train.cpu()
y_train_cpu = y_train.cpu()
x_test_cpu = x_test.cpu()
y_truth_cpu = y_truth.cpu()

# compute attention weights
queries = x_test_cpu.unsqueeze(1) # [n_test, 1]
keys = x_train_cpu.unsqueeze(0) # [1, n_train]
diff = queries - keys # [n_test, n_train]
attn = torch.softmax(-((diff / sigma) ** 2) / 2, dim=1) # 固定sigma
y_hat = torch.matmul(attn, y_train_cpu)

plot_kernel_reg(
y_hat,
x_test_cpu,
y_truth_cpu,
x_train_cpu,
y_train_cpu,
save_path="img/kernel_regression_noparam.png",
)

def visualize_attention_noparam(
x_test,
x_train,
attn,
num_points=3,
save_path="img/visualize_attention_noparam.png",
):
idxs = torch.linspace(0, len(x_test) - 1, num_points).long()
for i, idx in enumerate(idxs):
plt.figure()
plt.title(f"Attention for test x={x_test[idx].item():.2f}")
plt.plot(x_train, attn[idx].numpy(), "o-")
plt.xlabel("x_train")
plt.ylabel("Attention weight")
plt.savefig(f"img/visualize_attention_noparam_{i}.png")
plt.close()

visualize_attention_noparam(x_test_cpu, x_train_cpu, attn)

def visualize_kernel_shape_noparam(
sigma, save_path="img/kernelshape_visualize_noparam.png"
):
diffs = torch.linspace(-5, 5, 100)
attn = torch.softmax(-(diffs / sigma).pow(2) / 2, dim=0)
plt.figure()
plt.title("Fixed Kernel Shape")
plt.plot(diffs.numpy(), attn.numpy())
plt.xlabel("x - center")
plt.ylabel("Kernel value")
plt.savefig(save_path)
plt.close()

visualize_kernel_shape_noparam(sigma)

show_heatmaps(
attn.unsqueeze(0).unsqueeze(0),
xlabel="training inputs",
ylabel="testing inputs",
titles="HeatMaps for the final attention (no param)",
save_path="img/heatmap_noparam.png",
)


if __name__ == "__main__":
# Parameters
width = 20.0
n_train = 6000
n_test = 6000
epochs = 120001

# Run tests
x_train, y_train, x_test, y_truth = generate_datasets(width, n_train, n_test)
# Train and evaluate
print("For models with no parameters")
singleNWKernel(
width=width,
n_train=n_train,
n_test=n_test,
x_train=x_train,
y_train=y_train,
x_test=x_test,
y_truth=y_truth
)

print("For models with parameters")
train(width, epochs, n_train, n_test, x_train, y_train, x_test, y_truth)

实验结果分析

首先来看不带参数的版本:

Heatmap for no parameters

Regression results

image
image
image

因为没有参数,所以图像都非常的平滑!(这其实也是因为数据点采样的均匀性)但是从回归图像可以看出拟合的效果并不好,尤其在有噪声的情况下。

再来看带参数的版本:

下面三张图片分别是训练5000,50000,120000个epoch之后的版本。

其实这里也还是存在过拟合的问题,函数的参数过小,在后面的epoch发生了梯度消失的现象。

image
image
image

从Kernel Shape也可以看出,参数的引入是的距离效应显得更显著,即距离近的点的权重更大,而距离远的点的权重变得更小

Learned Kernel shape

如果看热力图可以发现,参数的引进导致图像变的模糊,这也是数据采样点的效果。

HeatMap for training

注意力评分函数

上文的公式:

$$\alpha(x_q,x_i) = \frac{\exp(-\frac{1}{2}(x_q - x_i)^2)}{\sum_{j = 1}^{n}\exp(-\frac{1}{2}(x_q - x_j)^2)} = \text{softmax}(-\frac{1}{2}(x_q - x_i)^2)$$

可以作为一个注意力评分函数,用来衡量查询值$x_q$和单个键值$x_i$的相关性。

我们可以推广到更加高维的空间中:

考虑查询$\vec{\mathbf{q}} \in \mathbb{R}^q$(在$q$维空间中),我们有$m$个键值对:$(\vec{\mathbf{k_1}}, \vec{\mathbf{v_1}})$, $(\vec{\mathbf{k_2}}, \vec{\mathbf{v_2}})$, … , $(\vec{\mathbf{k_m}}, \vec{\mathbf{v_m}})$, 其中$\vec{\mathbf{k_i}} \in \mathbb{R}^k$, $\vec{\mathbf{v_i}} \in \mathbb{R}^v$。

那么我们可以做如下的推广计算每一个查询的权重

$$f(\mathbf{q}, (\mathbf{k_1}, \mathbf{v_1}), \ldots, (\mathbf{k_m}, \mathbf{v_m})) = \sum_{i=1}^m w_i \mathbf{v}_i \in \mathbb{R}^v$$

$$f(\mathbf{q}, (\mathbf{k_1}, \mathbf{v_1}), \ldots, (\mathbf{k_m}, \mathbf{v_m})) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k_i}) \mathbf{v_i} \in \mathbb{R}^v$$

$$\alpha(\mathbf{q}, \mathbf{k_i}) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k_i})) = \frac{\exp(a(\mathbf{q}, \mathbf{k_i}))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k_j}))} \in \mathbb{R}$$

注意这里的$w_i$是一个标量,也是$\alpha(\mathbf{q}, \mathbf{k}_i)$的输出。

显然,在更加general的任务上,$q,k,v$的三个维度指标往往不相同。这就导致了维度不匹配的情况,单纯的NWKernel公式需要在维度对齐的情况下进行运算(比较两个向量的余弦相似度),因此后人工作的重点就是如何设计更好的注意力评分函数$\alpha(\mathbf{q}, \mathbf{k}_i)$。下文介绍两种经典的算法:加性注意力缩放点积注意力,以及Bahdanau注意力。


AINN Attention
https://xiyuanyang-code.github.io/posts/AINN-Attention/
Author
Xiyuan Yang
Posted on
April 29, 2025
Updated on
May 19, 2025
Licensed under