Transformer

Transformer

Model Architecture

Fig.1 Transformer
Fig.1 Transformer

Transformer was firsr used for machine translation with an an encoder-decoder structure, see Fig.1. The input (source) and output (target) sequence embeddings are added with positional encoding before being fed into the encoder and the decoder that stack modules based on self-attention.

Scaled Dot-Product Attention

Fig.2 Scaled Dot-Product Attention
Fig.2 Scaled Dot-Product Attention

Transformer uses “Scaled Dot-Product Attention”. The input consists of queries $Q$ and keys $Q$ of dimension $d_k$, and values $V$ of dimension $d_v$. This attention computes the dot products of the query with all keys, divide each by $\sqrt{d_k}$ and apply a softmax function to obtain the weights on the values. The scaling factor $d_k$ is used to counteract the effect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.
$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

The Pytorch implementation is

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from torch import nn
import torch.nn.functional as F
import einops
import math

class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k):
super(ScaledDotProductAttention, self).__init__()
self.scaling_factor = math.sqrt(d_k)

def forward(self, q, k, v, mask=None):

qk = einops.einsum(q, k, '... q_s d_k, ... kv_s d_k -> ... q_s kv_s')

if mask is not None:
qk = qk.masked_fill(mask == 0, -1e8)

score = F.softmax(qk / self.scaling_factor, dim=-1)
attn = einops.einsum(score, v, '... q_s kv_s, ... kv_s d_v -> ... q_s d_v')

return attn

Multi-head Attention

Fig.3 Multi-head Attention
Fig.3 Multi-head Attention

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. Instead of performing a single attention function with $d_{model}$-dimensional keys, values and queries, it’s beneficial to linearly project the queries, keys and values $h$ times with different, learned linear projections to $d_k, d_k$ and $d_v$ dimensions, respectively. On each of these projected versions of queries, keys and values, the attention function is performed in parallel, yielding $d_v$-dimensional output values. These are concatenated and once again projected, resulting in the final values. In this work, $d_k=d_v=d_{model} / h$

The Pytorch implementation is

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from einops.layers.torch import Rearrange, Repeat

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.to_multi_head = Rearrange('... s (h d_k) -> ... h s d_k', h=n_heads)
self.mask_repeat = Repeat(mask, '... s -> ... h s d_k', h=n_heads, d_k=d_model // n_heads)
self.attention = ScaledDotProductAttention(d_model // n_heads)
self.concat = Rearrange('... h s d_k -> ... s (h d_k)')
self.w_o = nn.Linear(d_model, d_model)

def forward(self, q, k, v, mask=None):

mh_q, mh_k, mh_v = self.to_multi_head([self.w_q(q), self.w_k(k), self.w_v(v)])

if mask is not None:
mask = self.mask_repeat(mask)
mh_attn = self.attention(mh_q, mh_k, mh_v, mask=mask)
attn = self.concat(mh_attn)

o = self.w_o(attn)

return o

Point-wise Feed-Forward Networks

This consists of two linear transformations with a ReLU activation in between.

$$
\text{FFN}(x) = \max(0, xW_1+b_1)W_2+b_2
$$

The Pytorch implementation is

1
2
3
4
5
6
7
8
9
class PointWiseFFN(nn.Module):
def __init__(self, d_model, d_ff):
super(PointWiseFFN, self).__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.act = nn.ReLU()
self.linear_2 = nn.Linear(d_model, d_ff)

def forward(self, input):
return self.linear_2(self.act(self.linear_1(input)))

Positional Encoding

Since transformer contains no recurrence and no convolution, in order for the model to make use of the order of the sequence, “positional encodings” as added to the input embeddings. For transformer, the authors use sine and cosine functions of different frequencies:

\[ \begin{aligned} PE(pos, 2i) & = \sin (pos/10000^{2i/d_{model}}) \\ PE(pos, 2i+1) & = \cos (pos/10000^{2i/d_{model}}) \\ \end{aligned} \]

As the two embedding layers and the pre-softmax linear transformation share the same weight, the output of embedding layer should be multiplied by $d_{model}$ (because the pre-softmax linear layer usually inited by xaiver init)

$$
W ~ N(0, 1/d_{model})
$$

The Pytorch implementation is

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class PositionalEncoding(nn.Module):
def __init__(self, max_len, d_model):
super(PositionalEncoding, self).__init__()
self.scaling_factor = math.sqrt(d_model)
self.p = torch.zeros((1, max_len, d_model))
index = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, d_model, 2, dtype=torch.float32) / d_model)
self.p[:, :, 0::2] = torch.sin(index)
self.p[:, :, 1::2] = torch.cos(index)

def forward(self, input):

device = input.device
seq_len = input.shape[1]
output = input * self.scaling_factor + self.p[:, :seq_len].to(device)

return output

Encoder

The encoder is composed of a stack of identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network. A residual connection is employed around each of the two sub-layers, followed by layer normalization. That is, the output of each sub-layer is $\text{LayerNorm}(x+\text{Sublayer}(x))$, where $\text{Sublayer}(x)$ is the function implemented by the sub-layer itself.

Dropout is applied to the output of each sub-layer, before it is added to the sub-layer input and normalized. In addition, dropout is applied to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks.

The Pytorch implementation is

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class EncoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout):
super(EncoderBlock, self).__init__()
self.sublayer_1 = MultiHeadAttention(d_model, n_heads)
self.drop_1 = nn.Dropout(dropout)
self.norm_1 = nn.LayerNorm(d_model)
self.sublayer_2 = FFN(d_model, d_ff)
self.drop_2 = nn.Dropout(dropout)
self.norm_2 = nn.LayerNorm(d_model)

def forward(self, src, src_mask=None):

o1 = self.sublayer_1(src, src, src, src_mask)
src = self.norm_1(src + self.drop_1(o1))

o2 = self.sublayer_2(src)
src = self.norm_2(src + self.drop_2(o2))

return src

class Encoder(nn.Module):
def __init__(self, num_blocks, vocab_size, max_len, d_model, n_heads, d_ff, dropout):
super(Encoder, self).__init__()
self.ebd = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_encode = PositionalEncoding(max_len, d_model)
self.drop = nn.Dropout(dropout)
self.blks = nn.ModuleList([EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(num_blocks)])

def forward(self, src, src_mask=None):

src = self.drop(self.pos_encode(self.ebd(src)))
for blk in self.blks:
src = blk(output, src_mask)

return src

Decoder

The decoder is also composed of a stack of identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. The self-attention sub-layer in the decoder stack is modified to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position $i$ can depend only on the known outputs at positions less than $i$.

The Pytorch implementation is

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class DecoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout):
super(DecoderBlock, self).__init__()
self.sublayer_1 = MultiHeadAttention(d_model, n_heads)
self.drop_1 = nn.Dropout(dropout)
self.norm_1 = nn.LayerNorm(d_model)
self.sublayer_2 = MultiHeadAttention(d_model, n_heads)
self.drop_2 = nn.Dropout(dropout)
self.norm_2 = nn.LayerNorm(d_model)
self.sublayer_3 = FFN(d_model, d_ff)
self.drop_3 = nn.Dropout(dropout)
self.norm_3 = nn.LayerNorm(d_model)

def forward(self, src, tgt, src_mask=None, tgt_mask=None):

o1 = self.sublayer_1(tgt, tgt, tgt, tgt_mask)
tgt = self.norm_1(tgt + self.drop_1(o1))

o2 = self.sublayer_2(src, src, tgt, src_mask)
tgt = self.norm_2(tgt + self.drop_2(o2))

o3 = self.sublayer_3(tgt)
tgt = self.norm_3(tgt + self.drop_3(o3))

return tgt

class Dencoder(nn.Module):
def __init__(self, num_blocks, vocab_size, max_len, d_model, n_heads, d_ff, dropout):
super(Dencoder, self).__init__()
self.ebd = nn.Embedding(vocab_size, d_model)
self.pos_encode = PositionalEncoding(max_len, d_model)
self.drop = nn.Dropout(dropout)
self.blks = nn.ModuleList([DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(num_blocks)])

def forward(self, src, tgt, src_mask=None, tgt_mask=None):

tgt = self.drop(self.pos_encode(self.ebd(tgt)))
for blk in self.blks:
tgt = blk(src, tgt, src_mask, tgt_mask)

return tgt

Transformer Pytorch Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Transformer(nn.Module):
def __init__(self, num_blocks, src_vocab_size, tar_vocab_size, src_max_len, tar_max_len, d_model, n_heads, d_ff, dropout):
super(Transformer, self).__init__()
self.encoder = Encoder(num_blocks, src_vocab_size, src_max_len, d_model, n_heads, d_ff, dropout)
self.decoder = Decoder(num_blocks, tar_vocab_size, tar_max_len, d_model, n_heads, d_ff, dropout)
self.fc = nn.Linear(d_model, tar_vocab_size)
# weight sharing
self.fc.weight = self.decoder.ebd.weight
if src_vocab_size == tar_vocab_size:
self.encoder.ebd.weight = self.decoder.ebd.weight

def forward(self, src, tgt, src_mask=None, tgt_mask=None):

if src_mask is None:
src_mask = src > 0
if tgt_mask is None:
tgt_mask = tgt > 0

enc_output = self.encoder(src, src_mask)
dec_output = self.decoder(enc_output, tgt, src_mask, tgt_mask)

output = F.softmax(self.fc(dec_output[:, 0]), dim=-1)

return output

Reference


Transformer
https://blog.iks-ran.com/2023/07/19/transformer/
Author
iks-ran
Posted on
July 19, 2023
Licensed under