7163 lines
4.2 MiB
Plaintext
7163 lines
4.2 MiB
Plaintext
|
|
{
|
|||
|
|
"cells": [
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "90dbf366-7a84-4c76-bcd6-466feaa51a7b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"# Transformer原理"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d5eef42a-f6f1-4756-a325-1e7fd59813a5",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"**0 前言**<br>\n",
|
|||
|
|
" 0.1 Transformer模型的地位与发展历程<br>\n",
|
|||
|
|
" 0.2 序列模型的基本思路与根本诉求<br>\n",
|
|||
|
|
"\n",
|
|||
|
|
"**1 注意力机制**<br>\n",
|
|||
|
|
" 1.1 注意力机制的本质<br>\n",
|
|||
|
|
" 1.2 Transformer中的自注意力机制运算流程<br>\n",
|
|||
|
|
" 1.3 Multi-Head Attention 多头注意力机制<br>\n",
|
|||
|
|
"\n",
|
|||
|
|
"**2 Transformer的基本结构**<br>\n",
|
|||
|
|
" 2.1 Embedding层与位置编码技术<br>\n",
|
|||
|
|
" 2.2 Encoder结构解析<br>\n",
|
|||
|
|
" 2.2.1 残差连接<br>\n",
|
|||
|
|
" 2.2.2 Layer Normalization层归一化<br>\n",
|
|||
|
|
" 2.2.3 Feed-Forward Networks前馈网络<br>\n",
|
|||
|
|
" 2.3 Decoder结构解析<br>\n",
|
|||
|
|
" 2.3.1 完整Transformer与Decoder-Only结构的数据流<br>\n",
|
|||
|
|
" 2.3.2 Encoder-Decoder结构中的Decoder<br>\n",
|
|||
|
|
" 2.3.2.1 输入与teacher forcing<br>\n",
|
|||
|
|
" 2.3.2.2 掩码注意力机制<br>\n",
|
|||
|
|
" 2.3.2.3 普通掩码与前馈掩码<br>\n",
|
|||
|
|
" 2.3.2.4 编码器-解码器注意力层<br>\n",
|
|||
|
|
" 2.3.3 Decoder-Only结构中的Decoder<br>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "db513437-c41a-4817-b69b-1b41be7d9e5e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"Transformer模型,作为自然语言处理(NLP)领域的一块重要里程碑,于2017年由Google的研究者们提出,现在成为深度学习中对文本和语言数据处理具有根本性影响的架构之一。在NLP的宇宙中,如果说RNN、LSTM等神经网络创造了“序列记忆”的能力,那么Transformer则彻底颠覆了这种“记忆”的处理方式——它放弃了传统的顺序操作,而是通过自注意力机制(Self-Attention),赋予模型一种全新的、并行化的信息理解和处理方式。从自注意力的直观概念出发,Transformer的设计者们引进了多头注意力(Multi-Head Attention)、位置编码(Positional Encoding)等创新元素,大幅度提升了模型处理序列数据的效率和效果。通过精巧的数学构建和模型设计,Transformer能够同时捕捉序列中的局部细节和全局上下文,解决了以往模型在长距离依赖上的困难,使其在处理长文本序列时的能力大大增强。\n",
|
|||
|
|
"\n",
|
|||
|
|
"经过几年的快速迭代,Transformer不仅优化了其原始架构,而且催生了一系列高效的后续模型如BERT、GPT-3、RoBERTa和T5等,这些模型在语言理解、文本生成等多种NLP任务上都取得了令人瞩目的成绩。如同LSTM在其领域内的长久影响一样,Transformer模型的论文和原理也成为了NLP领域的经典,而它本身的算法和架构也已成为当代处理语言数据的根基。今天,尽管存在多种高级的算法和模型,Transformer仍然是处理复杂语言模式、捕捉细腻语义的主流架构,它在多个维度上重塑了我们构建和理解语言模型的方式。现在,就让我们一起来探讨这一划时代结构背后的深邃原理。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5f7cc248-0d1a-413d-8837-58b55c3ec268",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 1 Transformer模型的地位与发展历程"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b9e56e65-a9d1-4727-bc5c-0b87dfcf5c68",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"学习Transformer并不只是学习一个算法,而是学习以Transformer为核心的一整个、基于注意力机制的大体系。在NLP领域中,有这样一张著名的树状图,它展示了从2018年到2023年的各种基于Transformer架构的语言模型的发展历程,并将模型从“开源/闭源”、“encoder/decoder/encoder-decoder”以及开发公司三个维度进行了划分。这个演化树很好地概述了从2018年到2023年基于Transformer架构的模型的发展脉络,让我们一步步来解读这个发展历史。\n",
|
|||
|
|
"\n",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "f1848b75-6de4-48dd-862c-089cbfefc50e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 2018年:Transformer的早期探索\n",
|
|||
|
|
"> ELMo(Embeddings from Language Models)和ULMFiT(Universal Language Model Fine-tuning)虽然不是基于Transformer架构的,但它们引入了深层双向表示和迁移学习的概念,这对后来的Transformer模型产生了重要影响。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 2019年:BERT及其衍生模型\n",
|
|||
|
|
"> BERT(Bidirectional Encoder Representations from Transformers)是一个重要的里程碑,它通过自注意力机制(self-attention)使模型能够生成上下文相关的词向量。它是首个大规模双向Transformer模型,对后续的模型产生了深远影响。\n",
|
|||
|
|
"> RoBERTa 对BERT进行了改进,通过更大的数据集和更长的训练时间提高了性能。\n",
|
|||
|
|
"> ALBERT 旨在减少模型大小的同时保持性能,通过因子分解嵌入层和跨层参数共享实现。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 2020年:多样化的模型和架构\n",
|
|||
|
|
"> T5(text-to-text Transfer Transformer)将所有文本问题转化为文本到文本的格式,这样可以用相同的模型处理不同的任务。\n",
|
|||
|
|
"> BART(Bidirectional and Auto-Regressive Transformers)结合了自编码和自回归的特点,适用于序列生成任务。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 2021年:专门化与效率优化\n",
|
|||
|
|
"> ELECTRA 通过对抗性训练和效率优化来提高模型性能。\n",
|
|||
|
|
"> DeBERTa 引入了改进的注意力机制,提高了模型对词之间关系的理解。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 2022年:大型语言模型的崛起\n",
|
|||
|
|
"> GPT-3(Generative Pre-trained Transformer 3)以其庞大的参数量和强大的生成能力成为当时最大的语言模型之一。\n",
|
|||
|
|
"> Switch-C(Switch Transformers)采用了稀疏激活,允许模型扩展到非常大的尺寸而不显著增加计算成本。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 2023年:高级应用与细化模型\n",
|
|||
|
|
"> ChatGPT 和 InstructGPT 是在GPT-3基础上针对特定应用,如聊天和指令性任务,进行了优化的模型。\n",
|
|||
|
|
"> Chinchilla 和 Gopher 表示了更高级别的语言理解和生成能力。\n",
|
|||
|
|
"> FLAN 和 mT5 针对多语言任务设计,显示了模型的国际化和多样化发展。\n",
|
|||
|
|
"\n",
|
|||
|
|
"这个演化树还展示了各个模型之间的继承关系,以及它们是否开源。模型的不断迭代和创新反映了这个领域对于理解和生成人类语言能力的不断追求。每一代模型都在数据处理能力、训练方法、应用范围以及解决特定问题的能力上做出了改进。这一进程不仅推动了NLP的边界,也为人工智能的其他领域提供了宝贵的见解。\n",
|
|||
|
|
"\n",
|
|||
|
|
"当我们踏上学习Transformer的旅程时,实际上是在拥抱一个基于注意力机制的庞大而复杂的知识体系,这远远超出了单一算法的学习。Transformer及其衍生的模型不仅仅是NLP领域的工具,更是一扇窗口,透过它我们可以观察和理解语言的深层结构和流动的信息。这一体系以其独特的处理方式——自注意力机制——为核心,它挑战了传统的序列处理观念,引领我们探索如何让机器更深入地理解文本之间的复杂关系。从Transformer的基本架构到各种先进的变体,如BERT、GPT-3等,我们将学习如何让机器通过这些模型捕捉到词与词之间的微妙联系,理解语境的全局连贯性,以及如何将这些理解转化为处理多样化任务的能力。通过学习Transformer,我们不只是在掌握一种技术,更是在探索一个不断发展的领域,这个领域正推动着人工智能的边界,塑造着未来。所以,让我们开启这一段学习之旅,不仅为了掌握一个算法,更为了深入理解这个基于注意力的丰富体系,发现其在语言、思想和技术交汇处的无限可能。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d202eeb6-3936-4814-a773-5b992b30ace0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 2 序列模型的基本思路与根本诉求"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "bfe181d2-490c-487e-89f3-aa4a7eab6d09",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"要理解Transformer模型的本质,首先我们要回归到序列数据、序列模型这些基本概念上来。序列数据是一种按照特定顺序排列的数据,它在现实世界中无处不在,例如股票价格的历史记录、语音信号、文本数据、视频数据等等,主要是按照某种特定顺序排列、且该顺序不能轻易被打乱的数据都被称之为是序列数据。序列数据有着“样本与样本有关联”的特点;对时间序列数据而言,每个样本就是一个时间点,因此样本与样本之间的关联就是时间点与时间点之间的关联。对文字数据而言,每个样本就是一个字/一个词,因此样本与样本之间的关联就是字与字之间、词与词之间的语义关联。很显然,要理解一个时间序列的规律、要理解一个完整的句子所表达的含义,就必须要理解样本与样本之间的关系。\n",
|
|||
|
|
"\n",
|
|||
|
|
"对于一般表格类数据,我们一般重点研究特征与标签之间的关联,但**在序列数据中,众多的本质规律与底层逻辑都隐藏在其样本与样本之间的关联中**,这让序列数据无法适用于一般的机器学习与深度学习算法。这是我们要创造专门处理序列数据的算法的根本原因。在深度学习与机器学习的世界中,**序列算法的根本诉求是要建立样本与样本之间的关联,并借助这种关联提炼出对序列数据的理解**。唯有找出样本与样本之间的关联、建立起样本与样本之间的根本联系,序列模型才能够对序列数据实现分析、理解和预测。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在机器学习和深度学习的世界当中,存在众多经典且有效的序列模型。这些模型通过如下的方式来建立样本与样本之间的关联——\n",
|
|||
|
|
"\n",
|
|||
|
|
"- ARIMA家族算法群\n",
|
|||
|
|
"> 过去影响未来,因此未来的值由过去的值加权求和而成,以此构建样本与样本之间的关联。\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$AR模型:y_t = c + w_1 y_{t-1} + w_2 y_{t-2} + \\dots + w_p y_{t-p} + \\varepsilon_t\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "91425e20-e6fd-4440-820a-6ff7ecef7b23",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 循环网络家族\n",
|
|||
|
|
"> 遍历时间点/样本点,将过去的时间上的信息传递存储在中间变量中,传递给下一个时间点,以此构建样本和样本之间的关联。\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$RNN模型:h_t = W_{xh}\\cdot x_t + W_{hh}\\cdot h_{t-1}$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$LSTM模型:\\tilde{C}_t = tanh(W_{xi} \\cdot x_t + W_{hi} \\cdot h_{t-1} + b_i)$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "93585a7c-b431-4963-8702-506947df6777",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 卷积网络家族\n",
|
|||
|
|
"> 使用卷积核扫描时间点/样本点,将上下文信息通过卷积计算整合到一起,以此构建样本和样本之间的关联。如下图所示,蓝绿色方框中携带权重$w$,权重与样本值对应位置元素相乘相加后生成标量,这是一个加权求和过程。\n",
|
|||
|
|
"\n",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "78a5abd9-3453-4320-8710-8485ba9e5bc2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"总结众多序列架构的经验,你会发现**成功的序列架构都在使用加权求和的方式来建立样本与样本之间的关联**,通过对不同时间点/不同样本点上的值进行加权求和,可以轻松构建“上下文信息的复合表示”,只要尝试着使用迭代的方式求解对样本进行加权求和的权重,就可以使算法获得对序列数据的理解。加权求和是有效的样本关联建立方式,这在整个序列算法研究领域几乎已经形成了共识。**在序列算法发展过程中,核心的问题已经由“如何建立样本之间的关联”转变为了“如何合理地对样本进行加权求和、即如何合理地求解样本加权求和过程中的权重”**。在这个问题上,Transformer给出了序列算法研究领域目前为止最完美的答案之一——**Attention is all you need,最佳权重计算方式是注意力机制**。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b6b1d67b-5fad-4fbc-95a4-c3fd2c9844ac",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 1.1 注意力机制的本质"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "51544e34-867c-481b-b198-30ba20025f30",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"注意力机制是一个帮助算法辨别信息重要性的计算流程,它通过计算样本与样本之间相关性来判断**每个样本之于一个序列的重要程度**,并**给这些样本赋予能代表其重要性的权重**。很显然,注意力机制能够为样本赋予权重的属性与序列模型研究领域的追求完美匹配,Transformer正是利用了注意力机制的这一特点,从而想到利用注意力机制来进行权重的计算。\n",
|
|||
|
|
"\n",
|
|||
|
|
"> **面试考点**<br><br>\n",
|
|||
|
|
"作为一种权重计算机制、注意力机制有多种实现形式。经典的注意力机制(Attention)进行的是跨序列的样本相关性计算,这是说,经典注意力机制考虑的是序列A的样本之于序列B的重要程度。这种形式常常用于经典的序列到序列的任务(Seq2Seq),比如机器翻译;在机器翻译场景中,我们会考虑原始语言系列中的样本对于新生成的序列有多大的影响,因此计算的是原始序列的样本之于新序列的重要程度。\n",
|
|||
|
|
"<br><br>不过在Transformer当中我们使用的是“自注意力机制”(Self-Attention),这是在一个序列内部对样本进行相关性计算的方式,核心考虑的是序列A的样本之于序列A本身的重要程度。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在Transformer架构中我们所使用的是自注意力机制,因此我们将重点围绕自注意力机制来展开讨论,我们将一步步揭开自注意力机制对于Transformer和序列算法的意义——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c1987c96-899b-4c98-a7c8-a3f4354ab6d0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 首先,**为什么要判断序列中样本的重要性?计算重要性对于序列理解来说有什么意义?**\n",
|
|||
|
|
"\n",
|
|||
|
|
"在序列数据当中,每个样本对于“理解序列”所做出的贡献是不相同的,能够帮助我们理解序列数据含义的样本更为重要,而对序列数据的本质逻辑/含义影响不大的样本则不那么重要。以文字数据为例——\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>尽管今天<font color =\"green\">下了雨</font>,但我因为<font color =\"red\">_________</font>而感到<font color =\"red\">非常开心和兴奋</font>。</center>**\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>__________</font>,但我因为拿到了<font color =\"red\">梦寐以求的工作offer</font>而感到<font color =\"red\">非常开心和兴奋</font>。</center>**\n",
|
|||
|
|
"\n",
|
|||
|
|
"观察上面两句话,我们分别抠除了一些关键信息。很显然,第一个句子令我们完全茫然,但第二个句子虽然缺失了部分信息,但我们依然理解事情的来龙去脉。从这两个句子我们明显可以看出,不同的信息对于句子的理解有不同的意义。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "237959f4-eab2-4351-8aa7-c89244cb5d2d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在实际的深度学习预测任务当中也是如此,假设我们依然以这个句子为例——\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>尽管今天<font color =\"green\">下了雨</font>,但我因为拿到了<font color =\"red\">梦寐以求的工作offer</font>而感到<font color =\"red\">非常开心和兴奋</font>。</center>**\n",
|
|||
|
|
"\n",
|
|||
|
|
"假设模型对句子进行情感分析,很显然整个句子的情感倾向是积极的,在这种情况下,“下了雨”这一部分对于理解整个句子的情感色彩贡献较小,相对来说,“拿到了梦寐以求的工作offer”和“感到非常开心和兴奋”这些部分则是理解句子传达的正面情绪的关键。因此对序列算法来说,如果更多地学习“拿到了梦寐以求的工作offer”和“感到非常开心和兴奋”这些词,就更有可能对整个句子的情感倾向做出正确的理解,就更有可能做出正确的预测。\n",
|
|||
|
|
"\n",
|
|||
|
|
"当我们使用注意力机制来分析这样的句子时,自注意力机制可能会为“开心”和“兴奋”这样的词分配更高的权重,因为这些词直接关联到句子的情感倾向。在很长一段时间内、长序列的理解都是深度学习世界的业界难题,在众多研究当中研究者们尝试着从记忆、效率、信息筛选等等方面来寻找出路,而注意力机制所走的就是一条“提效”的道路。**如果我们能够判断出一个序列中哪些样本是重要的、哪些是无关紧要的,就可以引导算法去重点学习更重要的样本,从而可能提升模型的效率和理解能力**。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "20a4906a-2812-46ad-b3e7-51cb48dc277c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 第二,**那样本的重要性是如何定义的?为什么?**\n",
|
|||
|
|
"\n",
|
|||
|
|
"自注意力机制通过**计算样本与样本之间的相关性**来判断样本的重要性,在一个序列当中,如果一个样本与其他许多样本都高度相关,则这个样本大概率会对整体的序列有重大的影响。举例说明,看下面的文字——\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>经理在会议上宣布了重大的公司<font color =\"red\">______</font>计划,员工们反应各异,但都对未来充满期待。</center>**\n",
|
|||
|
|
"\n",
|
|||
|
|
"在这个例子中,我们抠除的这个词与“公司”、“计划”、“会议”、“宣布”和“未来”等词汇都高度相关。如果我们针对这些词汇进行提问,你会发现——\n",
|
|||
|
|
"\n",
|
|||
|
|
"**公司**做了什么?<br>\n",
|
|||
|
|
"**宣布**了什么内容?<br>\n",
|
|||
|
|
"**计划**是什么?<br>\n",
|
|||
|
|
"**未来**会发生什么?<br>\n",
|
|||
|
|
"**会议**上的主要内容是什么?\n",
|
|||
|
|
"\n",
|
|||
|
|
"所有这些问题的答案都围绕着这一个被抠除的词产生。这个完整的句子是——\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>经理在会议上宣布了重大的公司<font color =\"red\">重组</font>计划,员工们反应各异,但都对未来充满期待。</center>**\n",
|
|||
|
|
"\n",
|
|||
|
|
"被抠掉的部分是**重组**。很明显,重组这个词不仅提示了事件的性质、是整个句子的关键,而且也对其他词语的理解有着重大的影响。这个单词对于理解句子中的事件——公司正在经历重大变革,以及员工们的情绪反应——都至关重要。如果没有“重组”这个词,整个句子的意义将变得模糊和不明确,因为不再清楚“宣布了什么”以及“未来期待”是指什么。因此,“重组”这个词很明显对整个句子的理解有重大影响,而且它也和句子中的其他词语高度相关。\n",
|
|||
|
|
"\n",
|
|||
|
|
"相对的,假设我们抠掉的是——\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>经理在会议上宣布了重大的公司<font color =\"red\">重组</font>计划,______反应各异,但都对未来充满期待。</center>**\n",
|
|||
|
|
"\n",
|
|||
|
|
"你会发现,虽然我们缺失了一些信息,但实际上这个信息并不太影响对于整体句子的理解,我们甚至可以大致推断出缺失的信息部分。这样的规律可以被推广到许多序列数据上,在序列数据中我们认为**与其他样本高度相关的样本,大概率会对序列整体的理解有重大影响。因此样本与样本之间的相关性可以用来衡量一个样本对于序列整体的重要性**。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2e73db0e-9032-4ea5-9b6d-5d1cc19f41d7",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 第三,**样本的重要性(既一个样本与其他样本之间的相关性)具体是如何计算的?**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ed8a88d4-6b6e-4319-8d0a-0016064adc5e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在NLP的世界中,序列数据中的每个样本都会被编码成一个向量,其中文字数据被编码后的结果被称为词向量,时间序列数据则被编码为时序向量。\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"因此,要计算样本与样本之间的相关性,本质就是计算向量与向量之间的相关性。**向量的相关性可以由两个向量的点积来衡量**。如果两个向量完全相同方向(夹角为0度),它们的点积最大,这表示两个向量完全正相关;如果它们方向完全相反(夹角为180度),点积是一个最大负数,表示两个向量完全负相关;如果它们垂直(夹角为90度或270度),则点积为零,表示这两个向量是不相关的。因此,向量的点积值的绝对值越大,则表示两个向量之间的相关性越强,如果向量的点积值绝对值越接近0,则说明两个向量相关性越弱。\n",
|
|||
|
|
"\n",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2c0af808-75f2-41d1-bb0b-804387fbdf0d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"向量的点积就是两个向量相乘的过程,设有两个三维向量$\\mathbf{A}$ 和 $\\mathbf{B}$,则向量他们之间的点积可以具体可以表示为:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b06c7ae5-7f41-41d1-b593-eb001c49dcf7",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"$$\n",
|
|||
|
|
"\\mathbf{A} \\cdot \\mathbf{B}^T = \\begin{pmatrix}\n",
|
|||
|
|
"a_1, a_2, a_3\n",
|
|||
|
|
"\\end{pmatrix} \\cdot\n",
|
|||
|
|
"\\begin{pmatrix}\n",
|
|||
|
|
"b_1 \\\\\n",
|
|||
|
|
"b_2 \\\\\n",
|
|||
|
|
"b_3\n",
|
|||
|
|
"\\end{pmatrix} = a_1 \\cdot b_1 + a_2 \\cdot b_2 + a_3 \\cdot b_3\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"相乘的结构为(1,3) y (3,1) = (1,1),最终得到一个标量。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e3b78b7d-7018-4ead-8eea-3a5b003a59b2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在NLP的世界当中,我们所拿到的词向量数据或时间序列数据一定是具有多个样本的。我们需要求解**样本与样本两两之间的相关性**,综合该相关性分数,我们才能够计算出一个样本对于整个序列的重要性。在这里需要注意的是,在NLP的领域中,样本与样本之间的相关性计算、即向量的之间的相关性计算会受到向量顺序的影响。**这是说,以一个单词为核心来计算相关性,和以另一个单词为核心来计算相关性,会得出不同的相关程度,向量之间的相关性与顺序有关**。举例说明:\n",
|
|||
|
|
"\n",
|
|||
|
|
"假设我们有这样一个句子:**我爱小猫咪。**\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - 如果以\"我\"字作为核心词,计算“我”与该句子中其他词语的相关性,那么\"爱\"和\"小猫咪\"在这个上下文中都非常重要。\"爱\"告诉我们\"我\"对\"小猫咪\"的感情是什么,而\"小猫咪\"是\"我\"的感情对象。这个时候,\"爱\"和\"小猫咪\"与\"我\"这个词的相关性就很大。\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - 但是,如果我们以\"小猫咪\"作为核心词,计算“小猫咪”与该剧自中其他词语的相关性,那么\"我\"的重要性就没有那么大了。因为不论是谁爱小猫咪,都不会改变\"小猫咪\"本身。这个时候,\"小猫咪\"对\"我\"这个词的上下文重要性就相对较小。\n",
|
|||
|
|
"\n",
|
|||
|
|
"当我们考虑更长的上下文时,这个特点会变得更加显著:\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - 我爱小猫咪,但妈妈并不喜欢小猫咪。\n",
|
|||
|
|
"\n",
|
|||
|
|
"此时对猫咪这个词来说,谁喜欢它就非常重要。\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - 我爱小猫咪,小猫咪非常柔软。\n",
|
|||
|
|
"\n",
|
|||
|
|
"此时对猫咪这个词来说,到底是谁喜欢它就不是那么重要了,关键是它因为柔软的属性而受人喜爱。\n",
|
|||
|
|
"\n",
|
|||
|
|
"因此,假设数据中存在A和B两个样本,则我们必须计算AB、AA、BA、BB四组相关性才可以。在每次计算相关性时,作为核心词的那个词被认为是在“询问”(Question),而作为非核心的词的那个词被认为是在“应答”(Key),AB之间的相关性就是A询问、B应答的结果,AA之间的相关性就是A向自己询问、A自己应答的结果。\n",
|
|||
|
|
"\n",
|
|||
|
|
"这个过程可以通过矩阵的乘法来完成。假设现在我们的向量中有2个样本(A与B),每个样本被编码为了拥有4个特征的词向量。如下所示,如果我们要计算A、B两个向量之间的相关性,只需要让特征矩阵与其转置矩阵做点积就可以了——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2e41e0f7-d265-48b6-8ba6-ecbbbb51e72a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "3aac4c51-1aeb-4dce-ab3b-423df48cb3ff",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"上述点积结果得到的最终矩阵是:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\begin{bmatrix}\n",
|
|||
|
|
"r_{AA} & r_{AB} \\\\\n",
|
|||
|
|
"r_{BA} & r_{BB} \n",
|
|||
|
|
"\\end{bmatrix}$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4ce9d240-c861-451c-906b-b6815d981755",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"该乘法规律可以推广到任意维度的数据上,如果是带有3个样本的序列与自身的转置相乘,就会得到3y3结构的相关性矩阵,如果是n个样本的序列与自身的转置相乘,就会得到nyn结构的相关性矩阵,这些相关性矩阵代表着**这一序列当中每个样本与其他样本之间的相关性**,相关系数的个数、以及相关性矩阵的结构只与样本的数量有关,与样本的特征维度无关。因此面对任意的数据,我们只需要让该数据与自身的转置矩阵相乘,就可以自然得到**这一序列当中每个样本与其他样本之间的相关性**构成的相关性矩阵了。\n",
|
|||
|
|
"\n",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "92b27875-074d-4aa3-a01f-4c403e1098a7",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"当然,在实际计算相关性的时候,我们一般不会直接使用原始特征矩阵并让它与转置矩阵相乘,**因为我们渴望得到的是语义的相关性,而非单纯数字上的相关性**。因此在NLP中使用注意力机制的时候,**我们往往会先在原始特征矩阵的基础上乘以一个解读语义的$w$参数矩阵,以生成用于询问的矩阵Q、用于应答的矩阵K以及其他可能有用的矩阵**。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在实际进行运算时,$w$是神经网络的参数,是由迭代获得的,因此$w$会依据损失函数的需求不断对原始特征矩阵进行语义解读,而我们实际的相关性计算是在矩阵Q和K之间运行的。使用Q和K求解出相关性分数的过程,就是自注意力机制的核心过程,如下图所示 ↓\n",
|
|||
|
|
"\n",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fecf3d55-e9bd-426d-a792-6ace51d439dc",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"到这里,我们已经将自注意力机制的内容梳理完毕了。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "43b9acd1-d216-4385-9930-47a76f3b6dd6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 1.2 Transformer中的自注意力机制运算流程"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5ebf4cf3-f7a9-49b2-bd43-fa4bc5ce0a4a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"现在我们知道注意力机制是如何运行的了,在Transformer当中我们具体是如何使用自注意力机制为样本增加权重的呢?来看下面的流程。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**Step1:通过词向量得到QK矩阵**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4c9fab64-b32c-44d7-9a50-da0805c0575d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"首先,transformer当中计算的相关性被称之为是**注意力分数**,该注意力分数是在原始的注意力机制上修改后而获得的全新计算方式,其具体计算公式如下——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c89054cb-6fa7-43ff-ab74-1c262aec5e08",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"$$Attention(Q,K,V) = softmax(\\frac{QK^{T}}{\\sqrt{d_k}})V$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "79275bc4-fd44-49be-9eb8-9641fe633b11",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在这个公式中,首先我们要先将原始特征矩阵转化为Q和K,然后令Q乘以K的转置,以获得最基础的相关性分数。同时,我们计算出权重之后,还需要将权重乘在样本上,以构成“上下文的复合表示”,因此我们还需要在原始特征矩阵基础上转化处矩阵V,用于表示原始特征所携带的信息值。假设现在我们有4个单词,每个单词被编码成了6列的词向量,那计算Q、K、V的过程如下所示:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4591b7a4-4e0e-4c42-87ee-1a97f6954473",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4f2b57b9-ecde-4c20-b97b-7fab967186b1",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"其中的$W_Q$与$W_K$的结构都为(6,3),事实上我们值需要保证这两个参数矩阵能够与$y$相乘即可(即这两个参数矩阵的行数与y被编码的列数相同即可),在现代大部分的应用当中,一般$W_Q$与$W_K$都是正方形的结构。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**Step2:计算$QK$相似度,得到相关性矩阵**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "cf0b8135-1426-4534-88ad-34182403f042",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"接下来我们让Q和K的转置相乘,计算出相关性矩阵。\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$Attention(Q,K,V) = softmax(\\frac{QK^{T}}{\\sqrt{d_k}})V$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$QK^{T}$的过程中,点积是相乘后相加的计算流程,因此词向量的维度越高、点积中相加的项也就会越多,因此点积就会越大。此时,词向量的维度对于相关性分数是有影响的,在两个序列的实际相关程度一致的情况下,词向量的特征维度高更可能诞生巨大的相关性分数,因此对相关性分数需要进行标准化。在这里,Transformer为相关性矩阵设置了除以$\\sqrt{d_k}$的标准化流程,$d_k$就是特征的维度,以上面的假设为例,$d_k$=6。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a32c4e92-2bc2-46a8-a47a-fe9345aeed96",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e677a541-7b31-4355-89f2-46a46a7f392d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"**Step3:softmax函数归一化**\n",
|
|||
|
|
"\n",
|
|||
|
|
"将每个单词之间的相关性向量转换成[0,1]之间的概率分布。例如,对AB两个样本我们会求解出AA、AB、BB、BA四个相关性,经过softmax函数的转化,可以让AA+AB的总和为1,可以让BB+BA的总和为1。这个操作可以令一个样本的相关性总和为1,从而将相关性分数转化成性质上更接近“权重”的[0,1]之间的比例。这样做也可以控制相关性分数整体的大小,避免产生数字过大的问题。\n",
|
|||
|
|
"\n",
|
|||
|
|
"经过softmax归一化之后的分数,就是注意力机制求解出的**权重**。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "20287472-a299-44a2-92d0-3a1a5a50c40e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"**Step4:对样本进行加权求和,建立样本与样本之间的关系**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e84a81cf-7d23-4481-9a7e-275d2327aed5",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "53e32837-f0d9-4094-b938-ebd57a3f9b68",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"现在我们已经获得了softmax之后的分数矩阵,同时我们还有代表原始特征矩阵值的V矩阵——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7966d731-6e83-447d-8e21-eb9cb8bdd817",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"$$\n",
|
|||
|
|
"\\mathbf{r} = \\begin{pmatrix}\n",
|
|||
|
|
"a_{11} & a_{12} \\\\\n",
|
|||
|
|
"a_{21} & a_{22}\n",
|
|||
|
|
"\\end{pmatrix},\n",
|
|||
|
|
"\\quad\n",
|
|||
|
|
"\\mathbf{V} = \\begin{pmatrix}\n",
|
|||
|
|
"v_{11} & v_{12} & v_{13} \\\\\n",
|
|||
|
|
"v_{21} & v_{22} & v_{23}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"二者相乘的结果如下:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fb4231d2-1af1-4e1c-80a8-4fcb65489715",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"$$\n",
|
|||
|
|
"\\mathbf{Z(Attention)} = \\begin{pmatrix}\n",
|
|||
|
|
"a_{11} & a_{12} \\\\\n",
|
|||
|
|
"a_{21} & a_{22}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"\\begin{pmatrix}\n",
|
|||
|
|
"v_{11} & v_{12} & v_{13} \\\\\n",
|
|||
|
|
"v_{21} & v_{22} & v_{23}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"= \\begin{pmatrix}\n",
|
|||
|
|
"(a_{11}v_{11} + a_{12}v_{21}) & (a_{11}v_{12} + a_{12}v_{22}) & (a_{11}v_{13} + a_{12}v_{23}) \\\\\n",
|
|||
|
|
"(a_{21}v_{11} + a_{22}v_{21}) & (a_{21}v_{12} + a_{22}v_{22}) & (a_{21}v_{13} + a_{22}v_{23})\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "13f70f41-6666-42df-8446-7fd3202bbb6a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"观察最终得出的结果,式子$a_{11}v_{11} + a_{12}v_{21}$不正是$v_{11}$和$v_{21}$的加权求和结果吗?$v_{11}$和$v_{21}$正对应着原始特征矩阵当中的第一个样本的第一个特征、以及第二个样本的第一个特征,这两个v之间加权求和所建立的关联,正是两个样本之间、两个时间步之间所建立的关联。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在这个计算过程中,需要注意的是,列脚标与权重无关。因为整个注意力得分矩阵与特征数量并无关联,因此在乘以矩阵$v$的过程中,矩阵$r$其实并不关心一行上有多少个$v$,它只关心这是哪一行的v。因此我们可以把Attention写成:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\mathbf{Z(Attention)} = \\begin{pmatrix}\n",
|
|||
|
|
"a_{11} & a_{12} \\\\\n",
|
|||
|
|
"a_{21} & a_{22}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"\\begin{pmatrix}\n",
|
|||
|
|
"v_{11} & v_{12} & v_{13} \\\\\n",
|
|||
|
|
"v_{21} & v_{22} & v_{23}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"= \\begin{pmatrix}\n",
|
|||
|
|
"(a_{11}v_{1} + a_{12}v_{2}) & (a_{11}v_{1} + a_{12}v_{2}) & (a_{11}v_{1} + a_{12}v_{2}) \\\\\n",
|
|||
|
|
"(a_{21}v_{1} + a_{22}v_{2}) & (a_{21}v_{1} + a_{22}v_{2}) & (a_{21}v_{1} + a_{22}v_{2})\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"很显然,对于矩阵$a$而言,原始数据有多少个特征并不重要,它始终都在建立样本1与样本2之间的联系。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5ba8f279-efaf-479d-b0b1-c44108e35fe1",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 1.3 Multi-Head Attention 多头注意力机制"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "0e9ab39d-e88e-4118-9110-16908d09a61b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"Multi-Head Attention 就是在self-attention的基础上,对于输入的embedding矩阵,self-attention只使用了一组$W^Q,W^K,W^V$ 来进行变换得到Query,Keys,Values。而Multi-Head Attention使用多组$W^Q,W^K,W^V$ 得到多组Query,Keys,Values,然后每组分别计算得到一个Z矩阵,最后将得到的多个Z矩阵进行拼接。Transformer原论文里面是使用了8组不同的$W^Q,W^K,W^V$ 。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "68d8cc2f-16ef-4d50-9953-c11c1c3f1d32",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d0e7f976-ad99-4e64-b4e9-2dbfdec9d23c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"假设每个头的输出$Z_i$是一个维度为(2,3)的矩阵,如果我们有$h$个注意力头,那么最终的拼接操作会生成一个维度为(2, 3h)的矩阵。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "caa814f5-46d9-422b-879d-ad758f22519c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"假设有两个注意力头的例子:\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. 头1的输出 $ Z_1 $:\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"Z_1 = \\begin{pmatrix}\n",
|
|||
|
|
"z_{11} & z_{12} & z_{13} \\\\\n",
|
|||
|
|
"z_{14} & z_{15} & z_{16}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"2. 头2的输出 $ Z_2 $:\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"Z_2 = \\begin{pmatrix}\n",
|
|||
|
|
"z_{21} & z_{22} & z_{23} \\\\\n",
|
|||
|
|
"z_{24} & z_{25} & z_{26}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"3. 拼接操作:\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"Z_{\\text{concatenated}} = \\begin{pmatrix}\n",
|
|||
|
|
"z_{11} & z_{12} & z_{13} & z_{21} & z_{22} & z_{23} \\\\\n",
|
|||
|
|
"z_{14} & z_{15} & z_{16} & z_{24} & z_{25} & z_{26}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2e3af21d-caef-45a4-9e34-85546f5c7972",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"一般情况:\n",
|
|||
|
|
"\n",
|
|||
|
|
"对于$h$个注意力头,每个头的输出$Z_i$为:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"Z_i = \\begin{pmatrix}\n",
|
|||
|
|
"z_{i1} & z_{i2} & z_{i3} \\\\\n",
|
|||
|
|
"z_{i4} & z_{i5} & z_{i6}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "506ebeb2-f529-4093-b8fe-05a9fac4429d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"总拼接操作如下:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"Z_{\\text{concatenated}} = \\begin{pmatrix}\n",
|
|||
|
|
"z_{11} & z_{12} & z_{13} & z_{21} & z_{22} & z_{23} & \\cdots & z_{h1} & z_{h2} & z_{h3} \\\\\n",
|
|||
|
|
"z_{14} & z_{15} & z_{16} & z_{24} & z_{25} & z_{26} & \\cdots & z_{h4} & z_{h5} & z_{h6}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"最终的结构为(2,3h)。因此假设特征矩阵中,序列的长度为100,序列中每个样本的embedding维度为3,并且设置了8头注意力机制,那最终输出的序列就是(100,24)。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7c841ff2-008b-4bde-aca2-caf563699989",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "44c706ea-07c9-426f-ba34-57114fa8e823",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"以上就是Transformer当中的自注意力层,Transformer就是在这一根本结构的基础上建立了样本与样本之间的链接。在此结构基础上,Transformer丰富了众多的细节来构成一个完整的架构。让我们现在就来看看Transformer的整体结构。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "6a4306b7-3723-471f-9d16-f5e6a0a76f91",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"让我们一起来看看Transformer算法都由哪些元素组成,以下是来自论文《All you need is Attention》的架构图:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "8db1c453-7eed-4cd4-b0a4-60e74b320627",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<center><img src=\"https://machinelearningmastery.com/wp-content/uploads/2021/08/attention_research_1.png\" alt=\"描述文字\" width=\"400\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "f1ebbc78-282b-46d1-90b6-39b040fe81a3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"Transformer的总体架构主要由两大部分构成:编码器(Encoder)和解码器(Decoder)。在Transformer中,编码是解读数据的结构,在NLP的流程中,编码器负责解构自然语言、将自然语言转化为计算机能够理解的信息,并让计算机能够学习数据、理解数据;而解码器是将被解读的信息“还原”回原始数据、或者转化为其他类型数据的结构,它可以让算法处理过的数据还原回“自然语言”,也可以将算法处理过的数据直接输出成某种结果。因此在transformer中,编码器负责接收输入数据、负责提取特征,而解码器负责输出最终的标签。当这个标签是自然语言的时候,解码器负责的是“将被处理后的信息还原回自然语言”,当这个标签是特定的类别或标签的时候,解码器负责的就是“整合信息输出统一结果”。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在信息进入解码器和编码器之前,我们首先要对信息进行**Embedding和Positional Encoding两种编码**,这两种编码在实际代码中表现为两个单独的层,因此这两种编码结构也被认为是Transformer结构的一部分。经过编码后,数据会进入编码器Encoder和解码器decoder,其中编码器是架构图上左侧的部分,解码器是架构图上右侧的部分。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**编码器(Encoder)结构包括两个子层:一个是多头的自注意力(Self-Attention)层,另一个是前馈(Feed-Forward)神经网络**。输入数据会先经过自注意力层,这层的作用是为输入数据中不同的信息赋予重要性的权重、让模型知道哪些信息是关键且重要的。接着,这些信息会经过前馈神经网络层,这是一个简单的全连接神经网络,用于将多头注意力机制中输出的信息进行整合。两个子层都被武装了一个残差连接(Residual Connection),这两个层输出的结果都会有残差链接上的结果相加,再经过一个层标准化(Layer Normalization),才算是得到真正的输出。在神经网络中,多头注意力机制+前馈网络的结构可以有很多层,在Transformer的经典结构中,encoder结构重复了6层。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**解码器(Decoder)也是由多个子层构成的:第一个也是多头的自注意力层(此时由于解码器本身的性质问题,这里的多头注意力层携带掩码),第二个子层是普通的多头注意力机制层,第三个层是前馈神经网络**。自注意力层和前馈神经网络的结构与编码器中的相同。注意力层是用来关注编码器输出的。同样的,每个子层都有一个残差连接和层标准化。在经典的Transformer结构中,Decoder也有6层。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "9c6f2bbc-7263-4fcd-88ef-6997bdde5f56",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"**这个结构看似简单,但其实奥妙无穷,这里有许多的问题等待我们去挖掘和探索**。现在就让我们从解码器部分开始逐一解读transformer结构。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "bad3022c-e161-462f-bdfe-23b59d0fb43a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/image-1.png\" alt=\"描述文字\" width=\"400\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "3358bb14-46ce-4556-b974-5e8fd9160c6a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 2.1 Embedding层与位置编码技术"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e4884c4a-5e1f-409e-8937-189aebbacb31",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在Transformer中,embedding层位于encoder和decoder之前,主要负责进行语义编码。Embedding层将离散的词汇或符号转换为连续的高维向量,使得模型能够处理和学习这些向量的语义关系。通过嵌入表示,输入的序列可以更好地捕捉到词与词之间的相似性和关系。此外,在输入到编码器和解码器之前,通常还会添加位置编码(Positional Encoding),因为Transformer没有内置的序列顺序信息,也就是说Attention机制本身会带来**位置信息的丧失**。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **首先,位置信息为什么重要?它可以从哪里来?**\n",
|
|||
|
|
"\n",
|
|||
|
|
"首先,位置信息就是顺序的信息,字符排列的顺序会影响语句的理解(还记得“屡战屡败”和“屡败屡战”的例子吗?同样的词在句子不同的地方出现,也可能会有不同的含义),我们说Transformer丧失了位置信息,意思是transformer并不理解样本与样本之间是按照什么顺序排列的(也就是不知道样本在序列中具体的位置)。\n",
|
|||
|
|
"\n",
|
|||
|
|
"还记得RNN和LSTM是如何处理数据的吗?RNN和LSTM以序列的方式处理输入数据,即一个时间步一个时间步地处理输入序列的每个元素。每个时间步的隐藏状态依赖于前一个时间步的隐藏状态。这种机制天然地捕捉了序列的顺序信息。由于RNN和LSTM在处理序列时会保留前一时间步的信息并传递到下一时间步,所以它们能够内在地理解和处理序列的时间依赖关系和顺序信息。然而,与RNN和LSTM不同,Transformer并不以序列的方式逐步处理输入数据,而是一次性处理整个序列。Attention能够通过点积的方式一次性计算出所有向量之间的相关性、并且多头注意力机制中不同的头还可以并行,因此Attention与Transformer缺乏天然的顺序信息。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **相关性计算过程中有标识,这些标识不能够成为位置信息吗?什么信息才算是位置信息/顺序信息呢?**\n",
|
|||
|
|
"\n",
|
|||
|
|
"在注意力机制中,权重矩阵$softmax(\\frac{QK^{T}}{\\sqrt{d_k}})$中的每个元素$a_{ij}$表示序列中位置$i$和位置$j$之间的相关性,但是却并没有假设这两个相关的元素之间的位置信息。具体来说,虽然我们使用了1、2这样的脚标,但Attention实际在进行计算的时候,只会认知两个具体的相关性数字,并没有显性地认知到脚标——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\mathbf{Z(Attention)} = \\begin{pmatrix}\n",
|
|||
|
|
"a_{11} & a_{12} \\\\\n",
|
|||
|
|
"a_{21} & a_{22}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"\\begin{pmatrix}\n",
|
|||
|
|
"v_{11} & v_{12} & v_{13} \\\\\n",
|
|||
|
|
"v_{21} & v_{22} & v_{23}\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"= \\begin{pmatrix}\n",
|
|||
|
|
"(a_{11}v_{11} + a_{12}v_{21}) & (a_{11}v_{12} + a_{12}v_{22}) & (a_{11}v_{13} + a_{12}v_{23}) \\\\\n",
|
|||
|
|
"(a_{21}v_{11} + a_{22}v_{21}) & (a_{21}v_{12} + a_{22}v_{22}) & (a_{21}v_{13} + a_{22}v_{23})\n",
|
|||
|
|
"\\end{pmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"由于Transformer模型放弃了“逐行对数据进行处理”的方式,而是一次性处理一整张表单,因此它不能直接像循环神经网络RNN那样在训练过程中就捕捉到单词与单词之间的位置信息。在经典的深度学习场景当中,最典型的顺序信息就是**数字的大小**。由于数字天生是带有大小顺序的,因此数字本身可以被认为是含有顺序一个信息,只要让有顺序的信息和有顺序的数字相匹配,就可以让算法天然地认知到相应的顺序。因此我们自然而然地**想要对样本的位置本身进行“编码”**,利用数字本身自带的顺序来告知Transformer。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "36bcbdda-2068-4b41-a958-65bd1fb4254e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **位置信息如何被告知给Attention/Transformer这样的算法?**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ad08a935-f6bc-4698-b34f-de5c385545d9",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"为了解决位置信息的问题,Transformer引入了位置编码(positional encoding)技术来补充语义词嵌入。我们首先将样本的位置转变成相应的数字或向量,然后让位置编码的这个向量被加到原有的词嵌入向量embedding向量上,这样模型就可以同时知道一个词的语义和它在句子中的位置。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "12e9d14a-4d3b-4c9f-9e95-c91d11d0aa6e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7fa0573c-3537-4a8d-a481-799ba78990a4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"(上图中的数字编号有误,无论是embedding、postion encoding还是最终加和的结果,都应该等于512,521是错误的表示)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a1aca67e-4fcd-4173-8791-209376d5426f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"位置编码使用了一种特殊的函数,这个函数会为序列中的每个位置生成一个向量。对于一个特定的位置,这个函数生成的向量在所有维度上的值都是不同的。这保证了每个位置的编码都是唯一的,而且不同位置的编码能够保持一定的相对关系。**在transformer的位置编码中,我们需要对每个词的每个特征值给与位置编码,所有这些特征位置的编码共同组合成了一个样本的位置编码**。例如,当一个样本拥有4个特征时,我们的位置编码也会是包含4个数字的一个向量,而不是一个单独的编码。因此,**位置编码矩阵是一个与embedding后的矩阵结构相同的矩阵**。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在Transformer模型中,词嵌入和位置编码被相加,然后输入到模型的第一层。这样,Transformer就可以同时处理词语的语义和它在句子中的位置信息。这也是Transformer模型在处理序列数据,特别是自然语言处理任务中表现出色的一个重要原因。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c5feef31-b541-4e79-868f-429d7201ad82",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **Transformer中的正余弦位置编码**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2f24d71b-c0d6-4a51-ab75-242b3e5eada3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在过去最为经典的位置编码就是OrdinalEncoder顺序编码,但在Transformer中我们需要的是一个编码向量,而非单一的编码数字,因此OrdinalEncoder编码就不能使用了。在众多的、构成编码向量的方式中,Transformer选择了“**正余弦编码**”这一特别的方式。让我们一起来看看正余弦编码的含义——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "23c62d58-c6e7-478e-bc02-d838e2c9453e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"首先,正余弦编码是使用正弦函数和余弦函数来生成具体编码值的编码方式。对于任意的词向量(也就是数据中的一个样本),正余弦编码在偶数维度上采用了sin函数来编码,奇数维度采用了cos函数来编码,sin函数与cos函数交替使用,最终构成一个多维度的向量——\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "208c74eb-6174-401c-8928-a8a06a8ea3de",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"通过对不同的维度进行不同的三角函数编码,来构成一串独一无二的编码组合。这种编码组合与embedding类似,都是将信息投射到一个高维空间当中,只不过正余弦编码是将样本的位置信息(也就是样本的索引)投射到高维空间中,且每一个特征的维度代表了这个高维空间中的一维度。对正余弦编码来说,编码数字本身是依赖于**样本的位置信息(索引)、所有维度的编号、以及总维度数三个因子**计算出来的。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "26b91d7d-3416-4606-a5ed-ea65109a2a5e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"具体来看,正余弦编码的公式如下:\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 正弦位置编码(Sinusoidal Positional Encoding)\n",
|
|||
|
|
"$$PE_{(pos, 2i)} = \\sin \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 余弦位置编码(Cosine Positional Encoding)\n",
|
|||
|
|
"$$ PE_{(pos, 2i+1)} = \\cos \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"将这段LaTey代码粘贴到支持LaTey的环境中(如LaTey编辑器或支持LaTey的Markdown渲染器)即可得到公式的正确显示。\n",
|
|||
|
|
"\n",
|
|||
|
|
"其中——\n",
|
|||
|
|
"> - pos代表样本在序列中的位置,也就是样本的索引(是三维度中的seq_len/vocal_size/time_step这个维度上的索引)<br><br>\n",
|
|||
|
|
"> - $2i$和$2i+1$分别代表embedding矩阵中的偶数和奇数维度索引,当我们让i从0开始循环增长时,可以获得[0,1,2,3,4,5,6...]这样的序列。<br><br>\n",
|
|||
|
|
"> - $d_{\\text{model}} $ 代表embedding后矩阵的总维度。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d327b655-f508-41f7-a860-61d3c787cf00",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在这里,你可以选择停下脚步、开启下一节课程,你也可以选择继续听更深入的关于位置编码的内容。这里有选择的原因是,位置编码作为深度学习和时间序列处理过程中非常重要的一种技术,在不同的场景下被频繁地使用,我们可以将其用于纹理建模、声音处理、信号处理、震动分析等多种场合,但同时,我们也将它作为一种行业惯例在进行使用,因此你或许无需对正余弦位置编码进行特别深入的探索。\n",
|
|||
|
|
"\n",
|
|||
|
|
"但正余弦位置编码本身是一种非常奇妙的结构,在接下来的内容中,我将带你仔细剖析正余弦位置编码的诸多细节和意义。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5ab6b2c2-dcb6-4006-9b2a-0f13ab1a087e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **为什么要使用正余弦编码?它有什么意义?**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b40f6119-02b4-4ff5-bae1-fe08e9be6c2f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 正弦位置编码(Sinusoidal Positional Encoding)\n",
|
|||
|
|
"$$PE_{(pos, 2i)} = \\sin \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 余弦位置编码(Cosine Positional Encoding)\n",
|
|||
|
|
"$$ PE_{(pos, 2i+1)} = \\cos \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "f0d19138-d827-425a-953d-de8170bdc39e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"首先我们先来看Pos,Pos是样本的索引,它有多大会取决于实际的数据尺寸。如果一个时间序列或文本数据是很长的序列,那Pos数值本身也会变得很大。假设我们使用很大的数值与原本的embedding序列相加,那位置编码带来的影响可能会远远超过原始的语义、会导致喧宾夺主的问题,因此我们天然就有限制位置编码的大小的需求。在这个角度来看,使用sin和cos这样值域很窄的函数、就能够很好地限制位置编码地大小。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>正余弦编码的意义①:sin和cos函数值域有限,可以很好地限制位置编码的数字大小。**\n",
|
|||
|
|
"\n",
|
|||
|
|
"假设我们使用的是单变量序列,那我们或许只需要sin(pos)或者cos(pos)看起来就足够了,但为了给每个不同的维度都进行编码,我们肯定还要做点儿别的文章。首先,位置信息和语义信息一样,当我们将其投射到高维空间时,我们也在尝试用不同的维度来解读位置信息。但我们使用正弦余弦这样的三角函数时,如何能够将信息投射到不同的维度呢——答案是创造各不相同的sin和cos函数。虽然都是正弦/余弦函数,但我们可以为函数设置不同的频率来获得各种高矮胖瘦的函数——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 58,
|
|||
|
|
"id": "e0c806bb-5df1-4b47-bb2c-2f4bcbc09e4f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAKACAYAAACBhdleAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd1xUV/r48c+hg6AiIKCoFCmKCPbeTTU9MWuSTWLMbjZtk2zKJrvZb5LNZjfZX7Kbnmx63/TeEzVq1IgVGwKCgiIdVHqd8/tjwEUFacPcO8Pzfr18RZiZe5+Z4MN97jnnOUprjRBCCCGEEEKInnMxOgAhhBBCCCGEcBZSYAkhhBBCCCGEjUiBJYQQQgghhBA2IgWWEEIIIYQQQtiIFFhCCCGEEEIIYSNSYAkhhBBCCCGEjUiBJUxNKTVLKZVudBxCiL5Bco4Qwl4k3zgvKbDEKSmlspVSNUqpylZ/hvTi+bRSamTL11rrn7XWsb10rheVUulKKYtSamlvnEMI0TXOmnOUUjFKqc+VUsVKqTKl1PdKqV7JbUKIznHifBOolFqnlCpVSh1RSv2ilJph6/OI9kmBJTrjXK21b6s/eUYHZCPbgRuBrUYHIoQ4jjPmnIHAF0AsEAxsBD43MiAhBOCc+aYSWAYEAf7AP4EvlVJuhkbVh0iBJbql+a7PwlZfP6CUerv57+HNd2muVkodUEqVKKXubfVcV6XUn5VSWUqpCqXUFqXUMKXUmuanbG++i/QrpdRcpVRuq9eOUkqtar4js1spdV6rx15XSj2rlPq6+bjJSqmo9t6D1vpZrfUKoNaWn40QwvYcPedorTdqrV/RWpdprRuAx4FYpVSAjT8qIUQPOUG+qdVap2utLYACmrAWWoNs+kGJdkmBJXrTTKx3axcA9ymlRjV//3bgMuBsoD/WuyzVWuvZzY8nNt9Fer/1wZRS7sCXwA/AYOD3wDsnTLNZAvwVayLJBP7eG29MCGFKjpRzZgMFWuvSrr1FIYRJmD7fKKV2YL2J/AXwsta6qJvvVXSRFFiiMz5rvptyRCn1WRde91etdY3WejvW6XiJzd//DfCX5rsrWmu9vZMXGVMBX+ARrXW91nol8BXWRNbi0+Y7xY3AO0BSF+IVQpiDU+ccpVQY8CzWCzEhhLGcNt9orcdiLfIuB9Z24b2JHpK5mKIzLtBaL+/G6wpa/b0aa+IAGAZkdeN4Q4CDzUPeLXKAoZ04pxDCcThtzlFKBWG9Q/2c1vrdbsQkhLAtp803YJ0uCLyrlNqjlEppLghFL5MRLNFdVYBPq69DuvDag0C7a6NOIQ8YppRq/XM7HDjUjWMJIRyLw+ccpZQ/1uLqC621TF8WwrwcPt+0wR2ItNGxRAekwBLdlQIsUUq5K6UmApd04bUvA39TSkUrq7GtFnoX0n4CSMZ6x+aPzeedC5wLvNedN6CU8lBKeWFdAOqulPI6IbEJIcwjBQfOOUqp/sD3wDqt9T1dfb0Qwq5ScOx8M1UpNbP5OsdbKXU31u6lyV09lugeuZgU3fV/WO/QHMa64PK/XXjtv4EPsN7JLQdeAbybH3sAeKN5LvSlrV+kta7HmmzOAkqA54CrtNZp3XwPPwA1wHTgxea/zz7lK4QQRnH0nHMhMAm4Rh2/587wbhxLCNG7HD3feGJd51mKdQTsbGCRk7SgdwhKa210DEIIIYQQQgjhFGQESwghhBBCCCFsRAosIYQQQgghhLARKbCEEEIIIYQQwkakwBJCCCGEEEIIG3G6jYYDAwN1eHi40WEIIZpt2bKlRGsdZHQcvUVyjhDmIflGCGFP7eUcpyuwwsPD2bx5s9FhCCGaKaVyjI6hN0nOEcI8JN8IIeypvZwjUwSFEEIIIYQQwkakwBJCCCGEEEIIGzG0wFJKvaqUKlJK7WrncaWUekoplamU2qGUGm/vGIUQzkHyjRDCniTnCNF3Gb0G63XgGeDNdh4/C4hu/jMFeL75v0J0S0NDA7m5udTW1hoditPx8vIiLCwMd3d3o0Npz+tIvhFC2M/rSM4Rok8ytMDSWq9RSoWf4innA29qrTWwQSk1UCkVqrXOt0+EzqfqaBWFOSUcKTpKfW0DTY1N+PT3xs/fl+DwIPz8fY0OsVfl5ubi5+dHeHg4Simjw3EaWmtKS0vJzc0lIiLC6HDa1FfzjcVi4cC+Yoryj+Dq6kL4yGACBvc3OixhgPIj1exLz6e2poGgkAGMiBqMm7ur0WE5LWfKOVprMnOKKS6tIHJ4ECFBkkNsRWtNxv4iyg5XMTI8iKAAP6NDchpNTRb27i/icHk1sZHBDBrYz27nNnoEqyNDgYOtvs5t/t5xyUcpdR1wHcDw4cPtFpzZaa3J3n2Q5K+2kLohg7TkvRwuPHrK1wwcPIDIxBEkzoknaf4Y4iaPxMXFeZbq1dbWSnHVC5RSBAQEUFxcbHQoPdGpfAOOkXNqquv45K31fP3BRsqKK457LG7sMC5ZOpMZC0bLv4U+YOuGTN5/aTU7NmdjvZa3GuDvw2nnj+dX187Gb4CPgRH2WQ5xjXPgUBkPPf0tqXvzm+OBM+fE84ffLMDH28Pu8TiTfQdKeOipb8jYXwSAi4vinAUJ3HLNPLw8TTsbxCFk7Cvkoae/Zd+BEsD62V505jhuvGo2Hu69X/6YvcDqFK31i8CLABMnTtQdPN3pFR0o5puXVvDTe2vJyyoEYFjsECacnkh4/HCCRwQyKNQfDy93XN1cqa6ooaKskvx9RRzYk0v6pkxe+8u7AAQNC2DupdM5Y9l8RowKM/Jt2YxcUPaOvvS5mj3n7Nqazf/780cU5R1h0swYlt5yGsPCg6ivbyBj5yG++3QLD93+LpNmxnDXPy6h/0C5uHZGNdV1PPPQl6z4KoXA4P5ccf084seNwNvHg/zcMtav3MMnb67jh8+2csdDFzNldqzRIYs2GJlv9h8s4eb/ex+AO69bSOSIINZuyuS9LzaTc6iMJx9YjLeXFFndkbGvkJvvex8vTzfuvuF0RoQF8NP6dD76ZisHDpXx2F8uliKrm3akHeKOv32EXz8v/nzTmQwNHciPP+859tn+vz9fiJtb747em73AOgQMa/V1WPP3RBvSN2Xyzt8/JvmrLWgNE04fy6V3nc/UcycSEOrfpWMdLSln03cprHp/HZ88+Q0f/utLJp01jktuP5dx88f0qYtp0Wc4Rb5Z+XUKj9/3KUGhA3ns9d8wZnz4cY8nTorkoqum8+X7G3nl399x85LneOSlaxgyLMCYgEWvKCup4P6b3yIrLZ8rfjePX/12Dh4e//uVHzd2GPPOTmR/RgGP/eVj7r/5La676ywuunKGgVH3OabOOVXVdfzxH5/g5urCsw8tIaz5OmJs3FDiY4bwf499wd+f+Y6/3XGuXBN00ZHyau76xyf49vPk+b9fRnCgdcrl2LihjI4O5cEnv+aJV1Zyz41nGByp4ykpq+Te//c5Af6+PP3XS49NuUwcFUZsZDD/fP4HnnljNbddO79X4zD73K8vgKuaO+1MBY6acW6y0fbvzOH/zn+Em6f8iV1r0/jV3Rfw1r5nefjbv7DoutO6XFwBDAjsz8Jfz+ahL//Ee4deYOmDS8jcuo+7T3uQPy78K+mbs3rhnfQNrq6uJCUlHfuTnZ1ts2N/9tlnpKamHvv6vvvuY/ny5T0+bmlpKfPmzcPX15ebb765x8czKYfPNz99s51H//wxo5KG8+Q7vzupuGrh6ubKBVdM47HXf0tNVR1/XPYKeQdL7Rus6DUVR6u55zevcnB/Cfc/9WuuvGnBccVVaxExITz+1nXMXBjPi49+y4ev/WznaPs0U+ecJ1/7icKSCv5+9/nHiqsWc6ZE89vLZrLqlwyWr00zKELH9e+XVnC0ooZ//unCY8VVi9NmjeLXF07hqxU7WZO816AIHdejL/xIbV0DD999/knr2c5dOJbFi8bz0Tdb2byjd/ckN7pN+7vAL0CsUipXKXWtUup6pdT1zU/5BtgHZAIvATcaFKopVVfU8J/bX+f68X9k189pLP3bEt7a9yzL/n45wSOCbHaegUEDuOIvF/N29vPc9OQ
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 864x648 with 9 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {
|
|||
|
|
"needs_background": "light"
|
|||
|
|
},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import matplotlib.pyplot as plt\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 定义绘制正弦函数的函数\n",
|
|||
|
|
"def plot_sin_functions(num_functions):\n",
|
|||
|
|
" y = np.linspace(0, 10, 1000) # 定义 y 轴范围\n",
|
|||
|
|
" colors = plt.cm.viridis(np.linspace(0, 1, num_functions)) # 生成颜色序列\n",
|
|||
|
|
"\n",
|
|||
|
|
" fig, ays = plt.subplots(3, 3, figsize=(12, 9)) # 创建3y3子图\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 绘制每个正弦函数\n",
|
|||
|
|
" for i, ay in enumerate(ays.flat):\n",
|
|||
|
|
" if i < num_functions:\n",
|
|||
|
|
" frequency = (i + 1) * 0.5 # 通过增加倍数来调整频率\n",
|
|||
|
|
" y = np.sin(frequency * y)\n",
|
|||
|
|
" ay.plot(y, y, label=f'Function {i+1}', color=colors[i])\n",
|
|||
|
|
" ay.set_title(f'Function {i+1}')\n",
|
|||
|
|
" ay.set_ylabel('y')\n",
|
|||
|
|
" ay.set_ylabel('sin(y)')\n",
|
|||
|
|
" ay.legend()\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.tight_layout()\n",
|
|||
|
|
" plt.show()\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 绘制9个正弦函数\n",
|
|||
|
|
"plot_sin_functions(9)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5e810dc0-2972-4ddb-b0b1-631a46ec879d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"如果能够将不同的特征投射到**不同频率的sin和cos函数上**,就可以让每个特征都投射到一个独特的维度上,各类不同的信息维度共同构成一个解构位置信息的空间,就能够形成对位置信息的深度解读。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>正余弦编码的意义②:通过调节频率,我们可以得到多种多样的sin和cos函数,<br><br>从而可以将位置信息投射到每个维度都各具特色、各不相同的高维空间,以形成对位置信息的更好的表示**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "eb4fcbb1-ee09-4810-8ac6-a5c81f79c0e2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 正弦位置编码(Sinusoidal Positional Encoding)\n",
|
|||
|
|
"$$PE_{(pos, 2i)} = \\sin \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 余弦位置编码(Cosine Positional Encoding)\n",
|
|||
|
|
"$$ PE_{(pos, 2i+1)} = \\cos \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "180b8539-cea9-4a11-adb2-909321b4358d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"接下来的问题就是如何赋予sin和cos函数不同的频率了——在sin和cos函数的自变量上乘以不同的值,就可以获得不同频率的sin和cos函数。\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$y = sin(frequency * y)$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"在位置编码的场景下,我们的自变量是样本的位置pos,因此特征的位置(2i和2i+1)就可以被用来创造不同的频率。在这里,我们对pos这个数字进行了scaling(压缩)的行为。具体地说,我们使用了$10000^{\\frac{2i}{d_{\\text{model}}}}$来作为我们缩放的因子,将它作为除数放在pos的下方。但这其实是在pos的基础上乘以$\\frac{1}{10000^{\\frac{2i}{d_{\\text{model}}}}}$这个频率的行为。因此,引入特征位置本身来进行缩放可以带来不同的频率,帮助我们将位置信息pos投射不同频率的三角函数上,确保不同位置(pos)在不同的特征维度(2i和2i+1)上有不同的编码值。\n",
|
|||
|
|
"\n",
|
|||
|
|
"那下一个问题是,这些正余弦函数的频率是随机的吗?我们应该如何控制它呢?正余弦编码最为巧妙的地方来了——通过让位置信息pos乘以$\\frac{1}{10000^{\\frac{2i}{d_{\\text{model}}}}}$这个频率,**特征编号比较小的特征会得到大频率,会被投射到高频率的正弦函数上,而特征编号较大的特征会得到小频率,会被投射到低频率的正弦函数上**👇"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 61,
|
|||
|
|
"id": "d7e3e910-b513-454f-8556-5f19ab50105f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0EAAAHxCAYAAACrqLeFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAAsTAAALEwEAmpwYAABTKUlEQVR4nO3dfZyVdZ3/8ddn5py54T7xHlBQQKVAMLxBNEnTzEirxbLF1LTMbW3bNnezslu3u1+71rqR2W5qeZOWZZnZamGUFoY3KDciioqCiiggMMDcf39/zMEGBOYAc+aamfN6Ph7nMef6Xte5zvvMfB96Pny/1/eKlBKSJEmSVC4qsg4gSZIkSV3JIkiSJElSWbEIkiRJklRWLIIkSZIklRWLIEmSJEllxSJIkiRJUlmxCJIkSZJUViyCJKkLRcTCiJiSdY6dERFfj4h/zjpHsSLipIj4fET0387+t0fEL4s81/SIuHsn3ntWRHy42OM7W0QMj4gUEbkijj0vIu4rPK+OiMcjYq92+98VEbeUMq8kZcUiSJK6UErpjSmlWaU6f0RURcStEbG08GV4SpGveSUi+m1j317AOcDVxZw/2nwzIlYVHt+MiGi3f3xEPBQRGws/x3fGa9sdczzwC+A04LaIqNrGR/4q8I3C8XtHxE8i4oWIWBsRf46IozcfmFK6MaV0ylbvMSki/rKDX2mPk1JqAK4BLm3X9mvgjRExLrNgklQiFkGS1PvcB5wNrCjy+LcAj6SU6rax7zzgzpTSpiLPfyHwbuBwYBzwLuCj0FZAAb8CbgDeAPwI+FW7QmV3Xkvhy/pPgQ8UPtNa4PqIqGh3zJHAwJTS/YWmfsADwJuBPQrn/c22CsJ23gncuYP9PdVNwLkRUd2u7Se0/V0kqVexCJKkLlQYQXnbNtqPjogVEVHZru09ETGv8PyoiHgwItZFxEsRccW2zp9SakwpfSeldB/QUmSs09j+l/p3AH/cifOfC/xnSml5Sul54D9pK6QApgA54DsppYaU0pVAACfu7msjYjjwc+DslNKdKaUm4P1AM/BfO/g8T6eUrkgpvZhSakkp/QCoAg4pnPe1KWPb+n1FxMmFaWRrI+K7hUw7VDjnnyPi2xHxakQ8HRHHFtqXRcTKiDi33fEDI+LHEfFyRDwbEZdtLuwiojIi/qMwkvc0bQUaW732hxHxYkQ8HxH/3r6PtZdSWg6sAY5p1zxr63NKUm9gESRJ3UBK6a/ABv5WEAD8PW3/Og9tX+T/K6U0ADiYthGPznIa8Jvt7BsLLN6Jc70ReLTd9qOFts375qWUUrv987bav0uvTSktTSmNSinN3LwzpdScUpqeUvp4sZ+nMMWuCliynf37AfsAcyNiT9qm3l0G7Ak8BUze3rm3cnQh/2Da/sY3A0cCI2kbZftuu9Go/wYGAgcBJ9A2PfFDhX0fAaYCE4CJwLSt3uc62grBkYVjTgF2dM3SItpG4tpvD4+IAUV+LknqESyCJKn7+AltU7mItov6Tyu0ATQBIyNiz5RSXbvpXLslIg4Gciml7RUGg4D1O3HKfrRNQ9tsLdCvcG3P1vs27+/fCa8t1iC283kKX/SvB76cUtr6vTY7Dfi/QjF2GrAwpXRrYeTpOxQ/BfGZlNK1KaUW4BZgGPCVwijX3UAjbX/vSuAs4DMppfUppaW0jZB9sHCe99E2OrYspbQa+Hq7z7NPIeM/p5Q2pJRWAt8unG971tP2O2q/zVZtktTjWQRJUvdxE/DewjUZ7wUeTik9W9h3ATAaeDwiHoiIqZ30nqcBv93B/jXsXKFRB7QfNRgA1BWKhq33bd6/vhNeW6xtfp6IqAV+DdyfUvr66171N+2nDu4PLNu8o5Bz2bZetA0vtXu+qfD6rdv60TbClAeebbfvWWDItjJsddyBhde+WJh29yptC1zsvYNc/YFXt9pmqzZJ6vEsgiSpm0gpPUbbl9h3sOVUOFJKT6aUPkDbF9hvArdGRN9OeNsdXQ8EbVO2Ru/E+Ray5XSqwwttm/eNa7/iG20LICzshNcW63Wfp1B0/hJYTmEhhm2JiDxt09F+V2h6kbYRnM37o/12J3mFtlHAA9u1HQA8v60MhX2bLQMagD1TSoMKjwEppTeyfYex5ZTEw4ClKaV1u/oBJKk7sgiSpO7lJuATtK1u9rPNjRFxdkTslVJq5W//Kt+6rRNE2z1fagqbVRFRs1XxsPm4PsBRwB92kOdO2r74F3v+HwP/EhFDImJ/4FO0XZcCbRfZtwD/VDjHxYX2ezrhtcXa4vMUCptbaRt5Obfw+92e42i7LmlzQfAb2paQfm+03Zfnn4B9dzLPDhWmy/0U+GpE9I+IA4F/oW2VPAr7/ikihkbEG9hyiesXgbuB/4yIARFREREHR8QJbENEDKFthbz2Uy1PYMcjhZLUI1kESVL38hPavnjek1J6pV37qcDCiKijbZGEs7Zatrq9xbR9qR8C3FV4fuA2jjsRmJ1Sqt9Bnh8DpxWmixVz/qtpm1Y2H1hAW6FwNbStLEfbEtjn0FbInQ+8u9C+u68tSkrpYWBt/O1eQMfStrDAKcCrEVFXeBy/jZdvsTR24e9zJm33HFoFjAL+vDN5ivRx2hbNeJq25clvou2ePgD/Q9vf4FHgYdoWamjvHNoWeniMtqmAtwL7bed9/h74UeGeQZt9gMLfQJJ6k9hyoR1JUrmIiO8BC1JK3+vguK8BK1NK3+mSYCUWEacAH0spvXsnX/cYMK0wbbFXKUwJfBR4S2EBBSLiXcAHU0rvyzScJJWARZAklamIuBD4dWHalHYg2m7K+i8ppW9knUWStPssgiRJ6mQR8X3a7veztRtSShd1dR5J0pYsgiRJkiSVFRdGkCRJklRWclkH2BV77rlnGj58eNYxANiwYQN9+3bGrTrU29lXVCz7ioplX1Gx7CsqVm/qKw899NArKaW9trWvRxZBw4cP58EHH8w6BgCzZs1iypQpWcdQD2BfUbHsKyqWfUXFsq+oWL2pr0TEs9vb53Q4SZIkSWXFIkiSJElSWbEIkiRJklRWeuQ1QZIkSVI5a2pqYvny5dTX13fqeQcOHMiiRYs69ZylVlNTw9ChQ8nn80W/xiJIkiRJ6mGWL19O//79GT58OBHRaeddv349/fv377TzlVpKiVWrVrF8+XJGjBhR9OucDidJkiT1MPX19QwePLhTC6CeKCIYPHjwTo+IWQRJkiRJPVC5F0Cb7crvwSJIkiRJUrd16qmnMmjQIKZOndpp57QIkiRJktRt/eu//ivXX399p57TIkiSJEnSTlu6dCmHHnoo06dP57DDDmPatGls3LiRmTNnMmHCBMaOHcv5559PQ0MDAJdeeiljxoxh3LhxXHLJJUW/z0knndTpizW4OpwkSZKkXbJ48WJ++MMfMnnyZM4//3yuuOIKrr76ambOnMno0aM555xzuOqqq/jgBz/IbbfdxuOPP05E8OqrrwJw44038q1vfet15x05ciS33npryXJbBEmSJEk92Jd/vZDHXljXKedqaWmhsrKSMfsP4IvvemOHxw8bNozJkycDcPbZZ3P55ZczYsQIRo8eDcC5557LjBkzuPjii6mpqeGCCy5g6tSpr13fM336dKZPn94p2XeG0+EkSZIk7ZKtV2YbNGjQNo/L5XLMmTOHadOmcccdd3DqqacCbSNB48ePf91j2rRpJc3tSJAkSZLUgxUzYlOsnb1Z6nPPPcfs2bOZNGkSN910ExMnTuTqq69myZIljBw5kuuvv54TTjiBuro6Nm7cyGmnncbkyZM56KCDgF46EhQR10TEyohYsJ39ERFXRsSSiJgXEUeUMo8kSZKkznPIIYcwY8YMDjvsMNasWcMnP/lJrr32Ws4880zGjh1LRUUFF110EevXr2fq1KmMGzeO4447jiuuuKLo9zj++OM588wzmTlzJkOHDuWuu+7a7dylHgm6Dvgu8OPt7H8HMKrwOBq4qvBTkiRJUjeXy+W44YYbtmg76aSTmDt37hZt++23H3PmzNml97j33nt3Od/2lHQkKKX0J2D1Dg45A/hxanM/MCgi9itlps7W0tR
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 1008x576 with 1 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {
|
|||
|
|
"needs_background": "light"
|
|||
|
|
},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import matplotlib.pyplot as plt\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 设置参数\n",
|
|||
|
|
"d_model = 512\n",
|
|||
|
|
"i = np.arange(0, 512) # 维度索引从0到20\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 绘制图像\n",
|
|||
|
|
"plt.figure(figsize=(14, 8))\n",
|
|||
|
|
"for pos in pos_values:\n",
|
|||
|
|
" values = 1 / (10000 ** (2 * i / d_model))\n",
|
|||
|
|
" plt.plot(i, values, label=f'pos={pos}')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.title('i vs 1 / (10000^(2i/d_model))')\n",
|
|||
|
|
"plt.ylabel('feature_indey')\n",
|
|||
|
|
"plt.ylabel('1 / (10000^(2i/d_model))')\n",
|
|||
|
|
"plt.legend()\n",
|
|||
|
|
"plt.grid(True)\n",
|
|||
|
|
"plt.show()"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "185f6c76-cbe2-44f8-8ab4-0ec871c9693c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在这张图像上,横坐标是特征的位置编号i,纵坐标是$\\frac{1}{10000^{\\frac{2i}{d_{\\text{model}}}}}$,很显然特征编号越大频率越小。对三角函数来说,频率越小意味着当自变量移动1个单位时,函数值变化越小;频率越高,当自变量移动1个单位时,函数值变化就越剧烈。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 63,
|
|||
|
|
"id": "5c51210e-50cc-4590-a7d2-3ef52f07b695",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAGoCAYAAABbkkSYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAAsTAAALEwEAmpwYAACLr0lEQVR4nOzdd1xV9/3H8deXjbJkOBAUZ5yAiit772YnxiyNWU2atukeya9NM9q0Tdu0zWgzNYnNTprE7KEZalREwL0H4EI2srnf3x9cKHGicjl3vJ+PBw+55557zgfk3vP9nO/3+/kaay0iIiIiIiJy7IKcDkBERERERMRfKMESERERERHpJEqwREREREREOokSLBERERERkU6iBEtERERERKSTKMESERERERHpJEqwpEsZY641xnzsdBxHyxiz0hhzqree3xgzzxhzc9dFJCLivXTN8ez5dc0ROTAlWNLpjDEnGmMWGGMqjDGlxpj5xpjxANba2dbas4/yuPcaYxqNMdXtvn7eudF/63wzjTEPtN9mrR1prZ3nqXMeTvvzu38fLx7tsYwxpxpjXPv8Pt/ttGD9nDHmNGPMXPff+Ran4xEJVLrmeI6uOd7DtPijMabE/fVHY4w5yL59jDHvGGO2G2OsMSati8MNeCFOByD+xRgTA8wBbgdeBcKAk4D6TjrFK9ba6zrpWALbrbUph9rBGBNirW3qqoB8yF7gWeAl4NcOxyISkHTN8Tm65hy9W4FLgAzAAp8Am4F/HWBfF/Ah8AdgQRfFJ+2oB0s621AAa+1L1tpma22ttfZja20+gDFmujHm69ad3XdWvmuMWW+MKTfGPHawOzIHs+9dNWNMmvu4Ie7H84wx97vvalYZYz42xiS227/17me5MabAHeOtwLXAz9vfZTPGbDHGnOn+PtwY84j7DtF29/fh7udONcYUGmN+YozZbYzZYYy58SDxn2aMWd7u8SfGmCXtHn9ljLmk/fmNMefS0qif4o4vr90h+x/sZ+3g73O6+/V/M8aUAPe6f9aHjTHbjDG7jDH/MsZEtnvNz9w/43ZjzAz3739wu9//zfscv/3fwDD3z1xqjFlrjLmq3XMz3X8T77l/nkXGmEHtnh/Z7rW7jDG/Nsb0NsbUGGMS2u031hhTbIwJPZLfxaFYaxdba18ANnXWMUXkiOmag645gXDNAaYBf7HWFlpri4C/ANMPtKO1dpe19nFgyYGeF89TgiWdbR3QbIyZZYw5zxjTowOvuRAYD6QDVwHneCCua4AbgZ603OH8KYAxpj/wAfBPIAnIBHKttU8Cs4E/WWujrLXfOcAx7wYmuV+TAUwA7mn3fG8gFugL3AQ8dpDfxzfAEGNMovvDOB1INsZEuy8oWcBX7V9grf0Q+D0td1ejrLUZh/tZj9BEWhKHXsCDwEO0NGQygcHun+k3AO4L70+Bs4AhwJkdPYkxpjstd+H+4473auBxY8yIdrtdDfwO6AFscMeDMSYa+JSWu3TJ7rg+s9buBObR8rfU6nrgZWtt4wFiuMbd0DnYV7+O/jwi0uV0zfkfXXMOw8evOSOB9oltnnubeCElWNKprLWVwIm0dF8/BRSblnHAvQ7xsoesteXW2m3AXFo+UA/mqn0+iJI7GNpz1tp11tpaWoaRtJ7jGuBT993PRmttibU2t4PHvBa4z1q721pbTMsH8vXtnm90P99orX0fqAaO2/cg7piWACcD42j50JwPnEDLxXS9tbakgzEd6mc9kOR9fp+tF4jt1tp/uodp1NEyNOFH1tpSa20VLRfaq937XuU+5wpr7V7g3iOI9UJgi7X2OWttk7V2GfAGcGW7fd5y9xY10dIAyWz32p3W2r9Ya+ustVXW2kXu52YB1wEYY4KBqcALBwrAWvsfa23cIb62HcHPIyJdSNccXXMInGtOFFDR7nEFEGXMkfXAStfQHCzpdNba1bi7rY0xw4AXgUdo+cA5kJ3tvq+h5UPkYF7ddzx8Bz9bDnaOVGBjRw5wAMnA1naPt7q3tSrZZxz5oX62L4BTgUL392XAKbTMI/jiCOM6kt/nfuPhjTHTgYJ2m5KAbsDSdr9rAwS7v08Glrbbv/3v5HD6AxONMeXttoXw7QvT0fzfvQ38yxgzgJYGRoW1dvERxCUiPkLXnDa65hyeL19zqoGYdo9jgGprre3k80gnUA+WeJS1dg0wExjlwdPspeXDuFXvI3htATDoIM8d7kNrOy0f1q36ubcdjdaL3cnu77+g5WJ3Cge/2HnyQ7X9sfcAtcDIdnfYYq21rRedHbRceFrtO7zhUP8/BcAX+9y9i7LW3t6BGAuAgQcM3to6Wu6kXkfLHd4D3kmEtjLO1Yf40hBBER+ha06H6Zrje9eclbQMDW2V4d4mXkgJlnQq9+TRnxhjUtyPU2m5i/iNB0+bC5xsjOlnjIkFfnUEr50NnGmMucoYE2KMSTDGZLqf28VBPkzdXgLuMcYkmZZJvb+h5c7p0VhAy12vCcBia+1K3HfagC8P8ppdQJoxxqPvY2uti5ahN38zxvQEMMb0Nca0zlt4FZhujBlhjOkG/HafQ+QClxljupmWScg3tXtuDjDUGHO9MSbU/TXeGDO8A6HNAfoYY+4yLROio40xE9s9/zwtd7Uv4hAXO9tSxjnqEF8HHK5hjAkyxkQAoS0PTYQxJqwDcYtIJ9E1R9ccAuSa4z7+j92/i2TgJ7TcTDgg9/Up3P0w3P1YuogSLOlsVbR8QC8yxuyl5SK3gpYPAo+w1n4CvALk0zJsYM4RvHYbcD4t8ZXS8sHceofoGWCEe5z4fw/w8geAbPd5lwM57m1H8zPsdb9+pbW2wb15IbDVWrv7IC97zf1viTEm52jOewR+Qctk32+MMZW0TPQ9DsBa+wEtw3E+d+/z+T6v/RvQQMvFeRYtDQzcr60CzqZlbP12WoZm/JH/XRQOyv3as4DvuF+3Hjit3fPzaSlVm2OtPZIhJB11Mi13Wd+n5Q5qLeCzC5qK+Chdc47uZ9A1x/euOf8G3qXl/34F8J57GwDu3q+T2u1fS8uwQoA17sfSRYyGbopIZzPGWGCItXaDw3F8DvzHWvu0k3GIiIjn6Joj3kZFLkTELxljxgNjgYudjkVERPybrjnSnoYIiojfMcbMomVIyV3uYR0iIiIeoWuO7EtDBEVERERERDqJerBEREREREQ6SUDNwUpMTLRpaWlOhyEiIsdg6dKle6y1SU7HcbR0LRIR8Q8Hux4FVIKVlpZGdna202GIiMgxMMZ4ogRyl9G1SETEPxzseqQhgiIiIiIiIp1ECZaIiIiIiEgnUYIlIiIiIiLSSQJqDpaIiIg3amxspLCwkLq6OqdDEekSERERpKSkEBoa6nQoIp1OCZaIiIjDCgsLiY6OJi0tDWOM0+GIeJS1lpKSEgoLCxkwYIDT4Yh0Og0RFBERcVhdXR0JCQlKriQgGGNISEhQj634LSVYIiIiXkDJlQQS/b2LP1OCJSIiIiIi0kmUYImIiAgPPvggI0eOJD09nczMTBYtWgTAzTffzKpVqzp8nJkzZ5KUlERmZiaZmZnccMMNnRrn73//+289Pv744zvt2HfddRdffvklAKeeeirHHXdc28/x+uuvd9p5fEl9fT1Tpkxh8ODBTJw4kS1bthxwvxkzZtCzZ09GjRr1re0//elP+fzzz7sgUhHvoQRLREQkwC1cuJA5c+aQk5NDfn4+n376KampqQA8/fTTjBgx4oiON2XKFHJzc8nNzeX555/v1Fj3TbAWLFjQKcctKSnhm2++4eSTT27bNnv27Laf44orrvjW/s3NzZ1yXm/3zDPP0KNHDzZs2MCPfvQjfvGLXxxwv+nTp/Phhx/ut/373/8+Dz30kKfDFPEqSrBEREQC3I4dO0hMTCQ8PByAxMREkpOTgZaenOzsbACioqK4++67ycjIYNKkSezatatDx583bx4XXnhh2+M777yTmTNnApCWlsZvf/tbxo4dy+jRo1mzZg0A1dXV3HjjjYwePZr09HTeeOMNfvnLX1JbW0tmZibXXnttW0zQUpnuZz/7GaNGjWL06NG88sorbec+9dR
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 864x432 with 2 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {
|
|||
|
|
"needs_background": "light"
|
|||
|
|
},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import matplotlib.pyplot as plt\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 定义绘制正弦函数的函数\n",
|
|||
|
|
"def plot_sin_functions():\n",
|
|||
|
|
" y = np.linspace(0, 10, 1000) # 定义 y 轴范围\n",
|
|||
|
|
" \n",
|
|||
|
|
" fig, ays = plt.subplots(1, 2, figsize=(12, 6)) # 创建1y2子图\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 绘制频率为1的正弦函数\n",
|
|||
|
|
" frequency1 = 1\n",
|
|||
|
|
" y1 = np.sin(frequency1 * y)\n",
|
|||
|
|
" ays[0].plot(y, y1, label=f'Sin Function (Frequency = {frequency1})')\n",
|
|||
|
|
" ays[0].set_title(f'Sin Function with Frequency = {frequency1}')\n",
|
|||
|
|
" ays[0].set_ylabel('y')\n",
|
|||
|
|
" ays[0].set_ylabel('sin(y)')\n",
|
|||
|
|
" ays[0].legend()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 绘制频率为0.1的正弦函数\n",
|
|||
|
|
" frequency2 = 0.1\n",
|
|||
|
|
" y2 = np.sin(frequency2 * y)\n",
|
|||
|
|
" ays[1].plot(y, y2, label=f'Sin Function (Frequency = {frequency2})')\n",
|
|||
|
|
" ays[1].set_title(f'Sin Function with Frequency = {frequency2}')\n",
|
|||
|
|
" ays[1].set_ylabel('y')\n",
|
|||
|
|
" ays[1].set_ylabel('sin(y)')\n",
|
|||
|
|
" ays[1].legend()\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.tight_layout()\n",
|
|||
|
|
" plt.show()\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 绘制两个正弦函数在横向排列的子图中\n",
|
|||
|
|
"plot_sin_functions()\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "dd0c2077-2cd1-4eb4-a37d-29ced5a218ca",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"因此,这里你就可以发现非常有趣的事实了——特征编号小的特征,会随着特征值的变化而产生剧烈的变化,即便是相邻的两个样本,在最初的几个特征进行位置编码时,也会产生迥然不同的结果,但是随着特征编号的变大,特征值的变化带来的变化会越来越小,并且会小到呈现出一种单调性(只上升、或者只下降)。当一个信息被映射到这样的高维空间时,我们会认为这个信息的全局趋势和局部细节都被捕捉到了。其中,特征编号比较大的那些维度捕捉到的是样本与样本之间按顺序排列的全局趋势,而特征编号比较小的那些维度捕捉到的是样本与样本的位置之间本身的细节差异。因此,正余弦编码是一种能够同时捕捉到全局位置趋势和细节位置差异的编码方式。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>正余弦编码的意义③:通过独特的计算公式,我们可以让特征编号小的特征被投射到剧烈变化的维度上,<br><br>并且让特征编号大的特征被投射到轻微变化、甚至完全单调的维度上,从而可以让小编号特征去<br><br>捕捉样本之间的局部细节差异,让大编号特征去捕捉样本之间按顺序排列的全局趋势**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "622f0950-291e-4108-ae0f-b1ff987f0f9c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"从这个角度来看,其实我们只需要设置一个随着i的增长变得越来越小的公式就可以了,实际公式本身其实并不一定非要是$\\frac{1}{10000^{\\frac{2i}{d_{\\text{model}}}}}$。但这个公式考虑了i相对于特征总量的相对位置,并且还使用了指数函数,它是能够最大程度放大i的影响的公式之一,因此我们使用它可以说是出于一种数学上的便利。当然,你也可以使用其他的公式,只要能够保证i的增长会让频率本身变得越来越小即可。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5c8c0f27-0846-4ef1-8c63-977231bd6534",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"现在我们可以来看一个具体例子,通过绘制图像来让大家清晰地看到,正余弦编码是如何帮助我们捕捉局部细节和总体趋势的。假设现在有30个样本(索引为1-30),每个样本有4个特征。我们将使用正弦函数编码偶数维度,使用余弦函数编码奇数维度,进行正余弦编码的具体计算。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b7761492-e455-4459-88a9-a54f7339c612",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 正弦位置编码(Sinusoidal Positional Encoding)\n",
|
|||
|
|
"$$PE_{(pos, 2i)} = \\sin \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 余弦位置编码(Cosine Positional Encoding)\n",
|
|||
|
|
"$$ PE_{(pos, 2i+1)} = \\cos \\left( \\frac{pos}{10000^{\\frac{2i}{d_{\\text{model}}}}} \\right) $$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "8b22bfcd-35bb-40f7-b84d-e4bfcb6defa8",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"我请GPT帮我完成了相应的计算流程,最终生成了如下的表单👇"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 24,
|
|||
|
|
"id": "7cdc9e73-583b-4b0a-a3da-f101adeb464e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"position = pd.read_csv(r\"D:\\pythonwork\\2024DL\\Position_Encoding_for_30_Samples.csv\")"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 25,
|
|||
|
|
"id": "c5fe942e-560d-4ee0-ab8d-c6d67408b025",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/html": [
|
|||
|
|
"<div>\n",
|
|||
|
|
"<style scoped>\n",
|
|||
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
|
" vertical-align: middle;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe tbody tr th {\n",
|
|||
|
|
" vertical-align: top;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe thead th {\n",
|
|||
|
|
" text-align: right;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"</style>\n",
|
|||
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
|
" <thead>\n",
|
|||
|
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
|
" <th></th>\n",
|
|||
|
|
" <th>维度0</th>\n",
|
|||
|
|
" <th>维度1</th>\n",
|
|||
|
|
" <th>维度2</th>\n",
|
|||
|
|
" <th>维度3</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </thead>\n",
|
|||
|
|
" <tbody>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>0</th>\n",
|
|||
|
|
" <td>0.841471</td>\n",
|
|||
|
|
" <td>0.540302</td>\n",
|
|||
|
|
" <td>0.010000</td>\n",
|
|||
|
|
" <td>0.999950</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>1</th>\n",
|
|||
|
|
" <td>0.909297</td>\n",
|
|||
|
|
" <td>-0.416147</td>\n",
|
|||
|
|
" <td>0.019999</td>\n",
|
|||
|
|
" <td>0.999800</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>2</th>\n",
|
|||
|
|
" <td>0.141120</td>\n",
|
|||
|
|
" <td>-0.989992</td>\n",
|
|||
|
|
" <td>0.029996</td>\n",
|
|||
|
|
" <td>0.999550</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>3</th>\n",
|
|||
|
|
" <td>-0.756802</td>\n",
|
|||
|
|
" <td>-0.653644</td>\n",
|
|||
|
|
" <td>0.039989</td>\n",
|
|||
|
|
" <td>0.999200</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>4</th>\n",
|
|||
|
|
" <td>-0.958924</td>\n",
|
|||
|
|
" <td>0.283662</td>\n",
|
|||
|
|
" <td>0.049979</td>\n",
|
|||
|
|
" <td>0.998750</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>5</th>\n",
|
|||
|
|
" <td>-0.279415</td>\n",
|
|||
|
|
" <td>0.960170</td>\n",
|
|||
|
|
" <td>0.059964</td>\n",
|
|||
|
|
" <td>0.998201</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>6</th>\n",
|
|||
|
|
" <td>0.656987</td>\n",
|
|||
|
|
" <td>0.753902</td>\n",
|
|||
|
|
" <td>0.069943</td>\n",
|
|||
|
|
" <td>0.997551</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>7</th>\n",
|
|||
|
|
" <td>0.989358</td>\n",
|
|||
|
|
" <td>-0.145500</td>\n",
|
|||
|
|
" <td>0.079915</td>\n",
|
|||
|
|
" <td>0.996802</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>8</th>\n",
|
|||
|
|
" <td>0.412118</td>\n",
|
|||
|
|
" <td>-0.911130</td>\n",
|
|||
|
|
" <td>0.089879</td>\n",
|
|||
|
|
" <td>0.995953</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>9</th>\n",
|
|||
|
|
" <td>-0.544021</td>\n",
|
|||
|
|
" <td>-0.839072</td>\n",
|
|||
|
|
" <td>0.099833</td>\n",
|
|||
|
|
" <td>0.995004</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>10</th>\n",
|
|||
|
|
" <td>-0.999990</td>\n",
|
|||
|
|
" <td>0.004426</td>\n",
|
|||
|
|
" <td>0.109778</td>\n",
|
|||
|
|
" <td>0.993956</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>11</th>\n",
|
|||
|
|
" <td>-0.536573</td>\n",
|
|||
|
|
" <td>0.843854</td>\n",
|
|||
|
|
" <td>0.119712</td>\n",
|
|||
|
|
" <td>0.992809</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>12</th>\n",
|
|||
|
|
" <td>0.420167</td>\n",
|
|||
|
|
" <td>0.907447</td>\n",
|
|||
|
|
" <td>0.129634</td>\n",
|
|||
|
|
" <td>0.991562</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>13</th>\n",
|
|||
|
|
" <td>0.990607</td>\n",
|
|||
|
|
" <td>0.136737</td>\n",
|
|||
|
|
" <td>0.139543</td>\n",
|
|||
|
|
" <td>0.990216</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>14</th>\n",
|
|||
|
|
" <td>0.650288</td>\n",
|
|||
|
|
" <td>-0.759688</td>\n",
|
|||
|
|
" <td>0.149438</td>\n",
|
|||
|
|
" <td>0.988771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>15</th>\n",
|
|||
|
|
" <td>-0.287903</td>\n",
|
|||
|
|
" <td>-0.957659</td>\n",
|
|||
|
|
" <td>0.159318</td>\n",
|
|||
|
|
" <td>0.987227</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>16</th>\n",
|
|||
|
|
" <td>-0.961397</td>\n",
|
|||
|
|
" <td>-0.275163</td>\n",
|
|||
|
|
" <td>0.169182</td>\n",
|
|||
|
|
" <td>0.985585</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>17</th>\n",
|
|||
|
|
" <td>-0.750987</td>\n",
|
|||
|
|
" <td>0.660317</td>\n",
|
|||
|
|
" <td>0.179030</td>\n",
|
|||
|
|
" <td>0.983844</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>18</th>\n",
|
|||
|
|
" <td>0.149877</td>\n",
|
|||
|
|
" <td>0.988705</td>\n",
|
|||
|
|
" <td>0.188859</td>\n",
|
|||
|
|
" <td>0.982004</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>19</th>\n",
|
|||
|
|
" <td>0.912945</td>\n",
|
|||
|
|
" <td>0.408082</td>\n",
|
|||
|
|
" <td>0.198669</td>\n",
|
|||
|
|
" <td>0.980067</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>20</th>\n",
|
|||
|
|
" <td>0.836656</td>\n",
|
|||
|
|
" <td>-0.547729</td>\n",
|
|||
|
|
" <td>0.208460</td>\n",
|
|||
|
|
" <td>0.978031</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>21</th>\n",
|
|||
|
|
" <td>-0.008851</td>\n",
|
|||
|
|
" <td>-0.999961</td>\n",
|
|||
|
|
" <td>0.218230</td>\n",
|
|||
|
|
" <td>0.975897</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>22</th>\n",
|
|||
|
|
" <td>-0.846220</td>\n",
|
|||
|
|
" <td>-0.532833</td>\n",
|
|||
|
|
" <td>0.227978</td>\n",
|
|||
|
|
" <td>0.973666</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>23</th>\n",
|
|||
|
|
" <td>-0.905578</td>\n",
|
|||
|
|
" <td>0.424179</td>\n",
|
|||
|
|
" <td>0.237703</td>\n",
|
|||
|
|
" <td>0.971338</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>24</th>\n",
|
|||
|
|
" <td>-0.132352</td>\n",
|
|||
|
|
" <td>0.991203</td>\n",
|
|||
|
|
" <td>0.247404</td>\n",
|
|||
|
|
" <td>0.968912</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>25</th>\n",
|
|||
|
|
" <td>0.762558</td>\n",
|
|||
|
|
" <td>0.646919</td>\n",
|
|||
|
|
" <td>0.257081</td>\n",
|
|||
|
|
" <td>0.966390</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>26</th>\n",
|
|||
|
|
" <td>0.956376</td>\n",
|
|||
|
|
" <td>-0.292139</td>\n",
|
|||
|
|
" <td>0.266731</td>\n",
|
|||
|
|
" <td>0.963771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>27</th>\n",
|
|||
|
|
" <td>0.270906</td>\n",
|
|||
|
|
" <td>-0.962606</td>\n",
|
|||
|
|
" <td>0.276356</td>\n",
|
|||
|
|
" <td>0.961055</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>28</th>\n",
|
|||
|
|
" <td>-0.663634</td>\n",
|
|||
|
|
" <td>-0.748058</td>\n",
|
|||
|
|
" <td>0.285952</td>\n",
|
|||
|
|
" <td>0.958244</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>29</th>\n",
|
|||
|
|
" <td>-0.988032</td>\n",
|
|||
|
|
" <td>0.154251</td>\n",
|
|||
|
|
" <td>0.295520</td>\n",
|
|||
|
|
" <td>0.955336</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </tbody>\n",
|
|||
|
|
"</table>\n",
|
|||
|
|
"</div>"
|
|||
|
|
],
|
|||
|
|
"text/plain": [
|
|||
|
|
" 维度0 维度1 维度2 维度3\n",
|
|||
|
|
"0 0.841471 0.540302 0.010000 0.999950\n",
|
|||
|
|
"1 0.909297 -0.416147 0.019999 0.999800\n",
|
|||
|
|
"2 0.141120 -0.989992 0.029996 0.999550\n",
|
|||
|
|
"3 -0.756802 -0.653644 0.039989 0.999200\n",
|
|||
|
|
"4 -0.958924 0.283662 0.049979 0.998750\n",
|
|||
|
|
"5 -0.279415 0.960170 0.059964 0.998201\n",
|
|||
|
|
"6 0.656987 0.753902 0.069943 0.997551\n",
|
|||
|
|
"7 0.989358 -0.145500 0.079915 0.996802\n",
|
|||
|
|
"8 0.412118 -0.911130 0.089879 0.995953\n",
|
|||
|
|
"9 -0.544021 -0.839072 0.099833 0.995004\n",
|
|||
|
|
"10 -0.999990 0.004426 0.109778 0.993956\n",
|
|||
|
|
"11 -0.536573 0.843854 0.119712 0.992809\n",
|
|||
|
|
"12 0.420167 0.907447 0.129634 0.991562\n",
|
|||
|
|
"13 0.990607 0.136737 0.139543 0.990216\n",
|
|||
|
|
"14 0.650288 -0.759688 0.149438 0.988771\n",
|
|||
|
|
"15 -0.287903 -0.957659 0.159318 0.987227\n",
|
|||
|
|
"16 -0.961397 -0.275163 0.169182 0.985585\n",
|
|||
|
|
"17 -0.750987 0.660317 0.179030 0.983844\n",
|
|||
|
|
"18 0.149877 0.988705 0.188859 0.982004\n",
|
|||
|
|
"19 0.912945 0.408082 0.198669 0.980067\n",
|
|||
|
|
"20 0.836656 -0.547729 0.208460 0.978031\n",
|
|||
|
|
"21 -0.008851 -0.999961 0.218230 0.975897\n",
|
|||
|
|
"22 -0.846220 -0.532833 0.227978 0.973666\n",
|
|||
|
|
"23 -0.905578 0.424179 0.237703 0.971338\n",
|
|||
|
|
"24 -0.132352 0.991203 0.247404 0.968912\n",
|
|||
|
|
"25 0.762558 0.646919 0.257081 0.966390\n",
|
|||
|
|
"26 0.956376 -0.292139 0.266731 0.963771\n",
|
|||
|
|
"27 0.270906 -0.962606 0.276356 0.961055\n",
|
|||
|
|
"28 -0.663634 -0.748058 0.285952 0.958244\n",
|
|||
|
|
"29 -0.988032 0.154251 0.295520 0.955336"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 25,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"position"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 64,
|
|||
|
|
"id": "f9998144-59f0-4771-b046-cccd53f7acb6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"维度0 0.990607\n",
|
|||
|
|
"维度1 0.991203\n",
|
|||
|
|
"维度2 0.295520\n",
|
|||
|
|
"维度3 0.999950\n",
|
|||
|
|
"dtype: float64"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 64,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"position.may()"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 40,
|
|||
|
|
"id": "51285d15-f79e-4d82-a706-6c525ce83ccb",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"维度0 -0.999990\n",
|
|||
|
|
"维度1 -0.999961\n",
|
|||
|
|
"维度2 0.010000\n",
|
|||
|
|
"维度3 0.955336\n",
|
|||
|
|
"dtype: float64"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 40,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"position.min()"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "74edeef0-0803-4ed7-80d4-3851ce5b1dd3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在这个表单中,有4个特征全部进行正余弦编码后的结果,很显然,特征编号较小的特征(1和2特征)波动很大,但是特征编号相对较大的特征(3和4)波动就不是那么大。我们只计算了4个特征,是因为我们要绘制的3d图像只能够容纳3个特征,事实上当特征数量变得很多时,大部分特征都会呈现像特征3和特征4一样这样平缓的变化方式。\n",
|
|||
|
|
"\n",
|
|||
|
|
"为了展现局部特征和整体趋势的捕捉,我们使用特征2、3、4来绘制了3D图像,在图像中,我们可以明显地看到局部细节和总体趋势的捕捉👇"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 5,
|
|||
|
|
"id": "70719d5a-50e1-4e43-9978-93b1094fee54",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/html": [
|
|||
|
|
" <script type=\"text/javascript\">\n",
|
|||
|
|
" window.PlotlyConfig = {MathJaxConfig: 'local'};\n",
|
|||
|
|
" if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}\n",
|
|||
|
|
" if (typeof require !== 'undefined') {\n",
|
|||
|
|
" require.undef(\"plotly\");\n",
|
|||
|
|
" define('plotly', function(require, exports, module) {\n",
|
|||
|
|
" /**\n",
|
|||
|
|
"* plotly.js v2.27.0\n",
|
|||
|
|
"* Copyright 2012-2023, Plotly, Inc.\n",
|
|||
|
|
"* All rights reserved.\n",
|
|||
|
|
"* Licensed under the MIT license\n",
|
|||
|
|
"*/\n",
|
|||
|
|
"/*! For license information please see plotly.min.js.LICENSE.txt */\n",
|
|||
|
|
"!function(t,e){\"object\"==typeof exports&&\"object\"==typeof module?module.exports=e():\"function\"==typeof define&&define.amd?define([],e):\"object\"==typeof exports?exports.Plotly=e():t.Plotly=e()}(self,(function(){return function(){var t={98847:function(t,e,r){\"use strict\";var n=r(71828),i={\"X,X div\":'direction:ltr;font-family:\"Open Sans\",verdana,arial,sans-serif;margin:0;padding:0;',\"X input,X button\":'font-family:\"Open Sans\",verdana,arial,sans-serif;',\"X input:focus,X button:focus\":\"outline:none;\",\"X a\":\"text-decoration:none;\",\"X a:hover\":\"text-decoration:none;\",\"X .crisp\":\"shape-rendering:crispEdges;\",\"X .user-select-none\":\"-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;-o-user-select:none;user-select:none;\",\"X svg\":\"overflow:hidden;\",\"X svg a\":\"fill:#447adb;\",\"X svg a:hover\":\"fill:#3c6dc5;\",\"X .main-svg\":\"position:absolute;top:0;left:0;pointer-events:none;\",\"X .main-svg .draglayer\":\"pointer-events:all;\",\"X .cursor-default\":\"cursor:default;\",\"X .cursor-pointer\":\"cursor:pointer;\",\"X .cursor-crosshair\":\"cursor:crosshair;\",\"X .cursor-move\":\"cursor:move;\",\"X .cursor-col-resize\":\"cursor:col-resize;\",\"X .cursor-row-resize\":\"cursor:row-resize;\",\"X .cursor-ns-resize\":\"cursor:ns-resize;\",\"X .cursor-ew-resize\":\"cursor:ew-resize;\",\"X .cursor-sw-resize\":\"cursor:sw-resize;\",\"X .cursor-s-resize\":\"cursor:s-resize;\",\"X .cursor-se-resize\":\"cursor:se-resize;\",\"X .cursor-w-resize\":\"cursor:w-resize;\",\"X .cursor-e-resize\":\"cursor:e-resize;\",\"X .cursor-nw-resize\":\"cursor:nw-resize;\",\"X .cursor-n-resize\":\"cursor:n-resize;\",\"X .cursor-ne-resize\":\"cursor:ne-resize;\",\"X .cursor-grab\":\"cursor:-webkit-grab;cursor:grab;\",\"X .modebar\":\"position:absolute;top:2px;right:2px;\",\"X .ease-bg\":\"-webkit-transition:background-color .3s ease 0s;-moz-transition:background-color .3s ease 0s;-ms-transition:background-color .3s ease 0s;-o-transition:background-color .3s ease 0s;transition:background-color .3s ease 0s;\",\"X .modebar--hover>:not(.watermark)\":\"opacity:0;-webkit-transition:opacity .3s ease 0s;-moz-transition:opacity .3s ease 0s;-ms-transition:opacity .3s ease 0s;-o-transition:opacity .3s ease 0s;transition:opacity .3s ease 0s;\",\"X:hover .modebar--hover .modebar-group\":\"opacity:1;\",\"X .modebar-group\":\"float:left;display:inline-block;box-sizing:border-box;padding-left:8px;position:relative;vertical-align:middle;white-space:nowrap;\",\"X .modebar-btn\":\"position:relative;font-size:16px;padding:3px 4px;height:22px;cursor:pointer;line-height:normal;box-sizing:border-box;\",\"X .modebar-btn svg\":\"position:relative;top:2px;\",\"X .modebar.vertical\":\"display:flex;flex-direction:column;flex-wrap:wrap;align-content:flex-end;max-height:100%;\",\"X .modebar.vertical svg\":\"top:-1px;\",\"X .modebar.vertical .modebar-group\":\"display:block;float:none;padding-left:0px;padding-bottom:8px;\",\"X .modebar.vertical .modebar-group .modebar-btn\":\"display:block;text-align:center;\",\"X [data-title]:before,X [data-title]:after\":\"position:absolute;-webkit-transform:translate3d(0, 0, 0);-moz-transform:translate3d(0, 0, 0);-ms-transform:translate3d(0, 0, 0);-o-transform:translate3d(0, 0, 0);transform:translate3d(0, 0, 0);display:none;opacity:0;z-index:1001;pointer-events:none;top:110%;right:50%;\",\"X [data-title]:hover:before,X [data-title]:hover:after\":\"display:block;opacity:1;\",\"X [data-title]:before\":'content:\"\";position:absolute;background:rgba(0,0,0,0);border:6px solid rgba(0,0,0,0);z-index:1002;margin-top:-12px;border-bottom-color:#69738a;margin-right:-6px;',\"X [data-title]:after\":\"content:attr(data-title);background:#69738a;color:#fff;padding:8px 10px;font-size:12px;line-height:12px;white-space:nowrap;margin-right:-18px;border-radius:2px;\",\"X .vertical [data-title]:before,X .vertical [data-title]:after\":\"top:0%;right:200%;\",\"X .vertical [data-title]:before\":\"border:6px solid rgba(0,0,0,0);border-left-color:#69738a;margin-top:8px;margin-right:-30px;\",Y:'font-family:\"Open
|
|||
|
|
" });\n",
|
|||
|
|
" require(['plotly'], function(Plotly) {\n",
|
|||
|
|
" window._Plotly = Plotly;\n",
|
|||
|
|
" });\n",
|
|||
|
|
" }\n",
|
|||
|
|
" </script>\n",
|
|||
|
|
" "
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"application/vnd.plotly.v1+json": {
|
|||
|
|
"config": {
|
|||
|
|
"plotlyServerURL": "https://plot.ly"
|
|||
|
|
},
|
|||
|
|
"data": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"color": [
|
|||
|
|
0.9999500004166652,
|
|||
|
|
0.9998000066665778,
|
|||
|
|
0.9995500337489877,
|
|||
|
|
0.999200106660978,
|
|||
|
|
0.9987502603949664,
|
|||
|
|
0.9982005399352042,
|
|||
|
|
0.9975510002532796,
|
|||
|
|
0.9968017063026194,
|
|||
|
|
0.9959527330119944,
|
|||
|
|
0.9950041652780258,
|
|||
|
|
0.9939560979566968,
|
|||
|
|
0.9928086358538664,
|
|||
|
|
0.991561893714788,
|
|||
|
|
0.9902159962126372,
|
|||
|
|
0.9887710779360422,
|
|||
|
|
0.9872272833756268,
|
|||
|
|
0.9855847669095608,
|
|||
|
|
0.9838436927881214,
|
|||
|
|
0.9820042351172704,
|
|||
|
|
0.9800665778412416,
|
|||
|
|
0.9780309147241484,
|
|||
|
|
0.9758974493306056,
|
|||
|
|
0.9736663950053748,
|
|||
|
|
0.9713379748520296,
|
|||
|
|
0.9689124217106448,
|
|||
|
|
0.9663899781345132,
|
|||
|
|
0.9637708963658904,
|
|||
|
|
0.9610554383107708,
|
|||
|
|
0.9582438755126972,
|
|||
|
|
0.955336489125606
|
|||
|
|
],
|
|||
|
|
"colorscale": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#440154"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#482878"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#3e4989"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#31688e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#26828e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#1f9e89"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#35b779"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#6ece58"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#b5de2b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#fde725"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"opacity": 1,
|
|||
|
|
"size": 6
|
|||
|
|
},
|
|||
|
|
"mode": "markers",
|
|||
|
|
"text": [
|
|||
|
|
"0",
|
|||
|
|
"1",
|
|||
|
|
"2",
|
|||
|
|
"3",
|
|||
|
|
"4",
|
|||
|
|
"5",
|
|||
|
|
"6",
|
|||
|
|
"7",
|
|||
|
|
"8",
|
|||
|
|
"9",
|
|||
|
|
"10",
|
|||
|
|
"11",
|
|||
|
|
"12",
|
|||
|
|
"13",
|
|||
|
|
"14",
|
|||
|
|
"15",
|
|||
|
|
"16",
|
|||
|
|
"17",
|
|||
|
|
"18",
|
|||
|
|
"19",
|
|||
|
|
"20",
|
|||
|
|
"21",
|
|||
|
|
"22",
|
|||
|
|
"23",
|
|||
|
|
"24",
|
|||
|
|
"25",
|
|||
|
|
"26",
|
|||
|
|
"27",
|
|||
|
|
"28",
|
|||
|
|
"29"
|
|||
|
|
],
|
|||
|
|
"textposition": "top center",
|
|||
|
|
"type": "scatter3d",
|
|||
|
|
"x": [
|
|||
|
|
0.5403023058681398,
|
|||
|
|
-0.4161468365471424,
|
|||
|
|
-0.9899924966004454,
|
|||
|
|
-0.6536436208636119,
|
|||
|
|
0.2836621854632262,
|
|||
|
|
0.960170286650366,
|
|||
|
|
0.7539022543433046,
|
|||
|
|
-0.1455000338086135,
|
|||
|
|
-0.9111302618846768,
|
|||
|
|
-0.8390715290764524,
|
|||
|
|
0.0044256979880507,
|
|||
|
|
0.8438539587324921,
|
|||
|
|
0.9074467814501962,
|
|||
|
|
0.1367372182078336,
|
|||
|
|
-0.7596879128588213,
|
|||
|
|
-0.9576594803233848,
|
|||
|
|
-0.2751633380515969,
|
|||
|
|
0.6603167082440802,
|
|||
|
|
0.9887046181866692,
|
|||
|
|
0.4080820618133919,
|
|||
|
|
-0.5477292602242684,
|
|||
|
|
-0.9999608263946372,
|
|||
|
|
-0.5328330203333975,
|
|||
|
|
0.424179007336997,
|
|||
|
|
0.9912028118634736,
|
|||
|
|
0.6469193223286404,
|
|||
|
|
-0.2921388087338362,
|
|||
|
|
-0.9626058663135666,
|
|||
|
|
-0.7480575296890003,
|
|||
|
|
0.154251449887584
|
|||
|
|
],
|
|||
|
|
"y": [
|
|||
|
|
0.0099998333341666,
|
|||
|
|
0.019998666693333,
|
|||
|
|
0.0299955002024956,
|
|||
|
|
0.0399893341866341,
|
|||
|
|
0.0499791692706783,
|
|||
|
|
0.0599640064794445,
|
|||
|
|
0.0699428473375327,
|
|||
|
|
0.0799146939691727,
|
|||
|
|
0.089878549198011,
|
|||
|
|
0.0998334166468281,
|
|||
|
|
0.1097783008371748,
|
|||
|
|
0.1197122072889193,
|
|||
|
|
0.1296341426196948,
|
|||
|
|
0.1395431146442365,
|
|||
|
|
0.1494381324735992,
|
|||
|
|
0.1593182066142459,
|
|||
|
|
0.169182349066996,
|
|||
|
|
0.1790295734258241,
|
|||
|
|
0.1888588949765005,
|
|||
|
|
0.1986693307950612,
|
|||
|
|
0.2084598998460995,
|
|||
|
|
0.2182296230808693,
|
|||
|
|
0.2279775235351884,
|
|||
|
|
0.2377026264271345,
|
|||
|
|
0.2474039592545229,
|
|||
|
|
0.2570805518921551,
|
|||
|
|
0.2667314366888311,
|
|||
|
|
0.2763556485641137,
|
|||
|
|
0.2859522251048355,
|
|||
|
|
0.2955202066613395
|
|||
|
|
],
|
|||
|
|
"z": [
|
|||
|
|
0.9999500004166652,
|
|||
|
|
0.9998000066665778,
|
|||
|
|
0.9995500337489877,
|
|||
|
|
0.999200106660978,
|
|||
|
|
0.9987502603949664,
|
|||
|
|
0.9982005399352042,
|
|||
|
|
0.9975510002532796,
|
|||
|
|
0.9968017063026194,
|
|||
|
|
0.9959527330119944,
|
|||
|
|
0.9950041652780258,
|
|||
|
|
0.9939560979566968,
|
|||
|
|
0.9928086358538664,
|
|||
|
|
0.991561893714788,
|
|||
|
|
0.9902159962126372,
|
|||
|
|
0.9887710779360422,
|
|||
|
|
0.9872272833756268,
|
|||
|
|
0.9855847669095608,
|
|||
|
|
0.9838436927881214,
|
|||
|
|
0.9820042351172704,
|
|||
|
|
0.9800665778412416,
|
|||
|
|
0.9780309147241484,
|
|||
|
|
0.9758974493306056,
|
|||
|
|
0.9736663950053748,
|
|||
|
|
0.9713379748520296,
|
|||
|
|
0.9689124217106448,
|
|||
|
|
0.9663899781345132,
|
|||
|
|
0.9637708963658904,
|
|||
|
|
0.9610554383107708,
|
|||
|
|
0.9582438755126972,
|
|||
|
|
0.955336489125606
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"layout": {
|
|||
|
|
"height": 800,
|
|||
|
|
"scene": {
|
|||
|
|
"aspectmode": "auto",
|
|||
|
|
"aspectratio": {
|
|||
|
|
"x": 1,
|
|||
|
|
"y": 1,
|
|||
|
|
"z": 1
|
|||
|
|
},
|
|||
|
|
"camera": {
|
|||
|
|
"center": {
|
|||
|
|
"x": 0,
|
|||
|
|
"y": 0,
|
|||
|
|
"z": 0
|
|||
|
|
},
|
|||
|
|
"eye": {
|
|||
|
|
"x": 1.295312601789196,
|
|||
|
|
"y": 1.4932026924031352,
|
|||
|
|
"z": 1.6431386191180295
|
|||
|
|
},
|
|||
|
|
"projection": {
|
|||
|
|
"type": "perspective"
|
|||
|
|
},
|
|||
|
|
"up": {
|
|||
|
|
"x": 0,
|
|||
|
|
"y": 0,
|
|||
|
|
"z": 1
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"xaxis": {
|
|||
|
|
"range": [
|
|||
|
|
-1.5,
|
|||
|
|
1.5
|
|||
|
|
],
|
|||
|
|
"title": {
|
|||
|
|
"text": "维度1"
|
|||
|
|
},
|
|||
|
|
"type": "linear"
|
|||
|
|
},
|
|||
|
|
"yaxis": {
|
|||
|
|
"range": [
|
|||
|
|
0,
|
|||
|
|
0.3
|
|||
|
|
],
|
|||
|
|
"title": {
|
|||
|
|
"text": "维度2"
|
|||
|
|
},
|
|||
|
|
"type": "linear"
|
|||
|
|
},
|
|||
|
|
"zaxis": {
|
|||
|
|
"range": [
|
|||
|
|
0.95,
|
|||
|
|
1
|
|||
|
|
],
|
|||
|
|
"title": {
|
|||
|
|
"text": "维度3"
|
|||
|
|
},
|
|||
|
|
"type": "linear"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"template": {
|
|||
|
|
"data": {
|
|||
|
|
"bar": [
|
|||
|
|
{
|
|||
|
|
"error_x": {
|
|||
|
|
"color": "#2a3f5f"
|
|||
|
|
},
|
|||
|
|
"error_y": {
|
|||
|
|
"color": "#2a3f5f"
|
|||
|
|
},
|
|||
|
|
"marker": {
|
|||
|
|
"line": {
|
|||
|
|
"color": "#E5ECF6",
|
|||
|
|
"width": 0.5
|
|||
|
|
},
|
|||
|
|
"pattern": {
|
|||
|
|
"fillmode": "overlay",
|
|||
|
|
"size": 10,
|
|||
|
|
"solidity": 0.2
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "bar"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"barpolar": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"line": {
|
|||
|
|
"color": "#E5ECF6",
|
|||
|
|
"width": 0.5
|
|||
|
|
},
|
|||
|
|
"pattern": {
|
|||
|
|
"fillmode": "overlay",
|
|||
|
|
"size": 10,
|
|||
|
|
"solidity": 0.2
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "barpolar"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"carpet": [
|
|||
|
|
{
|
|||
|
|
"aaxis": {
|
|||
|
|
"endlinecolor": "#2a3f5f",
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"minorgridcolor": "white",
|
|||
|
|
"startlinecolor": "#2a3f5f"
|
|||
|
|
},
|
|||
|
|
"baxis": {
|
|||
|
|
"endlinecolor": "#2a3f5f",
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"minorgridcolor": "white",
|
|||
|
|
"startlinecolor": "#2a3f5f"
|
|||
|
|
},
|
|||
|
|
"type": "carpet"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"choropleth": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"type": "choropleth"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"contour": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"colorscale": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"type": "contour"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"contourcarpet": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"type": "contourcarpet"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"heatmap": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"colorscale": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"type": "heatmap"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"heatmapgl": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"colorscale": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"type": "heatmapgl"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"histogram": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"pattern": {
|
|||
|
|
"fillmode": "overlay",
|
|||
|
|
"size": 10,
|
|||
|
|
"solidity": 0.2
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "histogram"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"histogram2d": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"colorscale": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"type": "histogram2d"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"histogram2dcontour": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"colorscale": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"type": "histogram2dcontour"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"mesh3d": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"type": "mesh3d"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"parcoords": [
|
|||
|
|
{
|
|||
|
|
"line": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "parcoords"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"pie": [
|
|||
|
|
{
|
|||
|
|
"automargin": true,
|
|||
|
|
"type": "pie"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scatter": [
|
|||
|
|
{
|
|||
|
|
"fillpattern": {
|
|||
|
|
"fillmode": "overlay",
|
|||
|
|
"size": 10,
|
|||
|
|
"solidity": 0.2
|
|||
|
|
},
|
|||
|
|
"type": "scatter"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scatter3d": [
|
|||
|
|
{
|
|||
|
|
"line": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scatter3d"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scattercarpet": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scattercarpet"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scattergeo": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scattergeo"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scattergl": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scattergl"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scattermapbox": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scattermapbox"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scatterpolar": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scatterpolar"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scatterpolargl": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scatterpolargl"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"scatterternary": [
|
|||
|
|
{
|
|||
|
|
"marker": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "scatterternary"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"surface": [
|
|||
|
|
{
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"colorscale": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"type": "surface"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"table": [
|
|||
|
|
{
|
|||
|
|
"cells": {
|
|||
|
|
"fill": {
|
|||
|
|
"color": "#EBF0F8"
|
|||
|
|
},
|
|||
|
|
"line": {
|
|||
|
|
"color": "white"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"header": {
|
|||
|
|
"fill": {
|
|||
|
|
"color": "#C8D4E3"
|
|||
|
|
},
|
|||
|
|
"line": {
|
|||
|
|
"color": "white"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"type": "table"
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"layout": {
|
|||
|
|
"annotationdefaults": {
|
|||
|
|
"arrowcolor": "#2a3f5f",
|
|||
|
|
"arrowhead": 0,
|
|||
|
|
"arrowwidth": 1
|
|||
|
|
},
|
|||
|
|
"autotypenumbers": "strict",
|
|||
|
|
"coloraxis": {
|
|||
|
|
"colorbar": {
|
|||
|
|
"outlinewidth": 0,
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"colorscale": {
|
|||
|
|
"diverging": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#8e0152"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1,
|
|||
|
|
"#c51b7d"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2,
|
|||
|
|
"#de77ae"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3,
|
|||
|
|
"#f1b6da"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4,
|
|||
|
|
"#fde0ef"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5,
|
|||
|
|
"#f7f7f7"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6,
|
|||
|
|
"#e6f5d0"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7,
|
|||
|
|
"#b8e186"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8,
|
|||
|
|
"#7fbc41"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.9,
|
|||
|
|
"#4d9221"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#276419"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"sequential": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
],
|
|||
|
|
"sequentialminus": [
|
|||
|
|
[
|
|||
|
|
0,
|
|||
|
|
"#0d0887"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.1111111111111111,
|
|||
|
|
"#46039f"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.2222222222222222,
|
|||
|
|
"#7201a8"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.3333333333333333,
|
|||
|
|
"#9c179e"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.4444444444444444,
|
|||
|
|
"#bd3786"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.5555555555555556,
|
|||
|
|
"#d8576b"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.6666666666666666,
|
|||
|
|
"#ed7953"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.7777777777777778,
|
|||
|
|
"#fb9f3a"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
0.8888888888888888,
|
|||
|
|
"#fdca26"
|
|||
|
|
],
|
|||
|
|
[
|
|||
|
|
1,
|
|||
|
|
"#f0f921"
|
|||
|
|
]
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"colorway": [
|
|||
|
|
"#636efa",
|
|||
|
|
"#EF553B",
|
|||
|
|
"#00cc96",
|
|||
|
|
"#ab63fa",
|
|||
|
|
"#FFA15A",
|
|||
|
|
"#19d3f3",
|
|||
|
|
"#FF6692",
|
|||
|
|
"#B6E880",
|
|||
|
|
"#FF97FF",
|
|||
|
|
"#FECB52"
|
|||
|
|
],
|
|||
|
|
"font": {
|
|||
|
|
"color": "#2a3f5f"
|
|||
|
|
},
|
|||
|
|
"geo": {
|
|||
|
|
"bgcolor": "white",
|
|||
|
|
"lakecolor": "white",
|
|||
|
|
"landcolor": "#E5ECF6",
|
|||
|
|
"showlakes": true,
|
|||
|
|
"showland": true,
|
|||
|
|
"subunitcolor": "white"
|
|||
|
|
},
|
|||
|
|
"hoverlabel": {
|
|||
|
|
"align": "left"
|
|||
|
|
},
|
|||
|
|
"hovermode": "closest",
|
|||
|
|
"mapbox": {
|
|||
|
|
"style": "light"
|
|||
|
|
},
|
|||
|
|
"paper_bgcolor": "white",
|
|||
|
|
"plot_bgcolor": "#E5ECF6",
|
|||
|
|
"polar": {
|
|||
|
|
"angularaxis": {
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"bgcolor": "#E5ECF6",
|
|||
|
|
"radialaxis": {
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"scene": {
|
|||
|
|
"xaxis": {
|
|||
|
|
"backgroundcolor": "#E5ECF6",
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"gridwidth": 2,
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"showbackground": true,
|
|||
|
|
"ticks": "",
|
|||
|
|
"zerolinecolor": "white"
|
|||
|
|
},
|
|||
|
|
"yaxis": {
|
|||
|
|
"backgroundcolor": "#E5ECF6",
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"gridwidth": 2,
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"showbackground": true,
|
|||
|
|
"ticks": "",
|
|||
|
|
"zerolinecolor": "white"
|
|||
|
|
},
|
|||
|
|
"zaxis": {
|
|||
|
|
"backgroundcolor": "#E5ECF6",
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"gridwidth": 2,
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"showbackground": true,
|
|||
|
|
"ticks": "",
|
|||
|
|
"zerolinecolor": "white"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"shapedefaults": {
|
|||
|
|
"line": {
|
|||
|
|
"color": "#2a3f5f"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"ternary": {
|
|||
|
|
"aaxis": {
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"baxis": {
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"ticks": ""
|
|||
|
|
},
|
|||
|
|
"bgcolor": "#E5ECF6",
|
|||
|
|
"caxis": {
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"ticks": ""
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"title": {
|
|||
|
|
"x": 0.05
|
|||
|
|
},
|
|||
|
|
"xaxis": {
|
|||
|
|
"automargin": true,
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"ticks": "",
|
|||
|
|
"title": {
|
|||
|
|
"standoff": 15
|
|||
|
|
},
|
|||
|
|
"zerolinecolor": "white",
|
|||
|
|
"zerolinewidth": 2
|
|||
|
|
},
|
|||
|
|
"yaxis": {
|
|||
|
|
"automargin": true,
|
|||
|
|
"gridcolor": "white",
|
|||
|
|
"linecolor": "white",
|
|||
|
|
"ticks": "",
|
|||
|
|
"title": {
|
|||
|
|
"standoff": 15
|
|||
|
|
},
|
|||
|
|
"zerolinecolor": "white",
|
|||
|
|
"zerolinewidth": 2
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"title": {
|
|||
|
|
"text": "3D Scatter Plot"
|
|||
|
|
},
|
|||
|
|
"width": 800
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABVcAAAMgCAYAAAAnUFLKAAAAAXNSR0IArs4c6QAAIABJREFUeF7s3QmULNldH+ibW1XW9lb19gQYg6QW2BghifEGGOHjscFI4sxYyCA8gwfJtBlsBplVeAxoxsIDCMxij4Qajw+WsAHPMYswNppB2COMMdrGhkGNxI5eL6/fVltm5Trn5uts1atXVRkZuURExRfn9HmtrhsR9343snTeL2/8b2U4HA6DgwABAgQIECBAgAABAgQIECBAgAABAgSmEqgIV6fy0pgAAQIECBAgQIAAAQIECBAgQIAAAQIjAeGqB4EAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcQoAAAQIECBAgQIAAAQIECBAgQIAAAeGqZ4AAAQIECBAgQIAAAQIECBAgQIAAAQIpBISrKdCcki+BH/7n/zo8+o53hkff/A3hjz38ifnqnN4QIECAAAECBAgQIECAAAECBAicWQHh6glTe+36rfDX/tabwh9cferZFuc21+8J8I5rNz7hT37Gp4QffNPXhPW15lQPUAwLv+etP37POfFaL/sznzHVtZI2PimgHI/vxZ/2/PCmb35d0svNpd1JDq/8i3/2rr7MEq7Ocu5cBukiBAgQIECAAAECBAgQIECAAAEChRUQrh4zdeNA8coDl+8KR9/wHW8LP/Vvf2n038Yh50nh436rHb76Dd8XfuUDvxFe/5VfHL7iS75g4kMyPuc3Pvx794S443snvdbEmx1pkNdw9eiK1F9/7HfDa//Od4ZPef4feXZuZglIZzl3WmPtCRAgQIAAAQIECBAgQIAAAQIEzpaAcPWEcPWf/+T/Hf72V/y3d/10HOy99jVf+GxYOmll53j1ZZJQ9N3/4QOjQPakFarx/r/2od8Or37l5839KSxKuBoHftRploB0lnPnPgkuSIAAAQIECBAgQIAAAQIECBAgUCgB4eoU05UmXB2vRr365PXwz37gDeG+yxdOvGOaoO+4sgSHyxEc92r90fIGJ71+//e/6SvCW37kZ+4qjRA7f7TcwdHzj/78cAD9mv/mL4xWnm7v7t9znaMwJ3kcnYeT2k3qVxblF6Z43DQlQIAAAQIECBAgQIAAAQIECBDIuYBwNeEEnVQqYNLK1Xj5cYg3qWbqeEVm0lqt4/ZHV8W+6fvfHl75Fz9rtLlTvHc8DpclOC6MTLNy9aTgOJYweP9/+fCzYfLhADjp2MZux21UlSRcPdqHk/qaJtBO+MhoRoAAAQIECBAgQIAAAQIECBAgcMYFhKsTJvjw6sbjwtEk4epJIehxt06y0jSel+S+Jw1tfO6rXv65z4auacLVOK43vOlt99SHPRp+pu3rpBWp4/k42u6k8grH9UO4esZ/wxkeAQIECBAgQIAAAQIECBAgQGCBAsLVKXDjash3/9IH7goTkwSH04Sr4+6MA8r4+vz4OPw6//jnb3rD657dXOukoRx3rdj28IrXNOFq9HjiqRt3bfoVrzteJfrg/ZfCm775damD4OP6NLZ85V/8s6Nrx+Nou/i/f+JnfvGeMgzjfsVzYjC7vta859wpHgdNCRAgQIAAAQIECBAgQIAAAQIESi4gXJ3iATguSE0SriYtC3BaV8YB6ac8/4+MgsFf+cBvnLr51fhaMQD9qX/7S3cFqfNYuToOKmM/TjrGAWgSo+OukbQm6tFw9WhJgMPXPvozK1en+ABoSoAAAQIECBAgQIAAAQIECBAgcJeAcHWKB+K4lY+TgsNpNrSa1JXDweBTT98abQx12srV4zbgiveYR7gar3PSytWj45hkdNK4kwafVq5OenL8nAABAgQIECBAgAABAgQIECBAYBECwtVjVOOr5/F42Z/5jLt+mmbl6kn1P4+bzB/7qV8If/yFnzTaiOrocTTY3dtvh7/2t94UXvxpz3/29fjxOf/pAx8KG+vN8NT1m8eubj0uXD2pfurRV/wP9ytp+LnscHWamqsnjXsRHzbXJECAAAECBAgQIECAAAECBAgQOFsCwtUTwtWvfsP3hcN1Pcch4298+PcS1Vw9/Nr84dqmpz0+49fg/+RnfMpddUxPuvdxtVwPh4X3P+fCPQHsOOj8g6tP3VUq4KRVrrG/J71mf1K/4jlxLPH4ii/5grnWXD3O77iQ92h93JNWEJ827rP1UTcaAgQIECBAgAABAgQIECBAgACBeQsIV08QHQeXh398NPSMPzscVh691HHtJ03gcfeN5xwOeg9f47jNqmJN1vGq26P9+/gr94fv+bavCq//tn8cXvXyzx2Fn+Pj6L3H1zlaX/XouI6rjRrv889+4A3hvssXMglXxwHv97z1x58d30nzcdK4J82VnxMgQIAAAQIECBAgQIAAAQIECJRbQLha7vk3egIECBAgQIAAAQIECBAgQIAAAQIEUgoIV1PCOY0AAQIECBAgQIAAAQIECBAgQIAAgXILCFfLPf9GT4AAAQIECBAgQIAAAQIECBAgQIBASgHhako4pxEgQIAAAQIECBAgQIAAAQIECBAgUG4B4Wq559/oCRAgQIAAAQIECBAgQIAAAQIECBBIKSBcTQnnNAIECBAgQIAAAQIECBAgQIAAAQIEyi0gXC33/Bs9AQIECBAgQIAAAQIECBAgQIAAAQIpBYSrKeGcRoAAAQIECBAgQIAAAQIECBAgQIBAuQWEq+Wef6MnQIAAAQIECBAgQIAAAQIECBAgQCClgHA1JZzTCBAgQIAAAQIECBAgQIAAAQIECBAot4Bwtdzzb/QECBAgQIAAAQIECBAgQIAAAQIECKQUEK6mhHMaAQIECBAgQIAAAQIECBAgQIAAAQLlFhCulnv+jZ4AAQIECBAgQIAAAQIECBAgQIAAgZQCwtWUcE4jQIAAAQIECBAgQIAAAQIECBAgQKDcAsLVcs+
|
|||
|
|
"text/html": [
|
|||
|
|
"<div> <div id=\"84a7c5d6-05b7-4e99-bfd7-f17c4ab3e895\" class=\"plotly-graph-div\" style=\"height:800px; width:800px;\"></div> <script type=\"text/javascript\"> require([\"plotly\"], function(Plotly) { window.PLOTLYENV=window.PLOTLYENV || {}; if (document.getElementById(\"84a7c5d6-05b7-4e99-bfd7-f17c4ab3e895\")) { Plotly.newPlot( \"84a7c5d6-05b7-4e99-bfd7-f17c4ab3e895\", [{\"marker\":{\"color\":[0.9999500004166652,0.9998000066665778,0.9995500337489877,0.999200106660978,0.9987502603949664,0.9982005399352042,0.9975510002532796,0.9968017063026194,0.9959527330119944,0.9950041652780258,0.9939560979566968,0.9928086358538664,0.991561893714788,0.9902159962126372,0.9887710779360422,0.9872272833756268,0.9855847669095608,0.9838436927881214,0.9820042351172704,0.9800665778412416,0.9780309147241484,0.9758974493306056,0.9736663950053748,0.9713379748520296,0.9689124217106448,0.9663899781345132,0.9637708963658904,0.9610554383107708,0.9582438755126972,0.955336489125606],\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"opacity\":1,\"size\":6},\"mode\":\"markers\",\"text\":[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\",\"8\",\"9\",\"10\",\"11\",\"12\",\"13\",\"14\",\"15\",\"16\",\"17\",\"18\",\"19\",\"20\",\"21\",\"22\",\"23\",\"24\",\"25\",\"26\",\"27\",\"28\",\"29\"],\"textposition\":\"top center\",\"x\":[0.5403023058681398,-0.4161468365471424,-0.9899924966004454,-0.6536436208636119,0.2836621854632262,0.960170286650366,0.7539022543433046,-0.1455000338086135,-0.9111302618846768,-0.8390715290764524,0.0044256979880507,0.8438539587324921,0.9074467814501962,0.1367372182078336,-0.7596879128588213,-0.9576594803233848,-0.2751633380515969,0.6603167082440802,0.9887046181866692,0.4080820618133919,-0.5477292602242684,-0.9999608263946372,-0.5328330203333975,0.424179007336997,0.9912028118634736,0.6469193223286404,-0.2921388087338362,-0.9626058663135666,-0.7480575296890003,0.154251449887584],\"y\":[0.0099998333341666,0.019998666693333,0.0299955002024956,0.0399893341866341,0.0499791692706783,0.0599640064794445,0.0699428473375327,0.0799146939691727,0.089878549198011,0.0998334166468281,0.1097783008371748,0.1197122072889193,0.1296341426196948,0.1395431146442365,0.1494381324735992,0.1593182066142459,0.169182349066996,0.1790295734258241,0.1888588949765005,0.1986693307950612,0.2084598998460995,0.2182296230808693,0.2279775235351884,0.2377026264271345,0.2474039592545229,0.2570805518921551,0.2667314366888311,0.2763556485641137,0.2859522251048355,0.2955202066613395],\"z\":[0.9999500004166652,0.9998000066665778,0.9995500337489877,0.999200106660978,0.9987502603949664,0.9982005399352042,0.9975510002532796,0.9968017063026194,0.9959527330119944,0.9950041652780258,0.9939560979566968,0.9928086358538664,0.991561893714788,0.9902159962126372,0.9887710779360422,0.9872272833756268,0.9855847669095608,0.9838436927881214,0.9820042351172704,0.9800665778412416,0.9780309147241484,0.9758974493306056,0.9736663950053748,0.9713379748520296,0.9689124217106448,0.9663899781345132,0.9637708963658904,0.9610554383107708,0.9582438755126972,0.955336489125606],\"type\":\"scatter3d\"}], {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram
|
|||
|
|
" \n",
|
|||
|
|
"var gd = document.getElementById('84a7c5d6-05b7-4e99-bfd7-f17c4ab3e895');\n",
|
|||
|
|
"var x = new MutationObserver(function (mutations, observer) {{\n",
|
|||
|
|
" var display = window.getComputedStyle(gd).display;\n",
|
|||
|
|
" if (!display || display === 'none') {{\n",
|
|||
|
|
" console.log([gd, 'removed!']);\n",
|
|||
|
|
" Plotly.purge(gd);\n",
|
|||
|
|
" observer.disconnect();\n",
|
|||
|
|
" }}\n",
|
|||
|
|
"}});\n",
|
|||
|
|
"\n",
|
|||
|
|
"// Listen for the removal of the full notebook cells\n",
|
|||
|
|
"var notebookContainer = gd.closest('#notebook-container');\n",
|
|||
|
|
"if (notebookContainer) {{\n",
|
|||
|
|
" x.observe(notebookContainer, {childList: true});\n",
|
|||
|
|
"}}\n",
|
|||
|
|
"\n",
|
|||
|
|
"// Listen for the clearing of the current output cell\n",
|
|||
|
|
"var outputEl = gd.closest('.output');\n",
|
|||
|
|
"if (outputEl) {{\n",
|
|||
|
|
" x.observe(outputEl, {childList: true});\n",
|
|||
|
|
"}}\n",
|
|||
|
|
"\n",
|
|||
|
|
" }) }; }); </script> </div>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import plotly.graph_objs as go\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 数据\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"position = pd.read_csv(r\"D:\\pythonwork\\2024DL\\Position_Encoding_for_30_Samples.csv\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 绘制3D散点图\n",
|
|||
|
|
"fig = go.Figure(data=[go.Scatter3d(\n",
|
|||
|
|
" x=position['维度1'],\n",
|
|||
|
|
" y=position['维度2'],\n",
|
|||
|
|
" z=position['维度3'],\n",
|
|||
|
|
" mode='markers',\n",
|
|||
|
|
" marker=dict(\n",
|
|||
|
|
" size=6,\n",
|
|||
|
|
" color=position['维度3'], # 设置颜色为维度3\n",
|
|||
|
|
" colorscale='Viridis', # 颜色范围\n",
|
|||
|
|
" opacity=1),\n",
|
|||
|
|
" text=position.index.tolist(), # 添加样本编号作为文本标签\n",
|
|||
|
|
" textposition='top center'\n",
|
|||
|
|
")])\n",
|
|||
|
|
"\n",
|
|||
|
|
"fig.update_layout(\n",
|
|||
|
|
" title='3D Scatter Plot',\n",
|
|||
|
|
" scene=dict(\n",
|
|||
|
|
" xaxis=dict(title='维度1', range=[-1.5, 1.5]),\n",
|
|||
|
|
" yaxis=dict(title='维度2', range=[0, 0.3]),\n",
|
|||
|
|
" zaxis=dict(title='维度3', range=[0.95, 1])\n",
|
|||
|
|
" ),\n",
|
|||
|
|
" width=800, # 调整图像宽度\n",
|
|||
|
|
" height=800 # 调整图像高度\n",
|
|||
|
|
")\n",
|
|||
|
|
"\n",
|
|||
|
|
"fig.show()"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "f0f681df-9eec-40ae-a1b8-da04d337529b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"现在你已经彻底了解正余弦编码的运作过程了。在这一小节我们总结了3个正余弦编码的意义:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5bc496e0-50f0-476d-bf18-3cf0093a6d38",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"**<center>正余弦编码的意义①:sin和cos函数值域有限,可以很好地限制位置编码的数字大小。**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<center>===================\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>正余弦编码的意义②:通过调节频率,我们可以得到多种多样的sin和cos函数,<br><br>从而可以将位置信息投射到每个维度都各具特色、各不相同的高维空间,以形成对位置信息的更好的表示**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<center>===================\n",
|
|||
|
|
"\n",
|
|||
|
|
"**<center>正余弦编码的意义③:通过独特的计算公式,我们可以让特征编号小的特征被投射到剧烈变化的维度上,<br><br>并且让特征编号大的特征被投射到轻微变化、甚至完全单调的维度上,从而可以让小编号特征去<br><br>捕捉样本之间的局部细节差异,让大编号特征去捕捉样本之间按顺序排列的全局趋势**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "9265b088-d45b-40a5-a624-764915bd8b2d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"但除此之外,正余弦编码还有一些额外的好处——\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 首先最重要的是其**函数的周期性带来泛化性**:在模型训练过程中,我们可能使用的都是序列长度小于20的数据,但是当实际应用中遇到一个序列长度为50的数据,**正弦和余弦函数的周期性**意味着,即使模型在训练时未见过某个位置,它仍然可以生成一个合理的位置编码。它可用泛化到不同长度的序列。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **不增加额外的训练参数**:当我们在一个已经很大的模型(如 GPT-3 或 BERT)上添加位置信息时,我们不希望增加太多的参数,因为这会增加训练成本和过拟合的风险。正弦和余弦位置编码不增加任何训练参数。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **即便是相同频率下的正余弦函数,也可以通过周期性带来部分的相对位置信息,可以比绝对位置信息更有效**:正弦和余弦函数的周期性特征为模型提供了一种隐含的相对位置信息,使得模型能够更有效地理解序列中不同位置之间的相对关系。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 63,
|
|||
|
|
"id": "41507a9b-9bd7-4f8e-b27c-ea914cd7c8c0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAGoCAYAAABbkkSYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAAsTAAALEwEAmpwYAACLr0lEQVR4nOzdd1xV9/3H8deXjbJkOBAUZ5yAiit772YnxiyNWU2atukeya9NM9q0Tdu0zWgzNYnNTprE7KEZalREwL0H4EI2srnf3x9cKHGicjl3vJ+PBw+55557zgfk3vP9nO/3+/kaay0iIiIiIiJy7IKcDkBERERERMRfKMESERERERHpJEqwREREREREOokSLBERERERkU6iBEtERERERKSTKMESERERERHpJEqwpEsZY641xnzsdBxHyxiz0hhzqree3xgzzxhzc9dFJCLivXTN8ez5dc0ROTAlWNLpjDEnGmMWGGMqjDGlxpj5xpjxANba2dbas4/yuPcaYxqNMdXtvn7eudF/63wzjTEPtN9mrR1prZ3nqXMeTvvzu38fLx7tsYwxpxpjXPv8Pt/ttGD9nDHmNGPMXPff+Ran4xEJVLrmeI6uOd7DtPijMabE/fVHY4w5yL59jDHvGGO2G2OsMSati8MNeCFOByD+xRgTA8wBbgdeBcKAk4D6TjrFK9ba6zrpWALbrbUph9rBGBNirW3qqoB8yF7gWeAl4NcOxyISkHTN8Tm65hy9W4FLgAzAAp8Am4F/HWBfF/Ah8AdgQRfFJ+2oB0s621AAa+1L1tpma22ttfZja20+gDFmujHm69ad3XdWvmuMWW+MKTfGPHawOzIHs+9dNWNMmvu4Ie7H84wx97vvalYZYz42xiS227/17me5MabAHeOtwLXAz9vfZTPGbDHGnOn+PtwY84j7DtF29/fh7udONcYUGmN+YozZbYzZYYy58SDxn2aMWd7u8SfGmCXtHn9ljLmk/fmNMefS0qif4o4vr90h+x/sZ+3g73O6+/V/M8aUAPe6f9aHjTHbjDG7jDH/MsZEtnvNz9w/43ZjzAz3739wu9//zfscv/3fwDD3z1xqjFlrjLmq3XMz3X8T77l/nkXGmEHtnh/Z7rW7jDG/Nsb0NsbUGGMS2u031hhTbIwJPZLfxaFYaxdba18ANnXWMUXkiOmag645gXDNAaYBf7HWFlpri4C/ANMPtKO1dpe19nFgyYGeF89TgiWdbR3QbIyZZYw5zxjTowOvuRAYD6QDVwHneCCua4AbgZ603OH8KYAxpj/wAfBPIAnIBHKttU8Cs4E/WWujrLXfOcAx7wYmuV+TAUwA7mn3fG8gFugL3AQ8dpDfxzfAEGNMovvDOB1INsZEuy8oWcBX7V9grf0Q+D0td1ejrLUZh/tZj9BEWhKHXsCDwEO0NGQygcHun+k3AO4L70+Bs4AhwJkdPYkxpjstd+H+4473auBxY8yIdrtdDfwO6AFscMeDMSYa+JSWu3TJ7rg+s9buBObR8rfU6nrgZWtt4wFiuMbd0DnYV7+O/jwi0uV0zfkfXXMOw8evOSOB9oltnnubeCElWNKprLWVwIm0dF8/BRSblnHAvQ7xsoesteXW2m3AXFo+UA/mqn0+iJI7GNpz1tp11tpaWoaRtJ7jGuBT993PRmttibU2t4PHvBa4z1q721pbTMsH8vXtnm90P99orX0fqAaO2/cg7piWACcD42j50JwPnEDLxXS9tbakgzEd6mc9kOR9fp+tF4jt1tp/uodp1NEyNOFH1tpSa20VLRfaq937XuU+5wpr7V7g3iOI9UJgi7X2OWttk7V2GfAGcGW7fd5y9xY10dIAyWz32p3W2r9Ya+ustVXW2kXu52YB1wEYY4KBqcALBwrAWvsfa23cIb62HcHPIyJdSNccXXMInGtOFFDR7nEFEGXMkfXAStfQHCzpdNba1bi7rY0xw4AXgUdo+cA5kJ3tvq+h5UPkYF7ddzx8Bz9bDnaOVGBjRw5wAMnA1naPt7q3tSrZZxz5oX62L4BTgUL392XAKbTMI/jiCOM6kt/nfuPhjTHTgYJ2m5KAbsDSdr9rAwS7v08Glrbbv/3v5HD6AxONMeXttoXw7QvT0fzfvQ38yxgzgJYGRoW1dvERxCUiPkLXnDa65hyeL19zqoGYdo9jgGprre3k80gnUA+WeJS1dg0wExjlwdPspeXDuFXvI3htATDoIM8d7kNrOy0f1q36ubcdjdaL3cnu77+g5WJ3Cge/2HnyQ7X9sfcAtcDIdnfYYq21rRedHbRceFrtO7zhUP8/BcAX+9y9i7LW3t6BGAuAgQcM3to6Wu6kXkfLHd4D3kmEtjLO1Yf40hBBER+ha06H6Zrje9eclbQMDW2V4d4mXkgJlnQq9+TRnxhjUtyPU2m5i/iNB0+bC5xsjOlnjIkFfnUEr50NnGmMucoYE2KMSTDGZLqf28VBPkzdXgLuMcYkmZZJvb+h5c7p0VhAy12vCcBia+1K3HfagC8P8ppdQJoxxqPvY2uti5ahN38zxvQEMMb0Nca0zlt4FZhujBlhjOkG/HafQ+QClxljupmWScg3tXtuDjDUGHO9MSbU/TXeGDO8A6HNAfoYY+4yLROio40xE9s9/zwtd7Uv4hAXO9tSxjnqEF8HHK5hjAkyxkQAoS0PTYQxJqwDcYtIJ9E1R9ccAuSa4z7+j92/i2TgJ7TcTDgg9/Up3P0w3P1YuogSLOlsVbR8QC8yxuyl5SK3gpYPAo+w1n4CvALk0zJsYM4RvHYbcD4t8ZXS8sHceofoGWCEe5z4fw/w8geAbPd5lwM57m1H8zPsdb9+pbW2wb15IbDVWrv7IC97zf1viTEm52jOewR+Qctk32+MMZW0TPQ9DsBa+wEtw3E+d+/z+T6v/RvQQMvFeRYtDQzcr60CzqZlbP12WoZm/JH/XRQOyv3as4DvuF+3Hjit3fPzaSlVm2OtPZIhJB11Mi13Wd+n5Q5qLeCzC5qK+Chdc47uZ9A1x/euOf8G3qXl/34F8J57GwDu3q+T2u1fS8uwQoA17sfSRYyGbopIZzPGWGCItXaDw3F8DvzHWvu0k3GIiIjn6Joj3kZFLkTELxljxgNjgYudjkVERPybrjnSnoYIiojfMcbMomVIyV3uYR0iIiIeoWuO7EtDBEVERERERDqJerBEREREREQ6SUDNwUpMTLRpaWlOhyEiIsdg6dKle6y1SU7HcbR0LRIR8Q8Hux4FVIKVlpZGdna202GIiMgxMMZ4ogRyl9G1SETEPxzseqQhgiIiIiIiIp1ECZaIiIiIiEgnUYIlIiIiIiLSSQJqDpaIiIg3amxspLCwkLq6OqdDEekSERERpKSkEBoa6nQoIp1OCZaIiIjDCgsLiY6OJi0tDWOM0+GIeJS1lpKSEgoLCxkwYIDT4Yh0Og0RFBERcVhdXR0JCQlKriQgGGNISEhQj634LSVYIiIiXkDJlQQS/b2LP1OCJSIiIiIi0kmUYImIiAgPPvggI0eOJD09nczMTBYtWgTAzTffzKpVqzp8nJkzZ5KUlERmZiaZmZnccMMNnRrn73//+289Pv744zvt2HfddRdffvklAKeeeirHHXdc28/x+uuvd9p5fEl9fT1Tpkxh8ODBTJw4kS1bthxwvxkzZtCzZ09GjRr1re0//elP+fzzz7sgUhHvoQRLREQkwC1cuJA5c+aQk5NDfn4+n376KampqQA8/fTTjBgx4oiON2XKFHJzc8nNzeX555/v1Fj3TbAWLFjQKcctKSnhm2++4eSTT27bNnv27Laf44orrvjW/s3NzZ1yXm/3zDPP0KNHDzZs2MCPfvQjfvGLXxxwv+nTp/Phhx/ut/373/8+Dz30kKfDFPEqSrBEREQC3I4dO0hMTCQ8PByAxMREkpOTgZaenOzsbACioqK4++67ycjIYNKkSezatatDx583bx4XXnhh2+M777yTmTNnApCWlsZvf/tbxo4dy+jRo1mzZg0A1dXV3HjjjYwePZr09HTeeOMNfvnLX1JbW0tmZibXXnttW0zQUpnuZz/7GaNGjWL06NG88sorbec+9dR
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 864x432 with 2 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {
|
|||
|
|
"needs_background": "light"
|
|||
|
|
},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import matplotlib.pyplot as plt\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 定义绘制正弦函数的函数\n",
|
|||
|
|
"def plot_sin_functions():\n",
|
|||
|
|
" y = np.linspace(0, 10, 1000) # 定义 y 轴范围\n",
|
|||
|
|
" \n",
|
|||
|
|
" fig, ays = plt.subplots(1, 2, figsize=(12, 6)) # 创建1y2子图\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 绘制频率为1的正弦函数\n",
|
|||
|
|
" frequency1 = 1\n",
|
|||
|
|
" y1 = np.sin(frequency1 * y)\n",
|
|||
|
|
" ays[0].plot(y, y1, label=f'Sin Function (Frequency = {frequency1})')\n",
|
|||
|
|
" ays[0].set_title(f'Sin Function with Frequency = {frequency1}')\n",
|
|||
|
|
" ays[0].set_ylabel('y')\n",
|
|||
|
|
" ays[0].set_ylabel('sin(y)')\n",
|
|||
|
|
" ays[0].legend()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 绘制频率为0.1的正弦函数\n",
|
|||
|
|
" frequency2 = 0.1\n",
|
|||
|
|
" y2 = np.sin(frequency2 * y)\n",
|
|||
|
|
" ays[1].plot(y, y2, label=f'Sin Function (Frequency = {frequency2})')\n",
|
|||
|
|
" ays[1].set_title(f'Sin Function with Frequency = {frequency2}')\n",
|
|||
|
|
" ays[1].set_ylabel('y')\n",
|
|||
|
|
" ays[1].set_ylabel('sin(y)')\n",
|
|||
|
|
" ays[1].legend()\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.tight_layout()\n",
|
|||
|
|
" plt.show()\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 绘制两个正弦函数在横向排列的子图中\n",
|
|||
|
|
"plot_sin_functions()\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a000479d-aba9-47af-81ee-a9afeb828588",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 2.2 Encoder结构解析"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a91102be-93a2-46ae-8359-721bb7f0cb1d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"编码器(Encoder)结构包括两个子层:一个是自注意力(Self-Attention)层,另一个是前馈(Feed-Forward)神经网络。输入会先经过自注意力层,这层的作用是帮助模型关注输入序列中不同位置的信息。然后,经过前馈神经网络层,这是一个简单的全连接神经网络。两个子层都有一个残差连接(Residual Connection)和层标准化(Layer Normalization)。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c357feb7-f752-47dc-8698-cd369d694d64",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/image-1.png\" alt=\"描述文字\" width=\"400\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "49692481-981f-41e5-812c-196400b69d31",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"### 2.2.1 残差连接\n",
|
|||
|
|
"\n",
|
|||
|
|
"观察Transformer的结构不难发现——在多头注意力机制之后,我们输出的信息(通常用$z_1$表示)并没有直接传入前馈神经网络,而是经过了一个**Add & Normalize**层,这是什么操作,代表了什么含义呢?\n",
|
|||
|
|
"\n",
|
|||
|
|
"首先来看**Add**,这里的Add表示“加和”,是在多头注意力机制输出的信息的基础上加了一个输入数据y自身,这个数据y是从输入层传过来的。这种通过两条链路并行、一条链路进行复杂计算(在这里是多头注意力机制)、一条链路将输入数据y原封不动传到架构后方、并且最终让两条链路上的输出结果进行加和的操作,叫做残差操作。**与复杂链路并行、负责将y进行传输的链路就是残差链接**。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在之前的课程里我们详细地讲解过残差网络,残差网络正是利用残差链接来对抗深度神经网络的“退化问题”。何凯明在2015年提出的残差网络(ResNet)https://aryiv.org/abs/1512.03385中提出了残差链接的构想,这是他当时构想的最基础的残差块的设计↓,如你所见,也是让残差链接与复杂链路并行的结构。\n",
|
|||
|
|
"\n",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "67e6e903-9db0-45fa-8848-aa2cd16e8b07",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"这种将y原封不动传到架构后方的操作可以解决梯度消失问题、可以让模型学到恒等映射(输入与输出相等)、可以让训练被简化、加速训练收敛、同时非常重要的是,还能避免网络在加深的时候出现退化问题。在之前讲解残差网络的时候,我们深读了残差网络的原始论文,详细解读了为什么残差链接让网络拟合恒等函数却能够获得很好的效果,讲解了这个反直觉的架构设计是如何逼迫网络变得越来越强。在Transformer中,残差链接可以说承担着各种方面的职责,从实验的结果来看,在各层次上加上残差链接可以让Transformer效果更好。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在残差链接的所有效果中,我们可以从数学角度、非严格证明一下它为什么能够解决梯度消失问题。\n",
|
|||
|
|
"\n",
|
|||
|
|
"假设现在存在一个神经网络,**它由多个残差结构相连(就像transformer一样)**。每个残差结构被定义为F(x,W),这个结构是由一个复杂结构 + 一个残差链接并行而成的,其中$x$代表残差输入的数据,$W$代表该结构中的权重。设$x_i, x_{i+1}$分别代表残差结构F()的输入和输出,设$x_I$代表整个神经网络的输入,令relu激活函数为$r(y)=max(0,x)$,简写为$r()$。由此可得:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\begin{aligned}\n",
|
|||
|
|
"x_{i+1} & = r(x_i + F(x_i, W_i)) \\\\ \\\\\n",
|
|||
|
|
"x_{i+2} & =r(x_{i+1}+F(x_{i+1},W_{i+1})) \\\\ \\\\ \n",
|
|||
|
|
"…… \\\\ \\\\\n",
|
|||
|
|
"&\\text{如果relu激活函数是被激活状态(残差结构输出的值都大于0),则有} \\\\ \\\\\n",
|
|||
|
|
"x_{i+1} & = x_i + F(x_i, W_i) \\\\ \\\\\n",
|
|||
|
|
"x_{i+2} & =x_{i+1}+F(x_{i+1},W_{i+1}) \\\\ \\\\\n",
|
|||
|
|
"…… \\\\ \\\\\n",
|
|||
|
|
"&\\text{这是一个递归嵌套结构,如此递归推导可以得到} \\\\ \\\\\n",
|
|||
|
|
"x_I & =x_i+\\sum_{n=i}^{I-1}F(x_n,W_n)\n",
|
|||
|
|
"\\end{aligned}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "8dfc84e1-70a2-4440-bd79-56e317376fe2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"此时如果我们对神经网络的结构求梯度,则会有——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\begin{aligned}\n",
|
|||
|
|
"\\frac{\\partial Loss}{\\partial x_{i}}&=\\frac{\\partial Loss}{\\partial x_{I}} * \\frac{\\partial x_{I}}{\\partial x_{i}} \\\\\n",
|
|||
|
|
"&=\\frac{\\partial Loss}{\\partial x_{I}} * \\frac{\\partial(x_i+\\sum_{n=i}^{I-1}F(x_n,W_n))}{\\partial x_{i}} \\\\\n",
|
|||
|
|
"&=\\frac{\\partial Loss}{\\partial x_{I}}*(1+\\frac{\\sum_{n=i}^{I-1}F(x_n,W_n)}{\\partial x_{i}})\n",
|
|||
|
|
"\\end{aligned}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"从结果可以看出,因为有“1+”这一结构的存在,可以有效避免梯度消失(求解后梯度为0)的情况,这样网络深层处的梯度可以直接传递到网络的浅层、让迭代变得更加稳定。与此同时,残差网络在更新梯度时把一些乘法转变为了加法,同时也提高了计算效率。\n",
|
|||
|
|
"\n",
|
|||
|
|
"基于残差结构的这些优势,Transformer在注意力机制的外侧添加了残差链接,从而让encoder和decoder中的梯度传输都变得更加稳定。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a2e03638-e8f5-4503-bfc2-27a01219cdd4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"### 2.2.2 Layer Normalization层归一化\n",
|
|||
|
|
"\n",
|
|||
|
|
"在了解了Add之后,我们来看一下**Normalize**。在Transformer结构中,Layer Normalization(层归一化)是一个至关重要的部分,它是一种特定的归一化技术,它在2016年被提出,用于减少训练深度神经网络时的内部协方差偏移(internal covariate shift)。我们在课程的Lesson13-15部分详细讲解过内部协方差偏移的关键知识,感兴趣的小伙伴可以回到可成lesson13-15去详细了解。\n",
|
|||
|
|
"\n",
|
|||
|
|
"与Batch Normalization(批归一化)不同,Layer Normalization不是对一个批次(batch)中的样本进行归一化,而是独立地对每个样本中的所有特征进行归一化(也就是对单一词向量、单一时间点的所有embedding维度进行归一化)。具体来说,对于每个样本,Layer Normalization会在特定层的所有激活上计算均值和方差,然后用这些统计量来归一化该样本的激活值。Transformer的Normalize使用了2016年Jimmy Lei Ba等人的的论文《Layer Normalization》https://aryiv.org/abs/1607.06450v1。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fde3b0c2-f719-4654-b402-28b7a9d9e6e9",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **为什么要进行Normalize呢?**\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - **减少内部协方差偏移**:在深度学习模型训练过程中,参数的更新会影响后续层的激活分布,这可能导致训练过程不稳定。Layer Normalization通过规范化每一层的输出来减轻这种效应,有助于稳定训练过程。<br><br>\n",
|
|||
|
|
"> - **加速训练速度**:归一化可以使得梯度更稳定,这通常允许更高的学习率,从而加快模型的收敛速度。<br><br>\n",
|
|||
|
|
"> - **减少对初始值的依赖**:由于Layer Normalization使得模型对于输入数据的分布变化更为鲁棒,因此可以减少模型对于参数初始值的敏感性。<br><br>\n",
|
|||
|
|
"> - **允许更深层网络的训练**:通过规范化每层的激活,Layer Normalization可以帮助训练更深的网络结构,而不会那么容易出现梯度消失或爆炸的问题。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fe738142-3fc9-440a-9709-5ea5bb939b19",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **为什么使用Layer Normalization(LN)而不使用Batch Normalization(BN)呢?**\n",
|
|||
|
|
"\n",
|
|||
|
|
"BN 和 LN 的差别就在$u_i$和 $\\sigma_i$这里,前者在某一个 Batch 内统计某特定神经元节点的输出分布(跨样本),后者在某一次迭代更新中统计同一层内的所有神经元节点的输出分布(同一样本下)。\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"最初BN 是为 CNN 任务提出的,需要较大的 BatchSize 来保证统计量的可靠性,并在训练阶段记录全局的$u$和 $\\sigma$供预测任务使用。而LN是独立于batch大小的,它只对单个输入样本的所有特征进行规范化。\n",
|
|||
|
|
"\n",
|
|||
|
|
"> * NLP任务中经常会处理长度不同的句子,使用LN时可以不考虑其它样本的长度。<br><br>\n",
|
|||
|
|
"> * 在某些情况下,当可用的内存有限或者为了加速训练而使用更小的batch时,BN因为batch数量不足而受到了限制。<br><br>\n",
|
|||
|
|
"> * 在某些NLP任务和解码设置中,模型可能会一个接一个地处理序列中的元素,而不是一次处理整个batch。这样BN就不是很适用了。<br><br>\n",
|
|||
|
|
"> * 在Transformer模型中有很深的层次和自注意机制。通过对每一层的输入进行规范化,可以防止值的爆炸或消失,从而帮助模型更快地收敛。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4749cc50-5e77-4418-aa13-bc8915e4cad3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- 【加餐】各类Normalization的本质\n",
|
|||
|
|
"\n",
|
|||
|
|
"LN 是 Normalization(规范化)家族中的一员,由 Batch Normalization(BN)发展而来。基本上所有的规范化技术,都可以概括为如下的公式:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$h_i = f(a_i) \\\\\n",
|
|||
|
|
"{h_i}^{'}=f(\\frac{g_i}{\\sigma_i}(a_i-u_i)+b_i)$\n",
|
|||
|
|
"\n",
|
|||
|
|
"这个公式描述了Normalization技术中对于单个数据点$a_i$在某一层的激活值进行规范化的过程。这里是每个符号的含义:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\begin{aligned}\n",
|
|||
|
|
"&a_i: \\text{原始神经网络层的激活值或输出。} \\\\ \\\\\n",
|
|||
|
|
"&f: \\text{应用于规范化之后的值的激活函数。} \\\\ \\\\\n",
|
|||
|
|
"&h_i: \\text{应用激活函数} f \\text{之后的激活值,是规范化步骤之前的输出。} \\\\ \\\\\n",
|
|||
|
|
"&h'_i: \\text{最终的规范化输出值。} \\\\ \\\\\n",
|
|||
|
|
"&\\sigma_i: \\text{用于规范化过程中的尺度调整的标准差。} \\\\ \\\\\n",
|
|||
|
|
"&u_i: \\text{平均值。} \\\\ \\\\\n",
|
|||
|
|
"&g_i: \\text{尺度参数。} \\\\ \\\\\n",
|
|||
|
|
"&b_i: \\text{偏置参数。}\n",
|
|||
|
|
"\\end{aligned}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"对于隐层中某个节点的输出为对激活值$a_i$ 进行非线性变换$f()$ 后的 $h_i$\n",
|
|||
|
|
"先使用均值$u_i$和方差 $\\sigma_i$对$a_i$ 进行**分布调整**。\n",
|
|||
|
|
"如果以正态分布为例,就是把“高瘦”(红色)和“矮胖”(蓝紫色)的都调整回正常体型(绿色),把偏离y=0的(紫色)拉回中间来。\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"* 这样可以将每一次迭代的数据调整为相同分布,消除极端值,提升训练稳定性。\n",
|
|||
|
|
"* 同时“平移”操作,可以让激活值落入$f()$的梯度敏感区间即梯度更新幅度变大,模型训练加快。\n",
|
|||
|
|
"\n",
|
|||
|
|
"然而,在梯度敏感区内,隐层的输出接近于“线性”,模型表达能力会大幅度下降。引入 gain 因子$g_i$ 和 bias 因子 $b_i$,为规范化后的分布再加入一点“个性”。\n",
|
|||
|
|
"\n",
|
|||
|
|
"注: $g_i$和$b_i$作为**模型参数训练得到**,$u_i$和 $\\sigma_i$在**限定的数据范围内统计得到**。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "19a62d44-27fb-4e9e-a1a5-fd1fb49edf4e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"### 2.2.3 Feed-Forward Networks前馈网络\n",
|
|||
|
|
"\n",
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/image-1.png\" alt=\"描述文字\" width=\"400\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "98cb8923-0d8a-4cd6-9bb2-c9f8375cd2b4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"根据Transformer的结构图可以看出,每一个多头注意力机制层都链接了一个前馈网络层。前馈网络(Feed-Forward Networks,FFNs),在神经网络的语境中,是指那些信息单向流动的网络结构。在这样的网络中,信息从输入层流向输出层,中间可能会经过多个隐藏层,但不会有任何反向的信息流,即不存在循环或者回路。因此在Transformer当中,实际上前馈神经网络就是**由线性层组成的深度神经网络结构**。它的主要职责是对输入数据进行**非线性变换,同时也负责产生输出值**。它的作用暗示了一个关键的事实——**自注意力机制大多数时候是一个线性结构**:加权求和是一个线性操作,即便我们是经过丰富的权重变化、由丰富的Q、K、V等矩阵点积的结果,还有softmax函数,但是自注意力机制依然是一个线性的过程。因此,在加入前馈神经网络之前,transformer本身不带有传统意义上的非线性结构。\n",
|
|||
|
|
"\n",
|
|||
|
|
"在现代深度学习架构中,特别是在Transformer模型中,前馈网络通常指的是一个特定的子层,它由两个线性变换组成,中间夹有一个激活函数,如ReLU或者GELU。具体结构可以表示为:\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. **第一层线性变换**:<br><br>\n",
|
|||
|
|
" $ z_1 = xW_1 + b_1 $\n",
|
|||
|
|
" - $x$ 是输入向量。<br>\n",
|
|||
|
|
" - $W_1$ 和 $b_1$ 是第一层的权重矩阵和偏置向量。<br><br>\n",
|
|||
|
|
" \n",
|
|||
|
|
"3. **ReLU激活函数**:<br><br>\n",
|
|||
|
|
" $ a_1 = \\text{ReLU}(z_1) = \\max(0, z_1) $\n",
|
|||
|
|
" - ReLU的作用是引入非线性,使得网络能够学习更复杂的函数映射。<br>\n",
|
|||
|
|
" - ReLU函数将输入中的负值置为零,正值保持不变。<br><br>\n",
|
|||
|
|
" \n",
|
|||
|
|
"5. **第二层线性变换**:<br><br>\n",
|
|||
|
|
" $ z_2 = a_1W_2 + b_2 $\n",
|
|||
|
|
" - $a_1$ 是经过ReLU激活后的中间表示。<br>\n",
|
|||
|
|
" - $W_2$ 和 $b_2$ 是第二层的权重矩阵和偏置向量。<br>\n",
|
|||
|
|
" - 最终输出 $z_2$ 是前馈神经网络的输出。<br><br>\n",
|
|||
|
|
"\n",
|
|||
|
|
"合起来,前馈神经网络的完整表达式为:<br><br>\n",
|
|||
|
|
"$$ FFN(x) = \\max(0, xW_1 + b_1)W_2 + b_2 $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"Transformer模型中的前馈网络在自注意力层之后对每个位置的表示独立地应用相同的变换,这样可以进一步提高网络的表示能力。由于在前馈网络中对每个位置进行的是相同的操作,所以它们非常适合于并行计算。这种层通常被设计为宽度很大,以便在模型中捕获大量的特征并提供足够的模型容量。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ffa475c5-a4dd-4e74-9bd0-7bfbcd686170",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"以线性层作为输出层是许多深度学习架构的经典操作。Encoder在最后使用的前馈神经网络可以说是以线性层结尾,它本身有激活函数,可以产生输出,**因此Encoder编码器部分是可以单独使用的结构**。在许多情况下,我们可以单独使用 Encoder 的输出来执行各种任务,而不需要 Decoder 解码器部分,下面是一些经典的场景——\n",
|
|||
|
|
"> **encoder走位特征提取器**:Encoder 的输出可以用作特征提取器,将输入序列转换为一系列有意义的特征表示。这些特征表示可以用于各种机器学习任务,如分类、聚类、序列标注等。<br><br>\n",
|
|||
|
|
"> **encoder生成类似autoencoder的语义表示**:Encoder 的输出可以被用来获取输入序列的语义表示。这些语义表示可以用于进行语义相似度计算、文本匹配、信息检索等自然语言处理任务。<br><br>\n",
|
|||
|
|
"> **序列到序列任务的编码器**:在一些序列到序列任务中,只需要对输入序列进行编码,而不需要生成输出序列。例如,文本摘要、问答系统中,只需将输入文本编码为一个语义表示,而无需生成摘要或答案。<br><br>\n",
|
|||
|
|
"> **预训练模型的基础部分**:许多预训练模型,如BERT(Bidirectional Encoder Representations from Transformers)等,基于 Transformer Encoder 架构。在这些模型中,Encoder 的输出可以被用作下游任务的输入,从而提供丰富的语义信息。\n",
|
|||
|
|
"\n",
|
|||
|
|
"具体地来说,有许多任务可以仅使用 Encoder 完成。以下是一些常见的例子:\n",
|
|||
|
|
"\n",
|
|||
|
|
"> **情感分析**:情感分析任务旨在确定文本的情感倾向,如正面、负面或中性。在这种任务中,我们只需将输入文本编码为一个语义表示,然后通过该表示来预测文本的情感倾向,而不需要生成任何文本输出。<br><br>\n",
|
|||
|
|
"> **文本分类**:文本分类任务要求将文本分配到预定义的类别中。例如,垃圾邮件过滤、新闻分类等。在这种任务中,我们可以使用 Encoder 将输入文本编码为一个语义表示,然后通过该表示来进行分类预测。<br><br>\n",
|
|||
|
|
"> **命名实体识别**:命名实体识别任务要求在文本中识别和分类命名实体,如人名、地名、组织名等。在这种任务中,我们可以使用 Encoder 将输入文本编码为一个语义表示,然后通过该表示来对命名实体进行识别。<br><br>\n",
|
|||
|
|
"> **关系抽取**:关系抽取任务旨在从文本中提取实体之间的关系。例如,在医学文本中,从病历中抽取药物与疾病之间的关系。在这种任务中,我们可以使用 Encoder 将输入文本编码为一个语义表示,然后通过该表示来提取实体之间的关系。<br><br>\n",
|
|||
|
|
"> **文本生成的预训练**:在预训练语言模型中,我们可以使用 Encoder 将输入文本编码为一个语义表示,然后利用这个语义表示来预测下一个词或者生成文本序列。这在自然语言生成任务中非常有用,如对话生成、摘要生成等。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5dc6bea2-aa9b-4d2c-ab43-af9b7aa8fe9c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"Encoder 在 Transformer 架构中扮演着至关重要的角色,其作用是将输入序列转换为一系列语义表示,以便后续任务的处理和预测。Encoder 的结构包括多个相同的层,每个层都由自注意力机制和前馈神经网络组成,其中自注意力机制用于捕捉输入序列中的全局依赖关系,前馈神经网络用于对每个位置的特征进行非线性变换和提取。Encoder 作为 Transformer 架构的核心组件之一,承担着将输入序列转换为语义表示的重要任务。它的结构设计体现了并行计算、信息流动、层级表示和模块化设计等关键原则,使得模型能够更好地理解和表示输入数据,并在各种文本相关任务中取得优异的性能。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "dd92da34-3e34-4a0c-ab28-5af265563e52",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## 2.3 Decoder结构解析"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2c710470-a356-4d18-8f6b-462cdbce760c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"### 2.3.1 完整Transformer与Decoder-Only结构的数据流"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "1fa43624-0cc1-4504-b999-4acb38277062",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"Decoder结构是Transformer中至关重要的结构,这不仅仅是因为Decoder是专门设计用来处理输出序列生成的结构,更是因为Decoder的用法非常灵活并且复杂。在之前Encoder的课程中,我们讲解了数个Encoder-Only结构的使用场景,在Transformer丰富的用法中,我们还可以选择使用按照完整的encoder+decoder结构、或者Decoder-Only架构——\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **使用完整Transformer结构的任务**\n",
|
|||
|
|
">\n",
|
|||
|
|
">完整的Transformer结构包括编码器(encoder)和解码器(decoder)部分,通常用于需要将一个序列映射到另一个序列的任务,如:\n",
|
|||
|
|
">\n",
|
|||
|
|
">1. **机器翻译(Machine Translation):**\n",
|
|||
|
|
"> - 将源语言的句子翻译成目标语言的句子。例如将英文句子翻译成中文句子。<br><br>\n",
|
|||
|
|
">\n",
|
|||
|
|
">2. **文本摘要(text Summarization):**\n",
|
|||
|
|
"> - 将长文本总结为简短的摘要,例如将新闻文章总结为简短的新闻标题。<br><br>\n",
|
|||
|
|
">\n",
|
|||
|
|
">3. **图像字幕生成(Image Captioning):**\n",
|
|||
|
|
"> - 为给定的图像生成描述性的文字(图生文)<br><br>\n",
|
|||
|
|
">\n",
|
|||
|
|
">4. **文本到语音(text-to-Speech, TTS):**\n",
|
|||
|
|
"> - 将文本转换为语音信号,比如将输入文本转换为自然的语音输出。<br><br>\n",
|
|||
|
|
">\n",
|
|||
|
|
">5. **问答系统(Question Answering):**\n",
|
|||
|
|
"> - 根据上下文回答用户的问题,或者给定一段文本,回答其中提到的具体问题。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "aec84123-36ed-472e-a6bf-f3eb62c61ea8",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **只使用decoder结构的任务**\n",
|
|||
|
|
"> \n",
|
|||
|
|
"> 只使用decoder结构(通常被称为自回归模型或生成模型)适用于需要从部分输入生成完整序列的任务,如:\n",
|
|||
|
|
"> \n",
|
|||
|
|
"> 1. **大语言模型(Language Modeling):**\n",
|
|||
|
|
"> - 任务描述:预测给定文本序列中的下一个词或字符,例如GPT系列模型用于生成连续的文本段落(当然,并不是所有的大语言模型都是decoder-only结构)。<br><br>\n",
|
|||
|
|
"> \n",
|
|||
|
|
"> 2. **文本生成(text Generation):**\n",
|
|||
|
|
"> - 任务描述:根据部分输入生成完整的文本,比如根据开头的一句话生成一篇文章或故事,根据部分诗句生成完整的诗歌。<br><br>\n",
|
|||
|
|
"> \n",
|
|||
|
|
"> 3. **代码补全(Code Completion):**\n",
|
|||
|
|
"> - 任务描述:根据部分输入代码生成完整的代码段。<br><br>\n",
|
|||
|
|
">\n",
|
|||
|
|
"> 4. **对话生成(Dialogue Generation)**:\n",
|
|||
|
|
"> - 任务描述:根据对话历史生成下一句回复。<br><br>\n",
|
|||
|
|
">\n",
|
|||
|
|
"> 5. **问答系统(Question Answering):**\n",
|
|||
|
|
"> - 根据上下文回答用户的问题,或者给定一段文本,回答其中提到的具体问题。\n",
|
|||
|
|
"\n",
|
|||
|
|
"这些任务利用Transformer的强大表示能力,通过不同的结构来适应不同的应用场景。完整的Transformer结构适合需要**从一个序列转换到另一个序列的任务**,一般我们会在需要**高度依赖原始数据信息、尤其是需要语义的转译**的时候使用这种结构,因为Encoder会有非常好的语义和数据信息解析功能,可以帮助架构更好地吸收原始数据的信息;而只使用decoder结构的模型适合**需要生成连续序列的任务**,当我们更强调基于原有的信息基础上进行“创新、创造、续写”,而对原有的数据的依赖程度不是那么高时,我们会选择decoder-only结构。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7fce11bd-8f08-4705-be60-d5468cf77bf2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"当然了,一个任务对于原始信息的依赖程度是否高,这是可以讨论、甚至因人而异的判断。像机器翻译任务,最好能够原封不动将原始数据的语义表达出来,就会显然更适合完整的Transformer结构,但代码补全这样的、文本生成这样更强调续写的任务,就会更偏向于decoder-only,然而对于像大语言模型、对话系统这样无法明确判断出“多大程度依赖于原始输入信息的”任务,就会依据算法创造者的不同选择有不同的状态。例如大语言模型,GPT、llama等等大模型就是decoder-only结构,BERT模型是encoder-only结构、T5(text-to-text Transfer Transformer)和BART(Bidirectional and Auto-Regressive Transformers)模型则是使用了完整的Transformer结构。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "08e5d5a5-0dcd-4ec3-8117-776380950221",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"为什么在Decoder篇章一开始,我们就要讲解不同的任务呢?与Encoder不同的是,Decoder结构在不同的任务中承担不同的角色、存在不同的网络架构、不同的训练模式以及不同的数据流,因此我们需要理解不同的任务、才能知道Decoder结构究竟是什么样的。接下来,就让我来看看Transformer完整结构与Decoder-only结构下的具体情况。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "37618e41-7a94-4ca3-a17a-e4664299de93",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/image-1.png\" alt=\"描述文字\" width=\"400\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "44cd7f20-67e3-4faf-a45b-62c1bda7f210",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"### 2.3.2 Encoder-Decoder结构中的Decoder"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4a566223-dafb-4089-95b5-9989b9c24223",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"上图所示的就是Transformer完整结构下的Decoder block的结构。之前在讲解Trans整体架构时我们就提到过,Decoder其实与Encoder非常类似,从图上来看,整个Decoder的结构包括如下核心内容:\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. **输入与teacher forcing机制**<br><br>\n",
|
|||
|
|
"Decoder的输入是**滞后1个单位的标签矩阵**(shifted right outputs),我们要将真实标签输入给模型,并且让真实标签指导模型的学习与预测,这种让模型通过正确的标签来学习的流程在Transformer中被称之为是teacher forcing强制教学机制。\n",
|
|||
|
|
"\n",
|
|||
|
|
"2. **Embedding与位置编码**<br><br>\n",
|
|||
|
|
"标签矩阵首先通过嵌入层(embedding)转换成固定大小的向量。就像 Encoder 一样,Decoder 也会对这些嵌入向量添加位置编码,以包含序列中的位置信息。但这里需要注意的是,输入到Decoder层中的sequence_length维度可以与输入到Encoder中的sequence_length维度不一致。\n",
|
|||
|
|
"> Encoder与Decoder架构中的Seq_len可以不一致,这其实非常好理解。假设是英文翻译成中文的机器翻译任务,为了表达相同的语义,英文句子长度与中文句子长度都应该不受限制,尽量精准地表达;不同语言、不用序列之间的规律本来就各不相同,有的语言比较高效、有的语言则追求尽量详尽,因此要求Encoder和Decoder的输入的数据长度相同是强人所难。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "8e319bc2-b5e8-467d-a586-fc656e7e2951",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th></th><th>x1</th><th>x2</th><th>x3</th><th>x4</th><th>x5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>也</td><td>0.1032</td><td>0.1477</td><td>0.7023</td><td>0.7224</td><td>0.2768</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>最坏的</td><td>0.4263</td><td>0.4615</td><td>0.5169</td><td>0.7584</td><td>0.8388</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>it</td><td>0.6523</td><td>0.1298</td><td>0.4576</td><td>0.9834</td><td>0.1876</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>worst</td><td>0.1543</td><td>0.9271</td><td>0.3821</td><td>0.6745</td><td>0.4823</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b83c63be-bf40-4faa-88b0-9c3aff0a76f5",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> 不过这里就会引发无穷的问题,比如结构不相同的矩阵如何在同一个注意力机制中运行?最终输出的矩阵结构是什么?Decoder后续的结构会帮助我们解决这些问题。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2511dfb9-4380-486a-b005-f7a81255672d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"3. **带掩码的自注意力层**(Masked Self-Attention)<br><br>\n",
|
|||
|
|
"Decoder 的自注意力层在功能上与 Encoder 的自注意力层类似,它允许 Decoder 关注到之前所有生成的词。然而,为了防止在生成当前词时使用未来的信息(即避免信息泄露),使用了所谓的“掩码”技术(Masking)。这种技术通过将未来位置的值设置为负无穷大(在 softmax 操作前),使得这些位置的影响力为零。\n",
|
|||
|
|
"\n",
|
|||
|
|
"4. **编码器-解码器注意力层**(Encoder-Decoder Attention)<br><br>\n",
|
|||
|
|
"这一层是 Decoder 特有的注意力层,它就是位于图像上、Decoder结构中间的那个注意力机制层。它允许 Decoder 的每个位置关注 Encoder 的全部输出。具体来说,这一层的查询(Q)来自前一层 Decoder 的输出,而键(K)和值(V)则来自 Encoder 的输出。通过这种方式,Decoder 能够利用输入序列中的相关信息来帮助生成正确的输出序列。\n",
|
|||
|
|
"\n",
|
|||
|
|
"5. **前馈神经网络网络、层归一化和残差链接**<br><br>\n",
|
|||
|
|
"与 Encoder 中的前馈网络、层归一化以及残差链接相同,每个 Decoder 层包含一个前馈网络,该网络对每个位置应用相同的全连接层。这个网络通常包含两个线性变换,并在中间加入了一个激活函数,如 ReLU 或 GELU。\n",
|
|||
|
|
"\n",
|
|||
|
|
"**在这些结构当中,我们较为陌生的三个结构是“Teacher Forcing”、“带掩码的自注意力层”以及“编码器-解码器注意力层”**,我们先来了解一下数据滞后操作以及teacher forcing制度。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "06a325e9-1aa9-44c5-9be2-59b79ff40835",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"#### 2.3.2.1 输入与teacher forcing"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ec1fe673-ba73-43da-b1ce-5ad9a67fdbcc",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"Decoder的输入是**滞后1个单位的标签矩阵**(shifted right outputs),我们要将真实标签输入给模型,并且让真实标签指导模型的学习与预测,这种让模型通过正确的标签来学习的流程在Transformer中被称之为是teacher forcing强制教学机制。接下来让我们展开仔细讲讲。\n",
|
|||
|
|
"\n",
|
|||
|
|
"> **shift right操作**\n",
|
|||
|
|
"\n",
|
|||
|
|
"首先,在序列到序列任务中,我们会将标签矩阵进行滞后操作(shift)。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 1,
|
|||
|
|
"id": "3cd5cf47-975e-4fc2-81ab-8fad1d4c2a8b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 创建DataFrame\n",
|
|||
|
|
"df = pd.DataFrame({\n",
|
|||
|
|
" \"值\": [0.1543, 0.2731, 0.3627, 0.4812, 0.5238]\n",
|
|||
|
|
"})"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 2,
|
|||
|
|
"id": "b647d7bd-8278-4a30-bd4b-ad1ec60d181f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/html": [
|
|||
|
|
"<div>\n",
|
|||
|
|
"<style scoped>\n",
|
|||
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
|
" vertical-align: middle;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe tbody tr th {\n",
|
|||
|
|
" vertical-align: top;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe thead th {\n",
|
|||
|
|
" text-align: right;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"</style>\n",
|
|||
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
|
" <thead>\n",
|
|||
|
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
|
" <th></th>\n",
|
|||
|
|
" <th>值</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </thead>\n",
|
|||
|
|
" <tbody>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>0</th>\n",
|
|||
|
|
" <td>0.1543</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>1</th>\n",
|
|||
|
|
" <td>0.2731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>2</th>\n",
|
|||
|
|
" <td>0.3627</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>3</th>\n",
|
|||
|
|
" <td>0.4812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>4</th>\n",
|
|||
|
|
" <td>0.5238</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </tbody>\n",
|
|||
|
|
"</table>\n",
|
|||
|
|
"</div>"
|
|||
|
|
],
|
|||
|
|
"text/plain": [
|
|||
|
|
" 值\n",
|
|||
|
|
"0 0.1543\n",
|
|||
|
|
"1 0.2731\n",
|
|||
|
|
"2 0.3627\n",
|
|||
|
|
"3 0.4812\n",
|
|||
|
|
"4 0.5238"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 2,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"df"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7b2b83b2-0171-48d4-9015-72ddca45fcf2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"对序列来说滞后是一种常见的操作👇是指将原有的序列向未来、向正向顺序的方向挪动位置,留出空值的行为:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 3,
|
|||
|
|
"id": "a43c609e-abbf-4fde-add8-4e5c982a4804",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/html": [
|
|||
|
|
"<div>\n",
|
|||
|
|
"<style scoped>\n",
|
|||
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
|
" vertical-align: middle;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe tbody tr th {\n",
|
|||
|
|
" vertical-align: top;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe thead th {\n",
|
|||
|
|
" text-align: right;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"</style>\n",
|
|||
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
|
" <thead>\n",
|
|||
|
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
|
" <th></th>\n",
|
|||
|
|
" <th>值</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </thead>\n",
|
|||
|
|
" <tbody>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>0</th>\n",
|
|||
|
|
" <td>NaN</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>1</th>\n",
|
|||
|
|
" <td>0.1543</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>2</th>\n",
|
|||
|
|
" <td>0.2731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>3</th>\n",
|
|||
|
|
" <td>0.3627</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>4</th>\n",
|
|||
|
|
" <td>0.4812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </tbody>\n",
|
|||
|
|
"</table>\n",
|
|||
|
|
"</div>"
|
|||
|
|
],
|
|||
|
|
"text/plain": [
|
|||
|
|
" 值\n",
|
|||
|
|
"0 NaN\n",
|
|||
|
|
"1 0.1543\n",
|
|||
|
|
"2 0.2731\n",
|
|||
|
|
"3 0.3627\n",
|
|||
|
|
"4 0.4812"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 3,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"df.shift(1) #挪动一个位置,被叫做滞后1"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 4,
|
|||
|
|
"id": "26e6959a-ce1c-4be9-96fb-feb33330efd4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/html": [
|
|||
|
|
"<div>\n",
|
|||
|
|
"<style scoped>\n",
|
|||
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
|
" vertical-align: middle;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe tbody tr th {\n",
|
|||
|
|
" vertical-align: top;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"\n",
|
|||
|
|
" .dataframe thead th {\n",
|
|||
|
|
" text-align: right;\n",
|
|||
|
|
" }\n",
|
|||
|
|
"</style>\n",
|
|||
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
|
" <thead>\n",
|
|||
|
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
|
" <th></th>\n",
|
|||
|
|
" <th>值</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </thead>\n",
|
|||
|
|
" <tbody>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>0</th>\n",
|
|||
|
|
" <td>NaN</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>1</th>\n",
|
|||
|
|
" <td>NaN</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>2</th>\n",
|
|||
|
|
" <td>0.1543</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>3</th>\n",
|
|||
|
|
" <td>0.2731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>4</th>\n",
|
|||
|
|
" <td>0.3627</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </tbody>\n",
|
|||
|
|
"</table>\n",
|
|||
|
|
"</div>"
|
|||
|
|
],
|
|||
|
|
"text/plain": [
|
|||
|
|
" 值\n",
|
|||
|
|
"0 NaN\n",
|
|||
|
|
"1 NaN\n",
|
|||
|
|
"2 0.1543\n",
|
|||
|
|
"3 0.2731\n",
|
|||
|
|
"4 0.3627"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 4,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"df.shift(2) #也可以挪动多个位置"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e045924f-ce18-4320-b5d9-6a129e588ce2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"当表现为编码前的序列时,就是从[y1, y2, y3, y4]变成[NaN, y1, y2, y3, y4],因此这个过程也被叫做“向右滞后”(shift right),其实代表的是在序列的最前方腾挪出位置,将已有的序列向后挤。在Transformer当中,我们一般会为解码器的输入标签添加起始标记\"SOS\"(start of sequence),并将这个起始标记作为标签序列的第一行,最终构成[\"sos\", y1, y2, y3, y4]这样的序列。当进行embedding编码后,会呈现为👇"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "507f201f-c78e-4ed3-b0de-d3dba51106b6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>7</td><td>it</td><td>0.6523</td><td>0.1298</td><td>0.4576</td><td>0.9834</td><td>0.1876</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>8</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>9</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>10</td><td>worst</td><td>0.1543</td><td>0.9271</td><td>0.3821</td><td>0.6745</td><td>0.4823</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>11</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>12</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "09128284-e32f-4c51-831e-2d342d122423",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"起始标记(Start of Sequence,SOS)和结束标记(End of Sequence,EOS)在序列到序列(Seq2Seq)任务中起着重要的作用,特别是在自然语言处理(NLP)和机器翻译等任务中。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 起始标记(SOS)的意义\n",
|
|||
|
|
"> 1. **标识序列的开始**:SOS标记用于指示解码器开始生成序列。这在训练和推理过程中都非常重要。<br><br>\n",
|
|||
|
|
"> 2. **初始化解码器**:在解码阶段,解码器需要一个初始输入来开始生成输出序列。SOS标记作为解码器的第一个输入,帮助其启动生成过程。<br><br>\n",
|
|||
|
|
"> 3. **模型一致性**:通过在每个输出序列的开头添加SOS标记,模型在训练时可以学到序列生成的起点,从而在推理时保持一致的生成过程。\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 结束标记(EOS)的意义\n",
|
|||
|
|
"> 1. **标识序列的结束**:EOS标记用于指示生成的序列在何处结束。这对于模型在推理阶段停止生成非常重要。<br><br>\n",
|
|||
|
|
"> 2. **控制生成长度**:在没有固定长度的输出序列中,EOS标记告诉模型何时停止生成,而不需要生成固定数量的时间步。这使得模型可以处理变长序列。<br><br>\n",
|
|||
|
|
"> 3. **训练终止条件**:在训练过程中,模型学会在适当的时候生成EOS标记,从而正确地结束序列。\n",
|
|||
|
|
"\n",
|
|||
|
|
"假设我们有一个输入序列和一个目标序列:\n",
|
|||
|
|
"\n",
|
|||
|
|
"- 输入序列:`y = [\"这\", \"是\", \"最\", \"好\", \"的\", \"时\", \"代\"]`\n",
|
|||
|
|
"- 目标序列:`y = [\"it\", \"was\", \"the\", \"best\", \"of\", \"times\"]`\n",
|
|||
|
|
"\n",
|
|||
|
|
"在Seq2Seq任务的训练过程中,由于Decoder结构会需要输入标签,因此我们必须要准备三种不同的数据,并进行如下的处理:\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. **编码器输入**:`y`不需要添加起始标记和结束标记。\n",
|
|||
|
|
"2. **解码器输入的标签**:在目标序列前添加起始标记(SOS)。\n",
|
|||
|
|
"3. **解码器用来计算损失函数的标签**:在目标序列末尾添加结束标记(EOS)。\n",
|
|||
|
|
"\n",
|
|||
|
|
"处理后的序列就是:\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **编码器输入**:`[\"这\", \"是\", \"最\", \"好\", \"的\", \"时\", \"代\"]`\n",
|
|||
|
|
"- **解码器输入的标签**:`[\"SOS\", \"it\", \"was\", \"the\", \"best\", \"of\", \"times\"]`\n",
|
|||
|
|
"- **解码器用来计算损失函数的标签**:`[\"it\", \"was\", \"the\", \"best\", \"of\", \"times\", \"EOS\"]`\n",
|
|||
|
|
"\n",
|
|||
|
|
"以下是一个简化的示例代码,展示如何使用PyTorch为序列添加起始标记和结束标记,并进行词嵌入:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 5,
|
|||
|
|
"id": "401f15ca-4b9e-4654-ad37-d86b3e8df629",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"Decoder Input (with SOS): tensor([0, 2, 3, 4, 5, 6])\n",
|
|||
|
|
"Decoder Output (with EOS): tensor([2, 3, 4, 5, 6, 1])\n",
|
|||
|
|
"Embedded Decoder Input: tensor([[ 0.6930, 0.0392, 0.6529, 1.3837],\n",
|
|||
|
|
" [ 0.1320, -1.5171, -0.2337, -1.1682],\n",
|
|||
|
|
" [ 0.2960, -1.3404, 0.1997, 0.8595],\n",
|
|||
|
|
" [-0.0201, -0.0039, 1.2342, -1.2684],\n",
|
|||
|
|
" [ 0.4403, 0.9309, -0.3682, 0.6179],\n",
|
|||
|
|
" [-0.4487, -0.2147, -0.5202, 1.3910]], grad_fn=<EmbeddingBackward0>)\n",
|
|||
|
|
"Embedded Decoder Output: tensor([[ 0.1320, -1.5171, -0.2337, -1.1682],\n",
|
|||
|
|
" [ 0.2960, -1.3404, 0.1997, 0.8595],\n",
|
|||
|
|
" [-0.0201, -0.0039, 1.2342, -1.2684],\n",
|
|||
|
|
" [ 0.4403, 0.9309, -0.3682, 0.6179],\n",
|
|||
|
|
" [-0.4487, -0.2147, -0.5202, 1.3910],\n",
|
|||
|
|
" [-0.0288, 2.0398, -0.3713, 1.0762]], grad_fn=<EmbeddingBackward0>)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import torch\n",
|
|||
|
|
"import torch.nn as nn\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 假设词汇表大小(包括特殊标记如SOS和EOS)\n",
|
|||
|
|
"vocab_size = 10\n",
|
|||
|
|
"embedding_dim = 4\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 创建嵌入层\n",
|
|||
|
|
"embedding_layer = nn.Embedding(vocab_size, embedding_dim)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 假设索引0是SOS,索引1是EOS\n",
|
|||
|
|
"SOS_token = 0\n",
|
|||
|
|
"EOS_token = 1\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 目标序列的索引表示\n",
|
|||
|
|
"target_sequence = [2, 3, 4, 5, 6] # 假设 \"it\", \"was\", \"the\", \"best\", \"of\"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 添加起始标记和结束标记\n",
|
|||
|
|
"decoder_input = [SOS_token] + target_sequence\n",
|
|||
|
|
"decoder_output = target_sequence + [EOS_token]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 转换为张量\n",
|
|||
|
|
"decoder_input_tensor = torch.tensor(decoder_input)\n",
|
|||
|
|
"decoder_output_tensor = torch.tensor(decoder_output)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 嵌入\n",
|
|||
|
|
"embedded_decoder_input = embedding_layer(decoder_input_tensor)\n",
|
|||
|
|
"embedded_decoder_output = embedding_layer(decoder_output_tensor)\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"Decoder Input (with SOS):\", decoder_input_tensor)\n",
|
|||
|
|
"print(\"Decoder Output (with EOS):\", decoder_output_tensor)\n",
|
|||
|
|
"print(\"Embedded Decoder Input:\", embedded_decoder_input)\n",
|
|||
|
|
"print(\"Embedded Decoder Output:\", embedded_decoder_output)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5401cb15-6138-4646-adeb-5c95a77edfab",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> **teacher force**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "acbd0c53-2e54-4cde-8c62-f9f599d2eddd",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"如果你非常熟悉序列模型的预测(比如时间序列的预测),那你应该早就见过很多使用真实标签+特征一起来指导模型学习的操作;例如,时间序列中存在“带标签的滑窗”技术。“带标签的滑窗”是一种特征矩阵构建方法,**它会将可以使用的那部分标签作为其中一个特征,和其他特征concat在一起构建特征矩阵**。使用带标签的滑窗后,特征信息与标签信息会一起被输入给模型,模型将会结合特征和可使用的标签两部分信息来共同决策。\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"在Transformer中,这种对标签的使用从时间序列数据扩大到了任意序列数据(对时间数据而言,可使用的标签就是当前预测时间点之前所有的时间点,对其他序列数据而言,例如文字数据,可使用的标签就是当前预测的文字位置之前的所有文字),并且将这种技巧从时间序列预测拓展到了序列到序列任务(seq2seq)。\n",
|
|||
|
|
"\n",
|
|||
|
|
"然而需要注意的是,时间序列任务是一种使用过去的信息来预测未来的任务,通常是利用一个序列的前半段数据来预测同一序列的后半段数据。**这意味着时间序列预测更多地依赖于生成式模型,旨在根据已有数据生成未来的数据点**。而Seq2Seq任务(序列到序列任务)并不总是遵循这种模式。例如,在机器翻译任务中,模型的目标是将一个语言的句子转换成另一种语言的句子,这并不是通过预测同一序列的未来部分来实现的。因此,时间序列预测更接近于生成式任务,而不是典型的序列到序列任务。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7f4cc298-f61e-4659-9fad-d362d57651ec",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **时间序列任务/生成式任务**:同一张表、过去预测未来\n",
|
|||
|
|
"\n",
|
|||
|
|
"| 时间点 | 值 |\n",
|
|||
|
|
"|--------|----------|\n",
|
|||
|
|
"| 1 | 0.1543 |\n",
|
|||
|
|
"| 2 | 0.2731 |\n",
|
|||
|
|
"| 3 | 0.3627 |\n",
|
|||
|
|
"| 4 | 0.4812 |\n",
|
|||
|
|
"| 5 | 0.5238 |"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a99d7333-0e93-46c9-924a-60c9020db6d7",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **Encoder-decoder下的seq2seq任务**:两个序列大概率不是一张表,是用一张表去预测另一张表"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e8779e34-45c3-4e00-bbb5-cfc49d37e8d6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "501e54ea-98a9-4b6a-bdf3-4a4aad8c5d36",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"因此在teacher force所强调的使用标签是**需要将特征矩阵和标签矩阵的信息融合后**再进行训练。以上面两张表单为例,设——\n",
|
|||
|
|
"> 原始序列y = [\"这\",\"是\",\"最\",\"好\",\"的\",\"时\",\"代\"]<br><br>\n",
|
|||
|
|
"> 真实标签y = [\"it\", \"was\", \"the\", \"best\", \"of\", \"times\"]<br><br>\n",
|
|||
|
|
"> 编码器输出的预测结果为yhat,添加过初始词/结束词、经过embedding的矩阵为ebd_X和ebd_y\n",
|
|||
|
|
"\n",
|
|||
|
|
"那我们实际走的<font color=\"red\">**训练流程**</font>是:\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - **第一步,输入ebd_X & ebd_y[0] >> 输出yhat[0],对应真实标签y[0]**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "8f388f17-a701-43f8-9cdb-d530c8f15fb7",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>对应</p>\n",
|
|||
|
|
" ➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "522677d2-9c70-4685-ac3a-4a9ef13bdd9a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> - **第二步,输入ebd_X & ebd_y[:1] >> 输出yhat[1],对应真实标签y[1]**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "545c40ee-1c02-4832-832f-62266ea32802",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>对应</p>\n",
|
|||
|
|
" ➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "206e2daf-d007-46af-8d09-d6cbe96a3d57",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> - **第三步,输入ebd_X & ebd_y[:2] >> 输出yhat[2],对应真实标签y[2]**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c366bc4b-95c1-450e-99c3-f9462641f898",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>对应</p>\n",
|
|||
|
|
" ➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "9f2162a6-2a87-4c61-b5a2-d5fdf08e16ba",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"……以此类推下去。不难发现,在这个流程中我们实现了【利用序列A + 序列B的前半段预测序列B的后半段】,这样的方式既没有泄露真实的标签,又能够为预测下一个词提供最大程度的准确的信息,这就是teacher forcing的本质。**在训练过程中,这个流程通过掩码自注意力机制+编码器-解码器注意力层合作的方式实现了并行**,所以Seq2Seq任务在训练时实际上并不是按照时间步顺序来运行,反而呈现为一次性输入特征矩阵+标签矩阵后,一次性获得整个预测的序列。\n",
|
|||
|
|
"\n",
|
|||
|
|
"然而在测试和推理过程中可就不一样。**在测试和推理的过程中,我们并没有真实的标签矩阵,因此需要将上一个时间步预测的结果作为Decoder需要的输入**。具体来看,在<font color=\"red\">**测试流程**</font>中:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "81e83662-e129-426e-9a8f-cedd78e00e7b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **第一步,输入 ebd_X & sos >> 输出时间步1的预测标签,对应真实标签y[0]**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td style=\"text-align: center;\">\n",
|
|||
|
|
" ➕\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder:sos编码序列</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "0a4a15a7-543c-4ff5-aaeb-8cf5476363d6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>对应</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d91564d9-3d48-4857-99ad-90f1edc7d2af",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> - **第二步,输入ebd_X & yhat[:1] >> 输出时间步2的标签,对应真实标签y[1]**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td style=\"text-align: center;\">\n",
|
|||
|
|
" ➕\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder:yhat预测标签<br>(加入上一个时间步的预测结果)</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4f5804cc-87d0-4246-b9c2-836c1a04e15e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>yyy</td><td>0.3074</td><td>0.8774</td><td>0.0364</td><td>0.0649</td><td>0.4704</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>对应</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "0530701a-b7b3-4d12-8ccf-f8980299ea2d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> - **第三步,输入ebd_X & yhat[:2] >> 输出索引为3的标签,对应真实标签y[1]**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td style=\"text-align: center;\">\n",
|
|||
|
|
" ➕\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder:yhat预测标签<br>(加入上一个时间步的预测结果)</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>yyy</td><td>0.3074</td><td>0.8774</td><td>0.0364</td><td>0.0649</td><td>0.4704</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "443b583c-6e6d-456b-9c88-e96f7d34e8d6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>yyy</td><td>0.2753</td><td>0.2921</td><td>0.4599</td><td>0.6449</td><td>0.1852</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>对应</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4eeb6672-907a-4a8a-9091-489e78a0076a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"很显然,这是一个自回归的流程。在实际代码实现时,这个过程是线性的、必须按照一个字、一个字的方式来预测,但Transformer本身并不提供像RNN和LSTM那样逐步处理样本的结构,因此推理流程中,我们需要写循环代码来完成推理的过程。每一步生成一个新词,并将其作为输入添加到序列中,直到生成结束标记 \"EOS\" 或达到最大长度为止。这个流程会极大地限制生成类算法的预测速度,因此现在也有越来越多的技术来帮助我们改进这个环节,但是使用最多的依然是最经典的自回归策略。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a9217f36-90e3-4272-a78e-63f5a8c293fd",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"现在你已经知道在seq2seq任务中Transformer处理训练数据的流程了——\n",
|
|||
|
|
"\n",
|
|||
|
|
"> 原始序列y = [\"这\",\"是\",\"最\",\"好\",\"的\",\"时\",\"代\"]<br><br>\n",
|
|||
|
|
"> 真实标签y = [\"it\", \"was\", \"the\", \"best\", \"of\", \"times\"]<br><br>\n",
|
|||
|
|
"> 编码器输出的预测结果为yhat,添加过初始词/结束词、经过embedding的矩阵为ebd_y和ebd_y\n",
|
|||
|
|
"\n",
|
|||
|
|
"那我们实际走的<font color=\"red\">**训练流程**</font>是:\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - **第一步,输入ebd_X & ebd_y[0] >> 输出yhat[0],对应真实标签y[0]**<br><br>\n",
|
|||
|
|
"> - **第二步,输入ebd_X & ebd_y[:1] >> 输出yhat[1],对应真实标签y[1]**<br><br>\n",
|
|||
|
|
"> - **第三步,输入ebd_X & ebd_y[:2] >> 输出yhat[2],对应真实标签y[2]**<br><br>\n",
|
|||
|
|
"> 以此类推……\n",
|
|||
|
|
"\n",
|
|||
|
|
"在讲解这个过程时我们曾经提到,**在训练过程中,这个流程通过掩码自注意力机制+编码器-解码器注意力层合作的方式实现了并行**,所以Seq2Seq任务在训练时实际上并不是按照时间步顺序来运行,反而呈现为一次性输入特征矩阵+标签矩阵后,一次性获得整个预测的序列。\n",
|
|||
|
|
"\n",
|
|||
|
|
"实际流程中是并行,就意味着我们需要将完整的yhat输入给Transformer,在这里就会存在两个问题:\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. 并行是如何实现的?\n",
|
|||
|
|
"2. 将完整的yhat输入给Transformer,是如何避免标签泄漏的?\n",
|
|||
|
|
"\n",
|
|||
|
|
"整个Decoder结构中、掩码注意力机制、编码器解码器注意力层共同解决了这两个问题。接下来让我们一起来看看带掩码的注意力机制。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "6a065205-31a5-4091-a21d-5ccee92bb646",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"#### 2.3.2.2 掩码注意力机制"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "0fc1dbd4-96d6-4bd4-b2e6-3b59db7365d1",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在Transformer的Decoder中,掩码自注意力(Masked Self-Attention)确保在生成当前时间步的输出时,模型不能查看未来的输入。这是通过在注意力机制计算过程中应用一个掩码实现的,该掩码有效地将未来位置的注意力得分设置为非常低的值(通常是负无穷),这样模型就无法在预测当前词时利用未来的信息。这种方法确保了生成的输出是自回归的,即每个输出仅依赖于之前的输出,而不是未来的输入。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "76a70df6-fb90-4d23-873c-7149d6f2120d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"掩码自注意力机制是**通过修改基本的注意力机制公式**来实现的,基本的注意力公式如下:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$Attention(Q,K,V) = softmax(\\frac{QK^{T}}{\\sqrt{d_k}})V$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"在这个公式的基础上引入掩码功能,则涉及到下面三个改变:\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. 在计算 $QK^T$ 的点积后,但在应用softmax函数之前,**掩码自注意力机制通过使用一个掩码矩阵来修改这个点积结果**。这个掩码矩阵有特定的结构:对于不应该被当前位置注意的所有位置(即未来的位置),掩码会赋予一个非常大的负值(如负无穷)。\n",
|
|||
|
|
"\n",
|
|||
|
|
"2. 应用softmax函数:**当softmax函数应用于经过掩码处理的点积矩阵时,那些被掩码覆盖的位置(即未来的位置)的权重实际上会接近于零**。这是因为 e 的非常大的负数次幂几乎为零。\n",
|
|||
|
|
"\n",
|
|||
|
|
"3. 结果的动态调整:这样处理后,每个位置的输出在计算时只会考虑到它前面的位置或当前位置的信息,确保了生成的每一步都不会“看到”未来的数据。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fdb4c775-b1a7-4c0a-8f30-b4963c562b7f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/image-1.png\" alt=\"描述文字\" width=\"400\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b556be60-89fc-4e55-a22d-e0f3fd0517c3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"我们可以来看具体的矩阵——\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **没有掩码时的$QK^T$点积**(此时的Q、K都是从输出矩阵中生成的)<br><br>\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"QK^T = \\begin{bmatrix}\n",
|
|||
|
|
" q_1 \\cdot k_1^T & q_1 \\cdot k_2^T & \\cdots & q_1 \\cdot k_n^T \\\\\n",
|
|||
|
|
" q_2 \\cdot k_1^T & q_2 \\cdot k_2^T & \\cdots & q_2 \\cdot k_n^T \\\\\n",
|
|||
|
|
" \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
|
|||
|
|
" q_n \\cdot k_1^T & q_n \\cdot k_2^T & \\cdots & q_n \\cdot k_n^T\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4d5aa2ef-6533-4fec-8b8f-b1b8462c8033",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **没有掩码时softmax函数结果**<br><br>\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"softmax(QK^T) = \\begin{bmatrix}\n",
|
|||
|
|
" \\frac{e^{q_1 \\cdot k_1^T}}{\\sum_{j=1}^n e^{q_1 \\cdot k_j^T}} & \\frac{e^{q_1 \\cdot k_2^T}}{\\sum_{j=1}^n e^{q_1 \\cdot k_j^T}} & \\cdots & \\frac{e^{q_1 \\cdot k_n^T}}{\\sum_{j=1}^n e^{q_1 \\cdot k_j^T}} \\\\\n",
|
|||
|
|
" \\frac{e^{q_2 \\cdot k_1^T}}{\\sum_{j=1}^n e^{q_2 \\cdot k_j^T}} & \\frac{e^{q_2 \\cdot k_2^T}}{\\sum_{j=1}^n e^{q_2 \\cdot k_j^T}} & \\cdots & \\frac{e^{q_2 \\cdot k_n^T}}{\\sum_{j=1}^n e^{q_2 \\cdot k_j^T}} \\\\\n",
|
|||
|
|
" \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
|
|||
|
|
" \\frac{e^{q_n \\cdot k_1^T}}{\\sum_{j=1}^n e^{q_n \\cdot k_j^T}} & \\frac{e^{q_n \\cdot k_2^T}}{\\sum_{j=1}^n e^{q_n \\cdot k_j^T}} & \\cdots & \\frac{e^{q_n \\cdot k_n^T}}{\\sum_{j=1}^n e^{q_n \\cdot k_j^T}}\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "adcac501-c476-4a9d-b13a-f1d3d7942615",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **有掩码时,我们使用的掩码矩阵**<br><br>\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"M = \\begin{bmatrix}\n",
|
|||
|
|
" 0 & -\\infty & -\\infty & \\cdots & -\\infty \\\\\n",
|
|||
|
|
" 0 & 0 & -\\infty & \\cdots & -\\infty \\\\\n",
|
|||
|
|
" 0 & 0 & 0 & \\cdots & -\\infty \\\\\n",
|
|||
|
|
" \\vdots & \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
|
|||
|
|
" 0 & 0 & 0 & \\cdots & 0\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
" $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"正如你所观察到的,这是一个上半部分全部是无穷大、下半部分全部是0的矩阵。**在进行掩码时,我们用掩码矩阵与原始$QK^T$点积进行加和**,然后再将加和结果放入softmax函数。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c5ed5f6b-d401-421b-b3b3-9eb52afb0b28",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **有掩码时,掩码矩阵对原始$QK^T$矩阵的影响**\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"QK^T + M = \\begin{bmatrix}\n",
|
|||
|
|
" q_1 \\cdot k_1^T + 0 & q_1 \\cdot k_2^T - \\infty & \\cdots & q_1 \\cdot k_n^T - \\infty \\\\\n",
|
|||
|
|
" q_2 \\cdot k_1^T + 0 & q_2 \\cdot k_2^T + 0 & \\cdots & q_2 \\cdot k_n^T - \\infty \\\\\n",
|
|||
|
|
" \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
|
|||
|
|
" q_n \\cdot k_1^T + 0 & q_n \\cdot k_2^T + 0 & \\cdots & q_n \\cdot k_n^T + 0 \\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "887df9e2-98ae-4884-b6ec-e533f79dda21",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"$$= \\begin{bmatrix}\n",
|
|||
|
|
" q_1 \\cdot k_1^T & -\\infty & -\\infty & \\cdots & -\\infty \\\\\n",
|
|||
|
|
" q_2 \\cdot k_1^T & q_2 \\cdot k_2^T & -\\infty & \\cdots & -\\infty \\\\\n",
|
|||
|
|
" \\vdots & \\vdots & \\ddots & \\vdots & -\\infty \\\\\n",
|
|||
|
|
" q_n \\cdot k_1^T & q_n \\cdot k_2^T & \\cdots & q_n \\cdot k_{n-1}^T & q_n \\cdot k_n^T\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "590ddd89-b7d9-4ecc-b462-a0c095732bf4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"经过掩码处理过的$QK^T$矩阵的右上角全部呈现为负无穷,左下角呈现为具体的值,在这种情况下应用softmax函数后,会得到:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "6e7a1d2d-785c-40db-b365-72ef595a3988",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **有掩码时,softmax函数应用后的影响**\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\text{softmax}(QK^T + M) = \\begin{bmatrix}\n",
|
|||
|
|
" \\frac{e^{q_1 \\cdot k_1^T}}{e^{q_1 \\cdot k_1^T}} & 0 & 0 & 0 \\\\\n",
|
|||
|
|
" \\frac{e^{q_2 \\cdot k_1^T}}{e^{q_2 \\cdot k_1^T} + e^{q_2 \\cdot k_2^T}} & \\frac{e^{q_2 \\cdot k_2^T}}{e^{q_2 \\cdot k_1^T} + e^{q_2 \\cdot k_2^T}} & 0 & 0 \\\\\n",
|
|||
|
|
" \\frac{e^{q_3 \\cdot k_1^T}}{e^{q_3 \\cdot k_1^T} + e^{q_3 \\cdot k_2^T} + e^{q_3 \\cdot k_3^T}} & \\frac{e^{q_3 \\cdot k_2^T}}{e^{q_3 \\cdot k_1^T} + e^{q_3 \\cdot k_2^T} + e^{q_3 \\cdot k_3^T}} & \\frac{e^{q_3 \\cdot k_3^T}}{e^{q_3 \\cdot k_1^T} + e^{q_3 \\cdot k_2^T} + e^{q_3 \\cdot k_3^T}} & 0 \\\\\n",
|
|||
|
|
" \\frac{e^{q_4 \\cdot k_1^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}} & \\frac{e^{q_4 \\cdot k_2^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}} & \\frac{e^{q_4 \\cdot k_3^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}} & \\frac{e^{q_4 \\cdot k_4^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}}\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "8c3e6a2e-eaa1-4d17-b652-a2006bc47e8f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"从softmax函数的具体公式来看,当输入值$z$高度接近负无穷时,以e为底的对数函数的取值会无穷地趋近于0,因此才会得到一个上半个三角全为0的矩阵。通过这种方式,可以让原始矩阵中的一部分信息被“掩盖”(变为0),这个操作就是掩码的本质。\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\sigma(z)_i = \\frac{e^{z_i}}{\\sum_{j=1}^{K} e^{z_j}}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5483bdbe-62a3-49ab-a3ea-1542dec43ac8",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在Transformer模型中,特别是在解码器的掩码自注意力机制中,矩阵$QK^T + M$是一切的关键。这里,掩码矩阵M的作用是确保在生成序列的每个步骤中,模型只能访问到当前和之前的信息,不能“看到”未来的信息。\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"**为什么QK.T矩阵的右上角代表模型在观察未来的信息呢**?回到最初的QK相乘的图像上,假设现在Q是4行3列、K.T是3行4列,不难发现QK.T矩阵的16个因子分别是这样构成的 ↓\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"QK^T = \\begin{bmatrix}\n",
|
|||
|
|
" q_\\boldsymbol{\\color{green}{1}} \\cdot k_\\boldsymbol{\\color{green}{1}}^T & q_\\boldsymbol{\\color{green}{1}} \\cdot k_\\boldsymbol{\\color{red}{2}}^T & q_\\boldsymbol{\\color{green}{1}} \\cdot k_\\boldsymbol{\\color{red}{3}}^T & q_\\boldsymbol{\\color{green}{1}} \\cdot k_\\boldsymbol{\\color{red}{4}}^T \\\\\n",
|
|||
|
|
" q_\\boldsymbol{\\color{green}{2}} \\cdot k_\\boldsymbol{\\color{green}{1}}^T & q_\\boldsymbol{\\color{green}{2}} \\cdot k_\\boldsymbol{\\color{green}{2}}^T & q_\\boldsymbol{\\color{green}{2}} \\cdot k_\\boldsymbol{\\color{red}{3}}^T & q_\\boldsymbol{\\color{green}{2}} \\cdot k_\\boldsymbol{\\color{red}{4}}^T \\\\\n",
|
|||
|
|
" q_\\boldsymbol{\\color{green}{3}} \\cdot k_\\boldsymbol{\\color{green}{1}}^T & q_\\boldsymbol{\\color{green}{3}} \\cdot k_\\boldsymbol{\\color{green}{2}}^T & q_\\boldsymbol{\\color{green}{3}} \\cdot k_\\boldsymbol{\\color{green}{3}}^T & q_\\boldsymbol{\\color{green}{3}} \\cdot k_\\boldsymbol{\\color{red}{4}}^T \\\\\n",
|
|||
|
|
" q_\\boldsymbol{\\color{green}{4}} \\cdot k_\\boldsymbol{\\color{green}{1}}^T & q_\\boldsymbol{\\color{green}{4}} \\cdot k_\\boldsymbol{\\color{green}{2}}^T & q_\\boldsymbol{\\color{green}{4}} \\cdot k_\\boldsymbol{\\color{green}{3}}^T & q_\\boldsymbol{\\color{green}{4}} \\cdot k_\\boldsymbol{\\color{green}{4}}^T\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"使用更简化的写法,你会发现脚标是这样构成的:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"QK^T = \\begin{bmatrix}\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{1}}\\cdot\\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{1}} \\cdot\\boldsymbol{\\color{red}{2}} & \\boldsymbol{\\color{green}{1}} \\cdot\\boldsymbol{\\color{red}{3}} & \\boldsymbol{\\color{green}{1}} \\cdot\\boldsymbol{\\color{red}{4}} \\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{red}{3}} & \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{red}{4}} \\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{3}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{red}{4}} \\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{3}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{4}}\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "78d1d6c8-4922-4dc1-9b1d-9042149c14d3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"你发现什么了?QK都是由单词经过embedding后编码的矩阵,因此Q从上至下的顺序就是“从过去到未来、按句子阅读顺序”排列的顺序,而K作为转置矩阵,K从左到右的顺序就是“从过去到未来、按句子阅读顺序”排列的顺序。当我们使用信息Q去询问信息K时,就有——\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. Q的脚标 = K的脚标,则Q在询问和自己在同一位置/同一时间点的信息\n",
|
|||
|
|
"2. Q的脚标 > K的脚标,则Q在询问在句子前方的/过去的时间点的信息\n",
|
|||
|
|
"3. Q的脚标 < K的脚标,则Q在询问在句子后方的/未来时间点的信息\n",
|
|||
|
|
"\n",
|
|||
|
|
"很显然,Q的脚标 < K的脚标的情况都集中在$QK^T$矩阵的右上角。因此,我们为右上角加上负无穷,并在softmax函数后将该部分信息化为0,就可以避免“未来的信息”泄漏给Transformer算法。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e995a257-2647-42cb-91f8-62b864adb064",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"到这里,你就明白−∞的引入、掩码矩阵的引入所具有的意义了:\n",
|
|||
|
|
"\n",
|
|||
|
|
"- **阻止信息泄露**:在解码过程中,为了保持输出的自回归性质(即每个输出仅依赖于先前的输出),模型不能提前访问未来位置的信息。在$QK^T$矩阵中添加负无穷正是为了这一点,将负无穷加到某些位置上,是为了在计算注意力权重时,这些位置的影响被完全忽略。\n",
|
|||
|
|
"- **影响softmax函数**:在自注意力机制中,注意力权重是通过对$QK^T$应用softmax函数计算得出的。当softmax函数作用于包含负无穷的值时,这些位置的指数值会趋于零,导致它们在计算最终的注意力权重时的贡献也趋于零。因此,这些未来的位置不会对当前或之前的输出产生影响。\n",
|
|||
|
|
"- **保持生成顺序性**:通过这种方式,Transformer能够按顺序逐个生成输出序列中的元素,每个元素的生成只依赖于之前的元素,从而有效地模拟序列生成任务中的时间顺序性和因果关系。\n",
|
|||
|
|
"\n",
|
|||
|
|
"简而言之,将矩阵$QK^T + M$中的上半部分变成负无穷实际上是一种控制措施,用于保证解码器在处理如机器翻译或文本生成等任务时,不会由于未来信息的干扰而产生错误或不自然的输出。这是确保模型生成行为的正确性和效率的关键技术手段。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c0f70fc1-3600-41db-9f24-f6b80901737c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> **掩码后的注意力机制的输出结果**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "cd60c256-8fe3-4fd3-8b56-a6be15bb7019",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **Decoder中,多头注意力机制输出的softmax结果**(这部分信息来自于真实标签y)\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\text{softmax}(QK^T + M) = \\begin{bmatrix}\n",
|
|||
|
|
" \\frac{e^{q_1 \\cdot k_1^T}}{e^{q_1 \\cdot k_1^T}} & 0 & 0 & 0 \\\\\n",
|
|||
|
|
" \\frac{e^{q_2 \\cdot k_1^T}}{e^{q_2 \\cdot k_1^T} + e^{q_2 \\cdot k_2^T}} & \\frac{e^{q_2 \\cdot k_2^T}}{e^{q_2 \\cdot k_1^T} + e^{q_2 \\cdot k_2^T}} & 0 & 0 \\\\\n",
|
|||
|
|
" \\frac{e^{q_3 \\cdot k_1^T}}{e^{q_3 \\cdot k_1^T} + e^{q_3 \\cdot k_2^T} + e^{q_3 \\cdot k_3^T}} & \\frac{e^{q_3 \\cdot k_2^T}}{e^{q_3 \\cdot k_1^T} + e^{q_3 \\cdot k_2^T} + e^{q_3 \\cdot k_3^T}} & \\frac{e^{q_3 \\cdot k_3^T}}{e^{q_3 \\cdot k_1^T} + e^{q_3 \\cdot k_2^T} + e^{q_3 \\cdot k_3^T}} & 0 \\\\\n",
|
|||
|
|
" \\frac{e^{q_4 \\cdot k_1^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}} & \\frac{e^{q_4 \\cdot k_2^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}} & \\frac{e^{q_4 \\cdot k_3^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}} & \\frac{e^{q_4 \\cdot k_4^T}}{\\sum_{j=1}^{4} e^{q_4 \\cdot k_j^T}}\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"当这个矩阵乘以v后,依然不会改变携带的信息,因此我们可以使用这个脚标来标注整个多头注意力机制输出的结果,使用数字简化则有——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\text{Decoder softmax} = \\begin{bmatrix}\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{1}}\\cdot\\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{1}} \\cdot\\boldsymbol{\\color{red}{2}} & \\boldsymbol{\\color{green}{1}} \\cdot\\boldsymbol{\\color{red}{3}} & \\boldsymbol{\\color{green}{1}} \\cdot\\boldsymbol{\\color{red}{4}} \\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{red}{3}} & \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{red}{4}} \\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{3}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{red}{4}} \\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{3}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{4}}\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"经过掩码之后,实际上是——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\text{Decoder masked softmax} = \\begin{bmatrix}\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{1}}\\cdot\\boldsymbol{\\color{green}{1}} & 0 & 0 & 0\\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{2}} \\cdot \\boldsymbol{\\color{green}{2}} & 0 & 0\\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{3}} \\cdot \\boldsymbol{\\color{green}{3}} & 0\\\\\n",
|
|||
|
|
" \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{1}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{2}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{3}} & \\boldsymbol{\\color{green}{4}} \\cdot \\boldsymbol{\\color{green}{4}}\n",
|
|||
|
|
" \\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"转换成注意力得分,则有——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\text{Decoder-Masked-Attention} = \\begin{bmatrix}\n",
|
|||
|
|
"a_{11} & 0 & 0 & 0 \\\\\n",
|
|||
|
|
"a_{21} & a_{22} & 0 & 0 \\\\\n",
|
|||
|
|
"a_{31} & a_{32} & a_{33} & 0 \\\\\n",
|
|||
|
|
"a_{41} & a_{42} & a_{43} & a_{44}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "584d9d95-e023-44db-a856-b41760d9964e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"假设 $V$ 矩阵如下,由于矩阵V是从原始标签y生成的embedding矩阵,因此矩阵V的序列方向是从上到下。\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"V = \\begin{bmatrix}\n",
|
|||
|
|
"v_{1} & v_{1} & \\ldots & v_{1} \\\\\n",
|
|||
|
|
"v_{2} & v_{2} & \\ldots & v_{2} \\\\\n",
|
|||
|
|
"v_{3} & v_{3} & \\ldots & v_{3} \\\\\n",
|
|||
|
|
"v_{4} & v_{4} & \\ldots & v_{4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d72749be-0273-4c5d-a368-038b43c42382",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<font color=\"red\">**特别注意!在这里为了避免脚标产生混淆,没有写特征维度脚标。此时我们所有的脚标都只代表了时间点,特征维度脚标被省略了!事实上真正的V矩阵应该是——**\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"V = \\begin{bmatrix}\n",
|
|||
|
|
"v_{1}^1 & v_{1}^2 & \\ldots & v_{1}^d \\\\\n",
|
|||
|
|
"v_{2}^1 & v_{2}^2 & \\ldots & v_{2}^d \\\\\n",
|
|||
|
|
"v_{3}^1 & v_{3}^2 & \\ldots & v_{3}^d \\\\\n",
|
|||
|
|
"v_{4}^1 & v_{4}^2 & \\ldots & v_{4}^d\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"<font color=\"red\">**在我们此时的讨论流程中,特征维度脚标只有标识作用,与整体过程理解无关,因此在这里出于教学目的将其省略。但事实上它应该是存在的。**</font>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b76c6110-ecb1-4229-aa8d-5637e880a4dd",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"将$\\text{Decoder-Masked-Attention}$ 矩阵与 $V$ 矩阵相乘,得到结果矩阵 $C$,就是带掩码的多头注意力机制的结果——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C = \\text{Decoder-Masked-Attention} \\times V\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C = \\begin{bmatrix}\n",
|
|||
|
|
"a_{11} & 0 & 0 & 0 \\\\\n",
|
|||
|
|
"a_{21} & a_{22} & 0 & 0 \\\\\n",
|
|||
|
|
"a_{31} & a_{32} & a_{33} & 0 \\\\\n",
|
|||
|
|
"a_{41} & a_{42} & a_{43} & a_{44}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"v_{1} & v_{1} & \\ldots & v_{1} \\\\\n",
|
|||
|
|
"v_{2} & v_{2} & \\ldots & v_{2} \\\\\n",
|
|||
|
|
"v_{3} & v_{3} & \\ldots & v_{3} \\\\\n",
|
|||
|
|
"v_{4} & v_{4} & \\ldots & v_{4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e7a9ebde-0c25-4b5e-9eab-65e3c4c96daf",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"结果矩阵 $C$ 的元素 $c_{ij}$ 的计算如下:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"c_{i} = \\sum_{k} a_{ik} \\cdot v_{k}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"具体计算为:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C = \\begin{bmatrix}\n",
|
|||
|
|
"a_{11}v_{1} & a_{11}v_{1} & \\ldots & a_{11}v_{1} \\\\\n",
|
|||
|
|
"a_{21}v_{1} + a_{22}v_{2} & a_{21}v_{1} + a_{22}v_{2} & \\ldots & a_{21}v_{1} + a_{22}v_{2} \\\\\n",
|
|||
|
|
"a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} & a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} & \\ldots & a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} \\\\\n",
|
|||
|
|
"a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{3} + a_{44}v_{4} & a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{3} + a_{44}v_{4} & \\ldots & a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{2} + a_{44}v_{4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d460efdc-188f-4da3-bee7-f3ca50523eba",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"观察这个矩阵,你发现了什么?在这个矩阵中,v上携带的信息的时间点不会超出分数a中携带的信息的时间点,权重和句子信息在交互时都只能与“过去”的信息交互,而不能与“未来”的信息交互。通过这种方式,你可以看到最终带掩码的注意力机制是如何实现未来信息的不泄露的。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7042cf71-fcbf-4619-85b8-e8c434784f1e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"#### 2.3.2.3 普通掩码 vs 前瞻掩码"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d7176153-1ecd-4975-aca4-c79a59c78e43",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在NLP的世界中,掩码最被人熟知的作用就是掩盖未来的信息、避免序列中未来的信息被泄露到算法中,然而掩码(Masking)是一种多功能的机制,其本质是为了“掩盖信息”,但并不局限于掩盖未来的信息。在注意力机制中、掩盖未来信息、不允许Q向未来的K发问的掩码被叫做“前瞻掩码”(look-ahead Masking),这里的“前瞻”正是代表了“未来”(对时间序列来说是未来的时间点、对文字序列来说是右侧的信息)。然而,掩码在Transformer中还有另一个巨大的作用,就是**掩盖噪音信息,避免噪音影响注意力机制的计算**。掩盖噪音的掩码是最普通的掩码之一,在NLP中它主要负责掩盖填充句子时产生的padding。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "80307ff1-42bf-4639-a0a6-46056a0e99c7",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "025e8bd2-e65a-4bc8-9b03-fba0e78ffc2d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"Transformer的输入数据结构为(batch_size, seq_len, input_dimensions),不同句子的seq_len必须保持一致,然而在现实中我们不太可能让每个句子的长度都一致,因此句子过长的部分我们就会截断句子、句子太短的部分我们就会使用填充。这些填充大部分都是0填充,这些0填充与其他token正常编码的结果计算之后,就会在注意力分数中留下许多的噪音值,因此在将这些信息输出之前,我们就会需要在QK.T矩阵上进行“填充掩码”,来帮助注意力机制减少噪音带来的影响。\n",
|
|||
|
|
"\n",
|
|||
|
|
"很显然,**前瞻掩码通常只是解码器专属的,但是填充掩码是解码器和编码器都可以使用的**。在编码器的多头注意力机制中,那个“可选的掩码”就是填充掩码机制。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "89e4493e-e6c5-4d7e-87e9-553639bd9ce9",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ffa551d3-633a-45fe-ac9e-62ac465c42a4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"pytorch ==> 允许我们自创掩码矩阵M,输入到pytorch的各个层里进行掩码(QK.T + M)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c97b2146-9873-436a-999c-6e98c4037898",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"M + QK.T"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "e96b97a6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"Q K\n",
|
|||
|
|
"\n",
|
|||
|
|
"(32,10,512)\n",
|
|||
|
|
"\n",
|
|||
|
|
"(32,10,512)\n",
|
|||
|
|
"\n",
|
|||
|
|
"QK.T\n",
|
|||
|
|
"\n",
|
|||
|
|
"(32,10,512) * (32,512,10)\n",
|
|||
|
|
"\n",
|
|||
|
|
"(10, 512) * (512, 10) ==> (10,10)\n",
|
|||
|
|
"\n",
|
|||
|
|
"(32,8,10,10)\n",
|
|||
|
|
"(batch_size, num_heads, seq_len, seq_len)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "329f55cc",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"M ==> (batch_size, num_heads, seq_len, seq_len)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "57d23d5a-b5fc-4b7b-9a23-70bfe9ee663b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **填充掩码的实现函数**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 8,
|
|||
|
|
"id": "bccd48e6-e101-4eda-b51f-8d42d0f469fa",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import torch\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 创造一个示例数据\n",
|
|||
|
|
"batch_size = 4\n",
|
|||
|
|
"seq_len = 10\n",
|
|||
|
|
"embedding_dim = 8\n",
|
|||
|
|
"seq = torch.randint(0, 5, (batch_size, seq_len, embedding_dim)) # 随机生成一些数据\n",
|
|||
|
|
"\n",
|
|||
|
|
"#填充部分\n",
|
|||
|
|
"pad_token = 0\n",
|
|||
|
|
"seq[0, 7:, :] = pad_token # 设置填充值\n",
|
|||
|
|
"seq[1, 9:, :] = pad_token # 设置填充值\n",
|
|||
|
|
"seq[3, 5:, :] = pad_token # 设置填充值"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 9,
|
|||
|
|
"id": "94e601fc",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"#(4,10,8)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 10,
|
|||
|
|
"id": "8cf81f21-9713-4ab2-8311-aa5f987164f8",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[0, 2, 1, 4, 0, 2, 0, 3],\n",
|
|||
|
|
" [3, 3, 4, 2, 4, 4, 1, 4],\n",
|
|||
|
|
" [0, 4, 3, 2, 4, 2, 4, 3],\n",
|
|||
|
|
" [1, 2, 2, 0, 0, 1, 1, 2],\n",
|
|||
|
|
" [2, 2, 1, 3, 2, 2, 4, 3],\n",
|
|||
|
|
" [2, 4, 4, 3, 2, 1, 1, 4],\n",
|
|||
|
|
" [3, 1, 4, 0, 1, 3, 4, 1],\n",
|
|||
|
|
" [0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|||
|
|
" [0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|||
|
|
" [0, 0, 0, 0, 0, 0, 0, 0]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 10,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"seq[0]"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 11,
|
|||
|
|
"id": "29fc1401-d43a-4abc-ac0e-471ef6af0268",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[4, 3, 0, 1, 1, 3, 2, 0],\n",
|
|||
|
|
" [0, 2, 2, 3, 1, 3, 2, 2],\n",
|
|||
|
|
" [3, 0, 2, 4, 0, 0, 3, 4],\n",
|
|||
|
|
" [3, 1, 0, 0, 3, 4, 1, 1],\n",
|
|||
|
|
" [4, 4, 4, 1, 3, 4, 2, 4],\n",
|
|||
|
|
" [4, 3, 4, 3, 3, 0, 1, 1],\n",
|
|||
|
|
" [2, 3, 0, 4, 4, 1, 0, 2],\n",
|
|||
|
|
" [2, 2, 1, 2, 3, 0, 4, 2],\n",
|
|||
|
|
" [3, 3, 2, 3, 3, 2, 4, 4],\n",
|
|||
|
|
" [0, 0, 0, 0, 0, 0, 0, 0]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 11,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"seq[1]"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 12,
|
|||
|
|
"id": "fa859651-79b4-4c16-b29b-9616b122cc8a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[4, 3, 0, 3, 3, 1, 0, 1],\n",
|
|||
|
|
" [2, 3, 0, 2, 0, 1, 1, 2],\n",
|
|||
|
|
" [0, 3, 4, 1, 2, 0, 1, 0],\n",
|
|||
|
|
" [1, 4, 2, 0, 2, 0, 2, 4],\n",
|
|||
|
|
" [2, 0, 2, 3, 2, 3, 1, 4],\n",
|
|||
|
|
" [0, 1, 0, 0, 0, 2, 1, 0],\n",
|
|||
|
|
" [2, 2, 4, 2, 3, 1, 2, 3],\n",
|
|||
|
|
" [0, 3, 2, 0, 1, 3, 3, 2],\n",
|
|||
|
|
" [3, 0, 2, 2, 4, 3, 2, 4],\n",
|
|||
|
|
" [2, 1, 2, 3, 1, 0, 3, 3]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 12,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"seq[2]"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 13,
|
|||
|
|
"id": "9cf06e5b-1935-4257-9ff3-97f6f7d6038d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"#怎么针对这些填充的行进行掩码?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 14,
|
|||
|
|
"id": "7dedc250-b45c-4059-ac0d-355d3fe6db7e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[[ True, False, False, False, True, False, True, False],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [ True, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, True, True, False, False, False],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, True, False, False, False, False],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True]],\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[False, False, True, False, False, False, False, True],\n",
|
|||
|
|
" [ True, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, True, False, False, True, True, False, False],\n",
|
|||
|
|
" [False, False, True, True, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, False, False, True, False, False],\n",
|
|||
|
|
" [False, False, True, False, False, False, True, False],\n",
|
|||
|
|
" [False, False, False, False, False, True, False, False],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True]],\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[False, False, True, False, False, False, True, False],\n",
|
|||
|
|
" [False, False, True, False, True, False, False, False],\n",
|
|||
|
|
" [ True, False, False, False, False, True, False, True],\n",
|
|||
|
|
" [False, False, False, True, False, True, False, False],\n",
|
|||
|
|
" [False, True, False, False, False, False, False, False],\n",
|
|||
|
|
" [ True, False, True, True, True, False, False, True],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [ True, False, False, True, False, False, False, False],\n",
|
|||
|
|
" [False, True, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, False, False, True, False, False]],\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[False, True, False, False, False, True, True, False],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, True, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, True, False, True, False, False],\n",
|
|||
|
|
" [False, False, True, True, False, False, False, False],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True],\n",
|
|||
|
|
" [ True, True, True, True, True, True, True, True]]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 14,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"(seq == pad_token) #具体的值是否等于掩码呢?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 15,
|
|||
|
|
"id": "4c0ed979-9da0-44f7-a44e-73461b38ce2a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[False, False, False, False, False, False, False, True, True, True],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False, False, True],\n",
|
|||
|
|
" [False, False, False, False, False, False, False, False, False, False],\n",
|
|||
|
|
" [False, False, False, False, False, True, True, True, True, True]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 15,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"(seq == pad_token).all(dim=-1) #只有全部是0的行,才是真正的掩码,在最后一个维度上查看是否整行都为0"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 16,
|
|||
|
|
"id": "0b4b711a-7a54-445d-a456-528007586321",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 16,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"(seq == pad_token).all(dim=-1).float() #使用float,这就是标注出来需要掩码的行"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "21c17563-be8d-404e-9086-d4b50013f158",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"#每一行是一个batch,每一列是这个batch中的一个token\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" token1 token2 token3 .。。。 token10\n",
|
|||
|
|
"batch1 0 0 0 1\n",
|
|||
|
|
"batch2 0 0 1 1\n",
|
|||
|
|
"batch3 0 1 1 1\n",
|
|||
|
|
"batch4 0 1 1 1 \n",
|
|||
|
|
"。。。\n",
|
|||
|
|
"\n",
|
|||
|
|
"(batch_size,seq_len)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 18,
|
|||
|
|
"id": "ece8aa4f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"#40 x 10"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "ea852ec4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"X (4,10,8)\n",
|
|||
|
|
"Q\\K (4,10,8)\n",
|
|||
|
|
"\n",
|
|||
|
|
"K.T (4,8,10)\n",
|
|||
|
|
"\n",
|
|||
|
|
"QK.T (4,10,8) x (4,8,10)\n",
|
|||
|
|
"\n",
|
|||
|
|
"(4,1,10,10)\n",
|
|||
|
|
"\n",
|
|||
|
|
"(4,num_heads,10,10)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "988e4b31",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"QK.T (4,num_heads,10,10) #400\n",
|
|||
|
|
"\n",
|
|||
|
|
"M (4,num_heads,10,10)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "3a5f9345",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 19,
|
|||
|
|
"id": "132f5208-c667-41c9-8ec0-22810db48919",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"torch.Size([4, 1, 10, 10])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 19,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"#接下来,要将掩码的行转变为掩码矩阵\n",
|
|||
|
|
"#QK.T矩阵的维度是(batch_size,num_heads,seq_len,seq_len)\n",
|
|||
|
|
"#掩码矩阵为了要能够与QK.T矩阵相加,也必须是这个结构\n",
|
|||
|
|
"#unsqueeze用于在特定位置增加维度,expand则用于复制&拓展维度\n",
|
|||
|
|
"\n",
|
|||
|
|
"(seq == pad_token).all(dim=-1).unsqueeze(1).unsqueeze(3).expand(-1, -1, -1, seq.size(1)).shape"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2f84798c-6241-4deb-8921-8a1b938f89b6",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<font color=\"red\">**录制视频时代码expand维度有误,导致掩码部分变成了列,正确的代码请以课件为准**</font>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 20,
|
|||
|
|
"id": "beb6c2b2-e891-49b8-8e22-3a54db34969f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 20,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"(seq == pad_token).all(dim=-1).unsqueeze(1).unsqueeze(3).expand(-1, -1, -1, seq.size(1)).float()"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b8e4cbe6-4435-450e-8bb0-b674f89478bc",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> - **reshape与expand的区别**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 21,
|
|||
|
|
"id": "43aacc59-bfba-437b-a2a3-6e00246795eb",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"s = torch.tensor([[0,1,2,3,4,5]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 22,
|
|||
|
|
"id": "bdbf6ab9-afd2-4e39-b188-96face8823ba",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[0, 1, 2],\n",
|
|||
|
|
" [3, 4, 5]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 22,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"s.reshape(-1,3) #将原始的序列拆成2行3列,重组"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 23,
|
|||
|
|
"id": "fe75e99e-4acd-4dac-ac5f-0f7f8d816975",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[[0, 1, 2, 3, 4, 5],\n",
|
|||
|
|
" [0, 1, 2, 3, 4, 5],\n",
|
|||
|
|
" [0, 1, 2, 3, 4, 5]]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 23,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"s.unsqueeze(1).expand(-1, 3, -1) #升维之后,在升起的维度上将原始序列复制三遍,构成更高维的结果"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 24,
|
|||
|
|
"id": "f929b99d-3625-4344-80ec-e3f5210ab417",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"def create_padding_mask(seq, pad_token=0):\n",
|
|||
|
|
" # seq: (batch_size, seq_len, embedding_dim)\n",
|
|||
|
|
" # 检查填充值位置\n",
|
|||
|
|
" padding_mask = (seq == pad_token).all(dim=-1) # (batch_size, seq_len)\n",
|
|||
|
|
" \n",
|
|||
|
|
" # 增加维度以匹配注意力权重矩阵的形状\n",
|
|||
|
|
" # (batch_size, num_heads, seq_len, seq_len)\n",
|
|||
|
|
" padding_mask = padding_mask.unsqueeze(1).unsqueeze(3).expand(-1, -1, -1, seq.size(1))\n",
|
|||
|
|
" \n",
|
|||
|
|
" # 将填充值部分设置为负无穷大,有效数据部分设置为0\n",
|
|||
|
|
" padding_mask = padding_mask.float() * -1e9 # (batch_size, num_heads, seq_len, seq_len)\n",
|
|||
|
|
" \n",
|
|||
|
|
" return padding_mask"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 25,
|
|||
|
|
"id": "1eb66347-89dc-41d5-bc17-e900c21b1385",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"torch.Size([4, 1, 10, 10])\n",
|
|||
|
|
"tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]]],\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]]],\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" [[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]]]])\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"padding_mask = create_padding_mask(seq, pad_token)\n",
|
|||
|
|
"print(padding_mask.shape)\n",
|
|||
|
|
"print(padding_mask)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e5ff1abd-5309-47ce-b59f-98b21e07dc06",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"但注意!我们并不是随时随地都需要这个升维的过程。具体需要呈现怎样的掩码矩阵,需要根据掩码矩阵使用的时机、以及配合的库来考虑。如果是配合PyTorch中已经设置好的Transformer类来使用,则二维的掩码矩阵就足够了,Transformer类会自动执行将掩码矩阵升维的过程;如果是利用更底层的机制创建的Transformer,则会需要我们手动执行上述流程来匹配掩码的结构。在实际使用时,大家要根据实际情况选择是否主动对掩码矩阵进行升维。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "96f57de7-3ebd-489b-a4d3-bb6be4e84911",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"无需升维的填充掩码函数如下——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 26,
|
|||
|
|
"id": "b86652ca-c6ce-40c1-82ec-cecf5511e0d1",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"def create_padding_mask(seq, pad_token=0):\n",
|
|||
|
|
" # seq: (batch_size, seq_len, embedding_dim)\n",
|
|||
|
|
" # 检查填充值位置\n",
|
|||
|
|
" padding_mask = (seq == pad_token).all(dim=-1) # (batch_size, seq_len)\n",
|
|||
|
|
" padding_mask = padding_mask.float() * -1e9\n",
|
|||
|
|
" \n",
|
|||
|
|
" return padding_mask"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b8be99bd-9980-408a-88f0-1e3f1b491924",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **前瞻掩码的实现函数**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 28,
|
|||
|
|
"id": "a5c596c4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"#QK.T (batch_size, num_heads, seq_len, seq_len)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "6c4456e0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"triu(生成的矩阵的结构,对角线所在的位置)\n",
|
|||
|
|
"tril"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 29,
|
|||
|
|
"id": "11eb6796-3359-4b99-851b-57c66e2b99ab",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import torch\n",
|
|||
|
|
"\n",
|
|||
|
|
"def create_look_ahead_mask(seq_len, start_seq = 1):\n",
|
|||
|
|
" mask = torch.triu(torch.ones((seq_len, seq_len)),diagonal=start_seq) # triu 左下方的三角矩阵,diagonal控制对角线位置\n",
|
|||
|
|
" #mask = mask.float() * -1e9 # 将未来的位置设置为负无穷大\n",
|
|||
|
|
" return mask"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 30,
|
|||
|
|
"id": "2d7315da-e693-4b43-a3dc-fc9986420843",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 30,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"create_look_ahead_mask(10) #为了教学方便,现在展示的是1和0,实际应该是右上角负无穷,左下角0"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "3ce79780",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"(seq_len, seq_len)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "27340ecd",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"batch_size * num_heads\n",
|
|||
|
|
"\n",
|
|||
|
|
"32 * 8 = 256"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "0e18c7b4-0078-420f-b0f5-133605f69578",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"你注意到了吗?前瞻掩码矩阵的结构为(seq_len, seq_len),而填充掩码矩阵的结构为(batch_size,num_heads,seq_len,seq_len)。前者可以通过广播的方式与QK.T矩阵相加,后者则必须写明4个维度的信息,这是因为**前瞻掩码对所有的序列都是一样的掩码方式,但填充掩码却是在每个batch内都是不一致的**,因为每个batch内的句子可能会不一致。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 31,
|
|||
|
|
"id": "6c0bd91e-837d-4e7c-99a9-bec5d827ad9e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"def create_look_ahead_mask(seq_len, start_seq = 1):\n",
|
|||
|
|
" mask = torch.triu(torch.ones((seq_len, seq_len)),diagonal=start_seq) # triu 左下方的三角矩阵,diagonal控制对角线位置\n",
|
|||
|
|
" mask = mask.float() * -1e9 # 将未来的位置设置为负无穷大\n",
|
|||
|
|
" return mask"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 32,
|
|||
|
|
"id": "57dd6146-1b79-4bbd-bcd2-276c048a8b6f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[-0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],\n",
|
|||
|
|
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,\n",
|
|||
|
|
" -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 32,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"create_look_ahead_mask(10) #右上角为负无穷,左下角为0"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c892c749-9b81-4679-90ad-3bc2c90e2a74",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"#### 2.3.2.4 编码器-解码器注意力层"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "42f286b3-69bd-4e2b-a6a1-340ae1af01d2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/image-1.png\" alt=\"描述文字\" width=\"400\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b8fbf3d5-90fa-417f-aac9-1ba78e6826dd",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在Transformer模型的解码器部分,编码器-解码器注意力层(通常称为“交叉注意力”层)起着至关重要的作用。这一层允许**解码器的每个位置访问整个编码器的输出**,这对于将输入序列的上下文信息整合到输出序列的生成中是必需的。这个层的设计是为了确保解码器能够基于完整的输入序列信息来生成每个输出元素。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "c2b029b5-f1c6-4c8f-8956-e9d41d9caecf",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"首先,**编码器-解码器注意力层的输入是来自多头注意力机制的输出结果**。从Decoder的掩码注意力层中输出的是经过掩码后、每一行只携带特定时间段信息的结果$C_{decoder}$:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C_{decoder} = \\begin{bmatrix}\n",
|
|||
|
|
"a_{11}v_{1} & a_{11}v_{1} & \\ldots & a_{11}v_{1} \\\\\n",
|
|||
|
|
"a_{21}v_{1} + a_{22}v_{2} & a_{21}v_{1} + a_{22}v_{2} & \\ldots & a_{21}v_{1} + a_{22}v_{2} \\\\\n",
|
|||
|
|
"a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} & a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} & \\ldots & a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} \\\\\n",
|
|||
|
|
"a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{3} + a_{44}v_{4} & a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{3} + a_{44}v_{4} & \\ldots & a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{2} + a_{44}v_{4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e2ec3ae6-4f94-4b56-933b-bd6174813a67",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"**当我们使用覆盖的时间点来作为脚标**,则有:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C_{decoder} = \\begin{bmatrix}\n",
|
|||
|
|
"c_{1} & c_{1} & \\ldots & c_{1} \\\\\n",
|
|||
|
|
"c_{1 \\to 2} & c_{1 \\to 2} & \\ldots & c_{1 \\to 2} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"<font color=\"red\">**同样的,这里出于教学目的,省略了特征维度上的脚标。现在你所看到的脚标只代表时间维度/序列长度的维度。**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7c2b8a3e-97bf-45a0-b9a7-1efb4cac481a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"从Encoder中输出的是没有掩码的注意力机制结果$C_{encoder}$,由于没有掩码,所以Encoder中的注意力分数为——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\text{A} = \\begin{bmatrix}\n",
|
|||
|
|
"a_{11} & a_{12} & a_{13} & a_{14} \\\\\n",
|
|||
|
|
"a_{21} & a_{22} & a_{23} & a_{24} \\\\\n",
|
|||
|
|
"a_{31} & a_{32} & a_{33} & a_{34} \\\\\n",
|
|||
|
|
"a_{41} & a_{42} & a_{43} & a_{44}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"同时,V矩阵为(省略了特征维度,脚标代表的是时间点、seq_len的信息)——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"V = \\begin{bmatrix}\n",
|
|||
|
|
"v_{1} & v_{1} & \\ldots & v_{1} \\\\\n",
|
|||
|
|
"v_{2} & v_{2} & \\ldots & v_{2} \\\\\n",
|
|||
|
|
"v_{3} & v_{3} & \\ldots & v_{3} \\\\\n",
|
|||
|
|
"v_{4} & v_{4} & \\ldots & v_{4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"由于$\\text{C}_{\\text{encoder}} = \\text{A} \\times V$,因此最终的结果矩阵 $\\text{C}_{\\text{encoder}}$ 是:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\text{C}_{\\text{encoder}} = \\begin{bmatrix}\n",
|
|||
|
|
"a_{11} \\cdot v_1 + a_{12} \\cdot v_2 + a_{13} \\cdot v_3 + a_{14} \\cdot v_4 & a_{11} \\cdot v_1 + a_{12} \\cdot v_2 + a_{13} \\cdot v_3 + a_{14} \\cdot v_4 & \\ldots & a_{11} \\cdot v_1 + a_{12} \\cdot v_2 + a_{13} \\cdot v_3 + a_{14} \\cdot v_4 \\\\\n",
|
|||
|
|
"a_{21} \\cdot v_1 + a_{22} \\cdot v_2 + a_{23} \\cdot v_3 + a_{24} \\cdot v_4 & a_{21} \\cdot v_1 + a_{22} \\cdot v_2 + a_{23} \\cdot v_3 + a_{24} \\cdot v_4 & \\ldots & a_{21} \\cdot v_1 + a_{22} \\cdot v_2 + a_{23} \\cdot v_3 + a_{24} \\cdot v_4 \\\\\n",
|
|||
|
|
"a_{31} \\cdot v_1 + a_{32} \\cdot v_2 + a_{33} \\cdot v_3 + a_{34} \\cdot v_4 & a_{31} \\cdot v_1 + a_{32} \\cdot v_2 + a_{33} \\cdot v_3 + a_{34} \\cdot v_4 & \\ldots & a_{31} \\cdot v_1 + a_{32} \\cdot v_2 + a_{33} \\cdot v_3 + a_{34} \\cdot v_4 \\\\\n",
|
|||
|
|
"a_{41} \\cdot v_1 + a_{42} \\cdot v_2 + a_{43} \\cdot v_3 + a_{44} \\cdot v_4 & a_{41} \\cdot v_1 + a_{42} \\cdot v_2 + a_{43} \\cdot v_3 + a_{44} \\cdot v_4 & \\ldots & a_{41} \\cdot v_1 + a_{42} \\cdot v_2 + a_{43} \\cdot v_3 + a_{44} \\cdot v_4\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "5d8a3541-a1bc-4b60-9e64-46a51614e553",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"同样的,**当我们使用覆盖的时间点来作为脚标**,则有:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C_{encoder} = \\begin{bmatrix}\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "09b5b9f3-a810-4c85-b363-98f2d0cb8705",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<font color=\"red\">**同样的,这里出于教学目的,省略了特征维度上的脚标。现在你所看到的脚标只代表时间维度/序列长度的维度。事实上4列C虽然覆盖的时间维度一致,但却归属于不同的特征维度。**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fa1b627a-e6fd-48e9-a32f-826632c52c79",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"此时,$C_{encoder}$携带的是特征矩阵X的信息,$C_{encoder}$中的每个元素都携带了全部时间步上的信息,$C_{decoder}$携带的是真实标签的信息,$C_{decoder}$中的元素则是每一行代表了一段时间的信息,随着行数的增加这段时间窗口越来越长。编码器-解码器注意力层负责整合这两部分信息。具体来说,编码器解码器输出的结果结合的方式是——**将解码器中的标签信息$C_{decoder}$作为Q矩阵,将编码器中输出的特征信息$C_{encoder}$作为K和V矩阵,使用每行Q与全部的K、V相乘,来执行一种特殊的注意力机制**。\n",
|
|||
|
|
"\n",
|
|||
|
|
"这种特殊注意力机制的公式如下——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\text{Context}_1 = \\sum_{i} \\text{Attention}(Q_1, K_i) \\times V_i$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\text{Context}_2 = \\sum_{i} \\text{Attention}(Q_2, K_i) \\times V_i$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\text{Context}_3 = \\sum_{i} \\text{Attention}(Q_3, K_i) \\times V_i$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$……$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "9e12a0d8-7f4d-4fdd-9a48-cf96b6721d4b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在这个公式中,Q与K转置相乘的地方不再是原来自注意力机制中“问答的点积”、而是转变为了交叉的点击——转换成矩阵则有,$C_{decoder}(Q)$的第一行乘以$C_{encoder}(K.T)$的第一列,加上$C_{decoder}(Q)$的第一行乘以$C_{encoder}(K.T)$的第二列,加上$C_{decoder}(Q)$的第一行乘以$C_{encoder}(K.T)$的第三列……直到所有的列都被乘完为止。\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"Context_1 = \\begin{bmatrix}\n",
|
|||
|
|
"\\color{red}{c_{1}} & \\color{red}{c_{1}} & \\ldots & \\color{red}{c_{1}} \\\\\n",
|
|||
|
|
"c_{1 \\to 2} & c_{1 \\to 2} & \\ldots & c_{1 \\to 2} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix} \\cdot \n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"\\color{red}\\ldots & \\ldots & \\ldots & \\ldots \\\\\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ + $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"\\color{red}{c_{1}} & \\color{red}{c_{1}} & \\ldots & \\color{red}{c_{1}} \\\\\n",
|
|||
|
|
"c_{1 \\to 2} & c_{1 \\to 2} & \\ldots & c_{1 \\to 2} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix} \\cdot\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"\\ldots & \\color{red}{\\ldots} & \\ldots & \\ldots \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ + $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ …… $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ + $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"\\color{red}{c_{1}} & \\color{red}{c_{1}} & \\ldots & \\color{red}{c_{1}} \\\\\n",
|
|||
|
|
"c_{1 \\to 2} & c_{1 \\to 2} & \\ldots & c_{1 \\to 2} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix} \\cdot\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} \\\\\n",
|
|||
|
|
"\\ldots & \\ldots & \\ldots & \\color{red}{\\ldots} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "544862ce-7029-462c-87dd-6aab943515a2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"这是**标签中的第一个时间步与特征中的所有时间步产生关联**。\n",
|
|||
|
|
"\n",
|
|||
|
|
"同样的我们有——\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"Context_2 = \\begin{bmatrix}\n",
|
|||
|
|
"c_{1} & c_{1} & \\ldots & c_{1} \\\\\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 2}} & \\color{red}{c_{1 \\to 2}} & \\ldots & \\color{red}{c_{1 \\to 2}} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix} \\cdot\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"\\color{red}{\\ldots} & \\ldots & \\ldots & \\ldots \\\\\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ + $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"c_{1} & c_{1} & \\ldots & c_{1} \\\\\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 2}} & \\color{red}{c_{1 \\to 2}} & \\ldots & \\color{red}{c_{1 \\to 2}} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix} \\cdot\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"\\ldots & \\color{red}{\\ldots} & \\ldots & \\ldots \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} & c_{1 \\to 4} & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ + $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ …… $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$ + $$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"c_{1} & c_{1} & \\ldots & c_{1} \\\\\n",
|
|||
|
|
"\\color{red}{c_{1 \\to 2}} & \\color{red}{c_{1 \\to 2}} & \\ldots & \\color{red}{c_{1 \\to 2}} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix} \\cdot\n",
|
|||
|
|
"\\begin{bmatrix}\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}} \\\\\n",
|
|||
|
|
"\\ldots & \\ldots & \\ldots & \\color{red}{\\ldots} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & c_{1 \\to 4} & \\color{red}{c_{1 \\to 4}}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是**标签中的第一个和第二个时间步与特征中的所有时间步产生关联**。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fe229962-5745-48b4-b0bf-ee676904196d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"以此类推下去,直到形成新的注意力机制矩阵,后续进入softmax、并与V相乘的流程也类似。**你是否注意到,这个注意力机制事实上代表了什么**?还记得我们最初说decoder结构的输入与输出吗?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "667787a4-7a41-499f-99db-cf994e2c6e5e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在decoder中我们实际走的<font color=\"red\">**训练流程**</font>是:\n",
|
|||
|
|
"\n",
|
|||
|
|
"> - **第一步,输入ebd_X & ebd_y[0] >> 输出yhat[0],对应真实标签y[0]**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "1390a055-99b9-429c-98ef-fc8c8047bb87",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>对应</p>\n",
|
|||
|
|
" ➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b35be58e-b456-4854-b1e2-e1335d7e0f5c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> - **第二步,输入ebd_X & ebd_y[:1] >> 输出yhat[1],对应真实标签y[1]**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "1894c009-80da-4d03-9c22-c77fa96744ea",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>对应</p>\n",
|
|||
|
|
" ➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "f8290514-7115-430b-b26e-461346862ee1",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"> - **第三步,输入ebd_X & ebd_y[:2] >> 输出yhat[2],对应真实标签y[2]**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "69d579ae-115f-4522-bcbe-eb0036f3bd36",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Encoder<br>特征矩阵</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>标签矩阵</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.5651</td><td>0.2220</td><td>0.5112</td><td>0.8543</td><td>0.1239</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>1</td><td>It</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:red;\">\n",
|
|||
|
|
" <td>2</td><td>was</td><td>0.2314</td><td>0.6794</td><td>0.9823</td><td>0.8452</td><td>0.3417</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>the</td><td>0.4932</td><td>0.2045</td><td>0.7531</td><td>0.6582</td><td>0.9731</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>best</td><td>0.8342</td><td>0.2987</td><td>0.7642</td><td>0.2154</td><td>0.9812</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>of</td><td>0.3417</td><td>0.5792</td><td>0.4821</td><td>0.6721</td><td>0.1234</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>times</td><td>0.2531</td><td>0.7345</td><td>0.9812</td><td>0.5487</td><td>0.2378</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>对应</p>\n",
|
|||
|
|
" ➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>It</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>was</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>2</td><td>the</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>best</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>of</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>times</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2daa4ae1-7a58-4778-b272-9b1ff4615ce0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"……以此类推下去。很显然,编码器-解码器注意力机制中的数学流程,正是【利用序列X + 序列y的前半段预测序列y的后半段】的计算方式!在这里每一步都是单独的方程,涉及到矩阵中不同的行,因此这里的所有时间步可以并行!本质上实现的是编码器-解码器注意力机制中、下列方程的并行 ↓\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\text{Context}_1 = \\sum_{i} \\text{Attention}(Q_1, K_i) \\times V_i$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\text{Context}_2 = \\sum_{i} \\text{Attention}(Q_2, K_i) \\times V_i$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\\text{Context}_3 = \\sum_{i} \\text{Attention}(Q_3, K_i) \\times V_i$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2e759e40-015e-4271-8add-5e911afa3c1d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"所以现在你知道编码器解码器层是如何实现信息整合的了。学到这里,我们来总结一下编码器-解码器注意力层的核心作用——\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. **关联输入和输出**:在许多任务中,输出序列的生成需要依赖于输入序列的特定部分。这层允许模型学习在生成每个输出元素时应关注输入序列的哪些部分。\n",
|
|||
|
|
"2. **灵活的上下文捕捉**:与自注意力层只能处理解码器自身的先前输出不同,编码器-解码器注意力层可以访问整个输入序列的上下文,这对于任务如机器翻译至关重要。\n",
|
|||
|
|
"3. **增强解码器能力**:通过整合来自编码器的信息,这一设计显著增强了解码器处理复杂输入序列并准确生成输出的能力。\n",
|
|||
|
|
"\n",
|
|||
|
|
"总之,编码器-解码器注意力层是Transformer解码器的核心部分,使解码器能够利用编码器处理的完整输入信息,从而生成语义上连贯且上下文相关的输出。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ce71908e-b9c8-46e6-b2b8-5f5444227d55",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"### 2.3.3 Decoder-Only结构中的Decoder"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a1962038-11c8-42c3-b538-9e18b2fe5372",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"现在,让我们来看看Decoder-only结构下的Decoder。在Decoder-only结构下的Decoder是专用于生成式任务的架构,它从整个Transformer结构中抽离出来、有独特的训练流程与结构。我们先从结构来看——\n",
|
|||
|
|
"\n",
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/IUCxP.png\" alt=\"描述文字\" width=\"300\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "023a7e72-d9fa-430f-845e-8586e2411ffb",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"训练——(teacher forcing - 不会累计错误)\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最好的时代 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最好的时代,这 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最好的时代,这是 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"测试——(autoregressive - 累计错误)\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最坏的时代 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最坏的时代,xxx 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最坏的时代,xxxxxx 👉 xxx"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "92b59d15-b72d-4f18-875a-f9ec5cbdfd37",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"如图所示,与原本的Decoder结构相比,Decoder-only状态下的Decoder不再存在编码器-解码器注意力层,整个结构会变得更像编码器Encoder,但依然保留着Teacher forcing和掩码机制。由于没有了编码器-解码器注意力层,因此原本依赖于编码器-解码器注意力层完成的整套训练和运算流程也都不再有效了,相对的,在Decoder-only结构中的Decoder大部分时候都采用“自回归”的训练流程——自回归流程在时间序列预测中是一种常用的方法,它逐步生成未来的值,每一步的预测依赖于前一步的实际值或预测值,而Decoder-only状态下的训练、预测流程都是这样的流程。**在自回归场景中,Decoder的任务是——**\n",
|
|||
|
|
"\n",
|
|||
|
|
"1. **利用序列的前半段预测序列的后半段**,因此Decoder的输入数据是一段时间序列、一段文字,输出的是对未来时间的预测、对未来文字的填补<br><br>\n",
|
|||
|
|
"\n",
|
|||
|
|
"2. **利用teacher forcing机制和自回归机制的本质,在训练和预测流程中使用标签来辅助预测**。具体地来说,在训练流程中,Decoder利用teacher forcing机制、不断将正确的标签作为特征数据使用;在测试流程中,Decoder利用自回归的属性,将前一步的预测值作为特征数据来使用。\n",
|
|||
|
|
"\n",
|
|||
|
|
"<font color=\"red\">**在生成式任务中,一般我们不再区分“特征和标签”这两种不同的数据,在大多数生成式任务中,我们有且只有一种数据——就是需要继续生成、继续补充的那段序列**</font>。生成式任务带有一定的“自监督”属性,我们训练用的数据、和要预测的数据都来自于同一段序列,因此标签数据在下一个时间步就会成为我们的特征数据,故而我们也不会特地再去区分特征和标签、而是会区分“输入”与“输出”。不过,从架构图上来看,除了要预测的序列本身之外,我们依然也可以给Decoder输入更多额外的信息(图上的inputs部分)。大部分时候,我们可以使用这条数据流线路向Decoder传递一些相应的“条件”与“背景知识”,可以帮助我们更好地进行信息的生成和填补。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d9626c6f-84ec-4a09-83bc-99eb5d957a93",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"具体来看,Decoder-only状态下的<font color=\"red\">**训练流程**</font>如下,假设需要预测的序列为y,编码好的结果为ebd_y,其中我们取ebd_y的前n个字符作为输入,n个字符后的字符作为标签:"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "5bddc57e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"训练——(teacher forcing - 不会累计错误)\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最好的时代 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最好的时代,这 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最好的时代,这是 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"测试——(autoregressive - 累计错误)\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最坏的时代 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最坏的时代,xxx 👉 xxx\n",
|
|||
|
|
"\n",
|
|||
|
|
"这是最坏的时代,xxxxxx 👉 xxx"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d15876cf-fee0-4766-a173-bcb75d18fd57",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **第1步,输入 ebd_y[0] >> 输出预测标签yhat[0],对应真实标签y[0]**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>序列的前半段</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>对应</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y<br>序列的后半段</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>0</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>最坏的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>7</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>8</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "60c5082e-e954-4592-a635-0c3cf4d08b52",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"……\n",
|
|||
|
|
"- **第n+1步,输入 ebd_y[:n] >> 输出预测标签yhat[n],对应真实标签y[n]**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>序列的前半段</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>对应</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y<br>序列的后半段</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>4</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>最坏的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>7</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>8</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ddd26efa-b0b6-4095-aa9d-cc04904e47c2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **第n+2步,输入 ebd_y[:n+1] >> 输出预测标签yhat[n+1],对应真实标签y[n+1]**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>序列的前半段</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>对应</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y<br>序列的后半段</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>5</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>最坏的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>7</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>8</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "63a5a743-5962-4da9-a04a-50af96e61923",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **第n+3步,输入 ebd_y[:n+2] >> 输出预测标签yhat[n+2],对应真实标签y[n+2]**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>序列的前半段</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>\"sos\"</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>对应</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>真实标签y<br>序列的后半段</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>这</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>是</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr style=\"color:blue;\">\n",
|
|||
|
|
" <td>6</td><td>最坏的</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>7</td><td>时代</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>8</td><td>\"eos\"</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "ec456494-41a8-424d-8370-420dd48cb091",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"而在<font color=\"red\">**推理流程**</font>中,Decoder中运行的流程如下所示——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "34ea7606-fd90-4985-87e6-53a440f508ac",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **第一步,输入 ebd_y(全部的数据) >> 输出下一步的预测标签**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>全部的序列</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "120b23d6-6ab6-4a17-9005-1848ebde133c",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **第二步,输入 ebd_y(全部的数据)+ 预测的yhat >> 输出下一步的预测标签**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>全部的序列</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "72d09ff7-aa85-4635-bf11-00071d5caf2b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **第三步,输入 ebd_y(全部的数据)+ 预测的yhat >> 输出下一步的预测标签**\n",
|
|||
|
|
"\n",
|
|||
|
|
"<table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>输入Decoder<br>全部的序列</p>\n",
|
|||
|
|
" <table style=\"color:red;\">\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>0</td><td>这</td><td>0.1821</td><td>0.4000</td><td>0.2248</td><td>0.4440</td><td>0.7771</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>1</td><td>是</td><td>0.1721</td><td>0.5030</td><td>0.8948</td><td>0.2385</td><td>0.0987</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>2</td><td>最好的</td><td>0.1342</td><td>0.8297</td><td>0.2978</td><td>0.7120</td><td>0.2565</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>3</td><td>时代</td><td>0.1248</td><td>0.5003</td><td>0.7559</td><td>0.4804</td><td>0.2593</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>4</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>5</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td><p>预测出</p>➡\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" <td>\n",
|
|||
|
|
" <p>当前时间步的预测标签yhat</p>\n",
|
|||
|
|
" <table>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <th>索引</th><th></th><th>y1</th><th>y2</th><th>y3</th><th>y4</th><th>y5</th>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" <tr>\n",
|
|||
|
|
" <td>6</td><td>yyy</td><td>0.5621</td><td>0.8920</td><td>0.7312</td><td>0.2543</td><td>0.1289</td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
" </table>\n",
|
|||
|
|
" </td>\n",
|
|||
|
|
" </tr>\n",
|
|||
|
|
"</table>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "e4e767b0-4593-4a1f-90d4-7188ab29a1ad",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"以此类推,直到预测出“eos”后停止。**与Transformer中的Decoder一致,训练流程是可以并行的,这一点通过带掩码的注意力机制来实现**。而推理流程是必须严格遵守自回归要求的、在下一个时间步预测之前必须将上一个时间步的结果计算出来,因此**推理流程中则需要使用循环的方式**来进行预测。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "1f02882e-87fd-4d72-97a0-1e675b228db2",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"- **不从第一个样本开始训练的流程如何实现?**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "4cd986a4-9f01-4be4-b9df-d42d2ea23650",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"<center><img src=\"https://skojiangdoc.oss-cn-beijing.aliyuncs.com/2023DL/transformer/IUCxP.png\" alt=\"描述文字\" width=\"300\">"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "406f0103-d27a-4934-85cc-8a2ff8c60917",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"从Decoder的掩码注意力层中输出的是经过掩码后、每一行只携带特定时间段信息的结果$C_{decoder}$:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C_{decoder} = \\begin{bmatrix}\n",
|
|||
|
|
"a_{11}v_{1} & a_{11}v_{1} & \\ldots & a_{11}v_{1} \\\\\n",
|
|||
|
|
"a_{21}v_{1} + a_{22}v_{2} & a_{21}v_{1} + a_{22}v_{2} & \\ldots & a_{21}v_{1} + a_{22}v_{2} \\\\\n",
|
|||
|
|
"a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} & a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} & \\ldots & a_{31}v_{1} + a_{32}v_{2} + a_{33}v_{3} \\\\\n",
|
|||
|
|
"a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{3} + a_{44}v_{4} & a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{3} + a_{44}v_{4} & \\ldots & a_{41}v_{1} + a_{42}v_{2} + a_{43}v_{2} + a_{44}v_{4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "cb2340e4-db1f-4649-90a3-a1038eeb9630",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"**当我们使用覆盖的时间点来作为脚标**,则有:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C_{decoder} = \\begin{bmatrix}\n",
|
|||
|
|
"c_{1} & c_{1} & \\ldots & c_{1} \\\\\n",
|
|||
|
|
"c_{1 \\to 2} & c_{1 \\to 2} & \\ldots & c_{1 \\to 2} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"\n",
|
|||
|
|
"<font color=\"red\">**同样的,这里出于教学目的,省略了特征维度上的脚标。现在你所看到的脚标只代表时间维度/序列长度的维度。**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a0b4f1bb-5796-4528-a134-19e98a189717",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"此时你会发现,我们必须从单词1开始预测,其流程为:\n",
|
|||
|
|
"\n",
|
|||
|
|
"> 单词1 用于预测 单词2<br><br>\n",
|
|||
|
|
"> 单词1、2 用于预测 单词3<br><br>\n",
|
|||
|
|
"> 单词1、2、3 用于预测 单词4"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "775919e2-3986-4ca4-8fc7-e8b537310386",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"但事实上,在生成式的例子中,我们可能会倾向于一开始就给与比较多的信息。我们真正要做的是“利用句子的前半段”去预测“句子的后半段”,大部分时候我们其实很少使用简单的几个单词、或1个单词来进行训练。而是倾向于使用下面的流程——\n",
|
|||
|
|
"\n",
|
|||
|
|
"> 单词1:n 用于预测 单词n+1<br><br>\n",
|
|||
|
|
"> 单词1:n+1 用于预测 单词n+2<br><br>\n",
|
|||
|
|
"> 单词1:n+2 用于预测 单词n+3"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "a64c9b94-284c-4738-a0c4-a178ea81deea",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"如果要实现上面的流程,可以怎么做呢?可以通过移动前瞻掩码矩阵的对角线来实现——"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 34,
|
|||
|
|
"id": "b7c6fbf5-e8ca-489e-bd4c-3f32ac674d0b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import torch\n",
|
|||
|
|
"\n",
|
|||
|
|
"def create_look_ahead_mask(seq_len, start_seq = 1):\n",
|
|||
|
|
" mask = torch.triu(torch.ones((seq_len, seq_len)),diagonal=start_seq) # triu 左下方的三角矩阵,diagonal控制对角线位置\n",
|
|||
|
|
" #mask = mask.float() * -1e9 # 将未来的位置设置为负无穷大\n",
|
|||
|
|
" return mask"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 35,
|
|||
|
|
"id": "14030c34-efc4-405d-9e61-90ce434ceda9",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 35,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"create_look_ahead_mask(10) #为了教学方便,现在展示的是1和0,实际应该是右上角负无穷,左下角0"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 36,
|
|||
|
|
"id": "c60161b9-8f5c-49ab-b152-31a24bf4f76a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"tensor([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
|||
|
|
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"execution_count": 36,
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "execute_result"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"create_look_ahead_mask(10,start_seq=4) #通过调节对角线,可以让掩码的区域缩小,从而可以允许更多信息的注入"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "d041f9e0-0620-4666-b104-a4a86eda87dd",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"当前瞻掩码从第一个时间步开始时,掩码注意力层输出的结果覆盖的时间步为:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C_{decoder} = \\begin{bmatrix}\n",
|
|||
|
|
"c_{1} & c_{1} & \\ldots & c_{1} \\\\\n",
|
|||
|
|
"c_{1 \\to 2} & c_{1 \\to 2} & \\ldots & c_{1 \\to 2} \\\\\n",
|
|||
|
|
"c_{1 \\to 3} & c_{1 \\to 3} & \\ldots & c_{1 \\to 3} \\\\\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "6b473aba-cd11-4654-89ff-561f98eaef93",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"当前瞻掩码从第4个时间步开始时,掩码注意力层输出的结果覆盖的时间步为:\n",
|
|||
|
|
"\n",
|
|||
|
|
"$$\n",
|
|||
|
|
"C_{decoder} = \\begin{bmatrix}\n",
|
|||
|
|
"c_{1 \\to 4} & c_{1 \\to 4} & \\ldots & c_{1 \\to 4} \\\\\n",
|
|||
|
|
"c_{1 \\to 5} & c_{1 \\to 5} & \\ldots & c_{1 \\to 5} \\\\\n",
|
|||
|
|
"c_{1 \\to 6} & c_{1 \\to 6} & \\ldots & c_{1 \\to 6} \\\\\n",
|
|||
|
|
"c_{1 \\to 7} & c_{1 \\to 7} & \\ldots & c_{1 \\to 7}\n",
|
|||
|
|
"\\end{bmatrix}\n",
|
|||
|
|
"$$"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "69f86f4a-81ea-4fde-a76c-cc3038ef2744",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"这样可以第一次预测过程中所使用的标签为“前n个字”而不是“第一个字”。当然,这已经是属于“自定义掩码”的范围,在实际中并不多见。但通过这种掩码方式,可以要求解码器产出的注意力分数完整接收前几个字之间的相互关系、从而一开始就使用“前半段话”来进行训练。在之后实现Decoder-only预测的过程中,我们将会更详细地讲解这个流程。"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "8548805f-deb6-4fe2-8309-caed8139022f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"在探索了Transformer模型的全貌之后,我们可以看到这一架构之所以在自然语言处理和其他序列处理任务中表现卓越,归功于其创新的设计和高效的信息处理能力。从自注意力机制到编码器和解码器的层叠结构,每一部分都精心设计以最大化上下文信息的利用,并提高计算的并行性。Transformer不仅改变了我们处理文本的方式,也为机器学习领域提供了一种强大的工具,用以解决一系列复杂的序列建模问题。\n",
|
|||
|
|
"\n",
|
|||
|
|
"自注意力机制使模型能够灵活地捕捉序列内的长距离依赖,而无需依赖于递归网络结构,从而避免了梯度消失和计算效率低下的问题。编码器层通过逐层处理输入数据,有效地提取和聚合信息;而解码器层则利用编码器的输出,结合自回归的方式逐步构建输出序列。通过这种方式,Transformer能够在翻译、文本生成、摘要等任务中生成准确且连贯的文本。\n",
|
|||
|
|
"\n",
|
|||
|
|
"此外,编码器-解码器注意力机制是理解输入与输出之间复杂关系的关键,它使得模型能够在生成每个输出时都考虑到与输入序列的具体关联。这种能力使得Transformer不仅适用于传统的NLP任务,还可以扩展到如图像处理和多模态任务中,展示了其极大的灵活性和广泛的适用性。\n",
|
|||
|
|
"\n",
|
|||
|
|
"总的来说,Transformer的出现标志着深度学习在处理序列数据方面的一个重大进步。随着研究的深入和技术的发展,我们期待看到更多基于Transformer的创新应用,这将进一步推动人工智能领域的边界向前发展。"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"metadata": {
|
|||
|
|
"kernelspec": {
|
|||
|
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
|
"language": "python",
|
|||
|
|
"name": "python3"
|
|||
|
|
},
|
|||
|
|
"language_info": {
|
|||
|
|
"codemirror_mode": {
|
|||
|
|
"name": "ipython",
|
|||
|
|
"version": 3
|
|||
|
|
},
|
|||
|
|
"file_extension": ".py",
|
|||
|
|
"mimetype": "text/x-python",
|
|||
|
|
"name": "python",
|
|||
|
|
"nbconvert_exporter": "python",
|
|||
|
|
"pygments_lexer": "ipython3",
|
|||
|
|
"version": "3.13.5"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"nbformat": 4,
|
|||
|
|
"nbformat_minor": 5
|
|||
|
|
}
|