In Part 1 of the tutorial we learned how to how build a transformer model. Now it is time to put it into action. There are lots of things we can do with a transformer but we will start off with a machine translation task to translate from Spanish to English.

The code for Parts 1 and 2 of this tutorial can be found in this Colab notebook.

Note: this a repost with minor modifications from Annotaited, where it was initially published on 19th February 2021.



We used the Adam optimizer [20] with $\beta_1 = 0.9$, $\beta_2 = 0.98$ and $\epsilon = 10$.

We varied the learning rate over the course of training, according to the formula:

\[\text{lrate}=d_\text{model}^{−0.5} \cdot \min(\text{step_num}^{−0.5}, \text{step_num}\cdot\text{warmup_steps}^{−1.5})\]

This corresponds to increasing the learning rate linearly for the first warmup_steps training steps, and decreasing it thereafter proportionally to the inverse square root of the step number. We used $\text{warmup_steps} = 4000$.

Create a custom schedule using tf.keras.optimizers.schedules.LearningRateSchedule. It should have the following methods:

__init__ Initialise the hyperparameters
__call__ Receives a value step as input and returns the learning rate according to the above equation
class LearningRateScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps):
        self._d_model = d_model
        self._warmup_steps = warmup_steps
        self.d_model_term = d_model ** (-0.5)
        self.warmup_term = warmup_steps ** (-1.5)

    def get_config(self):
        return dict(d_model=self._d_model, warmup_steps=self.warmup_steps)

    def __call__(self, step):
        step_term = tf.math.minimum(step ** (-0.5), step * self.warmup_term)
        return self.d_model_term * step_term

Try plotting the learning rate curve for a few different values of d_model and warmup_steps to get a sense of how it evolves over time.

plt.figure(figsize=(8, 8))
t = tf.range(20000, dtype=tf.float32)
d_model = np.stack([128, 256, 512])
warmup_steps = [2000, 4000, 6000]

for cmap, w in zip(['Reds', 'Greens', 'Blues'], warmup_steps):
  cm =
  for i, d in enumerate(d_model, 1):

    lr_wd = LearningRateScheduler(
    val = lr_wd(t).numpy()
    clr = cm(int(i * cm.N / len(d_model)))
    plt.plot(t.numpy(), val, label=f'warmup_steps={w}, d_model={d}', c = clr)

plot showing learning rate for d_model={128, 256, 512} x warmup_steps={2000, 4000, 6000}
for 20000 steps


The model is trained with a softmax loss over all the tokens in the vocabulary with label smoothing.

Label smoothing

During training, we employed label smoothing of value $\epsilon_{ls}$ = 0.1.

A very brief guide to label smoothing

  • The target class distribution $q(k\vert x)$ over $K$ classes is the one-hot distribution where $q(k\vert x_i) = \delta_{k,y_i}$.
  • To do label smoothing we subtract a small quantity $\epsilon_{ls}$ from $q(y_i\vert x)$ i.e. the 1 in the one-hot vector and divide this over all the other classes.
  • Typically it is divided uniformly over the classes as follows

    \[q'(k\vert x_i) = \delta_{k,y_i} * (1 - \epsilon_{ls}) + \epsilon_{ls}/K\]
  • The idea is that by smoothing the target distribution you encourage the model to be less confident about its predictions.

Write a function smooth_labels which implements the equation above. The inputs will be a tensor of one-hot labels of shape [..., K] and the smoothing parameter eps.

def smooth_labels(labels, eps=0.1):
    num_classes = tf.cast(tf.shape(labels)[-1], labels.dtype)
    labels = (labels - eps) + eps / num_classes
    return labels


The idea is that we should average the losses across all sequences and timesteps:

\[L = \frac{1}{N \cdot T}\sum_{n=0}^{N-1}\sum_{t=0}^{T-1}\text{loss}(\hat{y}_{nt}, y_{nt})\]

However we should remember to exclude the zero padded elements. As we shall see, at inference time as soon as the model predicts an <end> we stop predicting further. So the model should not be trained to predict anything after the <end> symbol.

Write a MaskedLoss class that wraps a loss function and when called returns a masked average of the loss.

class MaskedLoss(object):
    def __init__(self, loss_fn):
        self.loss_fn = loss_fn

    def __call__(self, y_true, y_pred, mask, **kwargs):
        loss = self.loss_fn(y_true=y_true, y_pred=y_pred, **kwargs)
        mask = tf.cast(mask, loss.dtype)
        return tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)


It is convenient to store all the training and inference settings and hyperparameters in a single config object. For simplicity we will use an EasyDict, which, if you have not previously used one, behaves just like an ordinary dict except that its keys can be accessed like attributes i.e. config.key1 returns config['key1'].

Here is our config. We will add more settings to it as we go along

config = EasyDict()
config.smoothing_prob = 0.1 = EasyDict() = 0


Now that all the pieces of the training are in place we can prepare the data. We will be using the crowdsourced Many Things Spanish-English dataset which consists of sentences and phrases in these two languages. Datasets are consisting of translations between many different languages and English available from the website so feel free to swap in another language. Depending on what language you use you might need to modify the preprocessing code to handle special characters.

Download the data

Download the data from here and save it to a directory of your choice and unzip it. If using Colab I recommend that you connect it to Google Drive so that your that your files will be saved even when the runtime gets disconnected.

You can do so by running these lines

from google.colab import drive

A link will appear that you need to click to get an authentication code which you input into a box below the link.

You also download and unzip the data by running:

# Modify the path to your save path
SAVE_PATH = "/content/drive/My Drive/transformer-tutorial"
if not os.path.exists(SAVE_PATH):

DATA_DIR = os.path.join(SAVE_PATH, "data")
if not os.path.exists(DATA_DIR):

!wget -P "{DATA_DIR}"
!unzip "{DATA_DIR}/"  -d "{DATA_DIR}/spa-eng"

DATA_PATH = os.path.join(DATA_DIR, "spa-eng")


First we will add some functions that will do some basic processing to get rid of special characters, trim whitespace and separate punctuation from words.

def unicode_to_ascii(s):
  return ''.join(c for c in unicodedata.normalize('NFD', s)
    if unicodedata.category(c) != 'Mn')
def preprocess_sentence(w):
  w = unicode_to_ascii(w.lower().strip())

  # Put spaces between words and punctuation
  w = re.sub(r"([?.!¿])", r" \1 ", w)
  # Reduce several whitespaces to a single one 
  w = re.sub(r'[" "]+', " ", w)

  # Replace characters with whitespace unless in (a-z, A-Z, ".", "?", "!", ",")
  w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)

  w = w.strip()

  # Add start / end tokens
  w = ' '.join(['<start>', w, '<end>'])
  return w 

Now we can load the data and preprocess the sequences. The data consists of tab-separated columns of text as shown below. We only ned the first two columns which contain pairs of Spanish and English sequences. The second column is English and the first column is the other language.

def load_dataset(path, num_examples):
  with open(path, encoding='UTF-8') as f:
    lines ='\n')
  word_pairs = [[preprocess_sentence(w) for w in l.split('\t')[:2]] for l in lines[:num_examples]]
  return zip(*word_pairs)

A few randomly sampled pairs from the processed dataset. A character called “Tom” seems to feature in a lot of the examples.

spanish: <start> da la impresion de que has estado llorando . <end>
english: <start> it looks like you ve been crying . <end>

spanish: <start> tom y mary hicieron lo que se les dijo . <end>
english: <start> tom and mary did what they were told . <end>

spanish: <start> cuando no queda nada por hacer, ¿ que haces ? <end>
english: <start> when there s nothing left to do, what do you do ? <end>

spanish: <start> llevo toda la semana practicando esta cancion . <end>
english: <start> i ve been practicing this song all week . <end>

spanish: <start> tom era feliz . <end>
english: <start> tom was happy . <end>

spanish: <start> llamalo, por favor . <end>
english: <start> please telephone him . <end>

spanish: <start> vere que puedo hacer . <end>
english: <start> i ll see what i can do . <end>

spanish: <start> hubo poca asistencia a la conferencia . <end>
english: <start> a few people came to the lecture . <end>

spanish: <start> mi papa no lo permitira . <end>
english: <start> my father won t allow it . <end>

spanish: <start> ¿ como fue el vuelo ? <end>
english: <start> how was the flight ? <end>

Splits and tokenisation

To create inputs that we can feed into the model we need to do the following:

  • Create train-val-test splits
  • Split the sentences into individual tokens (words and punctuation in this case)
  • Create mapping from tokens to numbers based on the training split for each language
  • Use this pair of mappings to convert all the all data into numerical sequences that will be fed in to the embedding layers that we have defined earlier

First let us write a function that creates a tokeniser, fits it on the training set and returns it to use subsequently. The validation and test sets might contain words not present in the training set and this is handled by replacing this with an <unk> token.

def get_tokenizer(lang, num_words=None):
  lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
  return lang_tokenizer

We will split the data into train/val/test for input and target, using 70% of the data for training and split the remaining examples into two equal parts for val and test.

def load_split(path, num_examples=None, seed=1234):
    tar, inp = load_dataset(path, num_examples)
    inp_trn, inp_valtest, tar_trn, tar_valtest = train_test_split(inp, tar, test_size=0.3, 
    inp_val, inp_test, tar_val, tar_test = train_test_split(inp_valtest, tar_valtest, test_size=0.5,  
     # Delete to avoid returning them by mistake 
    del inp_valtest, tar_valtest
    return (inp_trn, inp_val, inp_test), (tar_trn, tar_val, tar_test)

To build our tokenised datasets, we will create a tokeniser for each language, use it to process the data and return these along with the splits. We also return the tokeniser as we need to hold on it to use later to map the model outputs into words.

def create_dataset(inp, tar, num_inp_words=None, num_tar_words=None):
  inp_tokenizer = get_tokenizer(inp[0], num_words=num_inp_words)
  tar_tokenizer = get_tokenizer(tar[0], num_words=num_tar_words)

  inp_trn_seq, inp_val_seq, inp_test_seq = [
    inp_tokenizer.texts_to_sequences(inp_split) for inp_split in inp

  tar_trn_seq, tar_val_seq, tar_test_seq = [
    tar_tokenizer.texts_to_sequences(tar_split) for tar_split in tar

  inputs = dict(
      train=inp_trn_seq, val=inp_val_seq, test=inp_test_seq

  targets = dict(
      train=tar_trn_seq, val=tar_val_seq, test=tar_test_seq

  return inputs, targets, inp_tokenizer, tar_tokenizer

Let us try out the tokeniser:

inp, tar = load_split(os.path.join(DATA_PATH, 'spa.txt'), num_examples=None, seed=1234)
tmp_inp_tkn = get_tokenizer(inp[0])
tmp_tar_tkn = get_tokenizer(tar[0])

print(inp[0][0], tar[0][0])
tmp_inp_tkn.texts_to_sequences([inp[0][0]]), tmp_tar_tkn.texts_to_sequences([tar[0][0]])
<start> no son idiotas . <end> <start> they re not stupid . <end>
([[2, 8, 65, 3396, 4, 3]], [[2, 48, 50, 41, 448, 4, 3]])

To keep our model relatively small we will use just the top 10000 words in each language = 10000 = 10000

Let us generate the data splits and save them along with the tokenizer to be able to reuse them later without having to regenerate them.

# If path doesn't exist create dataset or let RECREATE_DATASET = True to
# force creation
if not os.path.exists(os.path.join(DATA_PATH, 'splits.pkl')) or RECREATE_DATASET:
  print('Creating dataset')
  inputs, targets, inp_tkn, tar_tkn = create_dataset(
      inp, tar ,,
  for k, v in inputs.items():
    print(k, len(v), len(targets[k])) = os.path.join(DATA_PATH, 'splits.pkl')
  with open(, 'wb') as f:

The vocabulary of num_words includes <start>, <end>, <unk> (for unknown word) and implicitly includes padding. tf.keras.preprocessing.text.Tokenizer does not have an explict token for padding but starts the the token ids from 1 so that when we use the tokeniser below to decode the predictions we need to get rid of the padding first.

print(inp_tkn.index_word.get(0), tar_tkn.index_word.get(0))
None None
['enloquecido'] ['harbored']
['<unk>'] ['<unk>']

Then we can load the splits.

with open(, 'rb') as f:
  data = pickle.load(f)

inp_trn = data['inputs']['train']
inp_val = data['inputs']['val']

tar_trn = data['targets']['train']
tar_val = data['targets']['val']

inp_tkn = data['inp_tkn']
tar_tkn = data['tar_tkn']

Here is how to decode the input sequence. Decoding can sometimes be lossy when there are words not present in the vocabulary, as they get converted to <unk>

def convert(tensor, tokenizer):
  for t in tensor:
    if t != 0:
      print(f'{t} -----> {tokenizer.index_word[t]}')

convert(inp_trn[0], inp_tkn)
convert(tar_trn[0], tar_tkn)
2 -----> <start>
8 -----> no
65 -----> son
3396 -----> idiotas
4 -----> .
3 -----> <end>

2 -----> <start>
48 -----> they
50 -----> re
41 -----> not
448 -----> stupid
4 -----> .
3 -----> <end>

Data pipeline

Now that our data is all prepared let us build the input pipeline using = 64
# - Data will be padded to the longest sequence length in the batch
# - There are ways to optimise this such that groups of sequences of similar length 
#   are batched together so that computation is not wasted on the padded elements
#   but we won't worry about that for now as we are dealing 
#   with quite a small dataset and model.
padded_shapes = ([None], [None])
buffer_size = len(inp_trn)
train_dataset =, 
train_dataset = train_dataset.shuffle(buffer_size).batch(
val_dataset =, 
val_dataset = val_dataset.batch(

Although called the target, the sequence for the target language is used both as input to the decoder and as target for the final prediction. For sequence of length $T$, the first $T-1$ elements of the sequence are input to the decoder which is trained to predicts the last $T-1$ elements. For convenience and to avoid mistakes let us write a simple function that splits the target into two parts tar_inp and tar_real. Assume that the target is of arbitrary rank but the sequence or “time” dimensions is the last dimension e.g. it could have shape [batch_size, T] or it could just have shape [T]

def split_target(target):
    tar_real = target[..., :-1]
    tar_inp = target[..., 1:]
    return tar_ral, tar_inp


In this tutorial we won’t go into the metrics like BLEU that are used for evaluating translation models as they merit a tutorial in themselves but use loss and accuracy to get a sense of how well the model is doing. The caveat is that accuracy is not the best metric to use when you don’t have a balanced dataset - and a vocabulary is not particularly well-balanced - since the model can do badly on infrequently occuring classes (here words) and that won’t really be reflected in the accuracy. Moreover iit only deals with pointwise comparisons so doesn’t take into account the sequential nature of the data.

However for this simple application, it will give us a reasonable idea of whether the model is doing well or badly. Once we have trained the model we will also make some predictions and evaluate them qualitatively.

Here is the masked accuracy function we will be using:

def accuracy_function(real, pred, pad_mask):
    pred_ids = tf.argmax(pred, axis=2)
    accuracies = tf.cast(tf.equal(tf.cast(real, pred_ids.dtype), pred_ids), tf.float32)
    mask = tf.cast(pad_mask, dtype=tf.float32)
    return tf.reduce_sum(accuracies * mask)/tf.reduce_sum(mask)

Set up the experiment

Add some more fields to config. Note that in the interests of speed we are using quite small model. Feel free with change these settings and try different experiments.

config.model = EasyDict()
config.model.model_dim = 128
config.model.ff_dim = 512
config.model.num_heads = 8
config.model.num_encoder_blocks = 4
config.model.num_decoder_blocks = 4
config.model.dropout = 0.1

config.smoothing_prob = 0.1

Let us initialise instances of the following using the settings from config:

  • LearningRateScheduler
  • MaskedLoss
  • SequenceMask
  • TargetMask
  • Transformer
  • An optimiser with the following specification

    We used the Adam optimizer [20] with $\beta_1 = 0.9, \beta_2 = 0.98 and \epsilon = 10$.

lr = LearningRateScheduler(

loss_function = MaskedLoss(
            from_logits=True, reduction="none",

pad_masker = SequenceMask(
future_masker = TargetMask(

model = Transformer(

optim = tf.optimizers.Adam(

Training loop

We will write a simple class Trainer that will hold together all these components and which will have methods for running training and validation steps.

class Trainer(tf.Module):
    def __init__(self,
                 model: tf.keras.models.Model,
        self.config = config
        self.model = model
        self.pad_masker = pad_masker
        self.future_masker = future_masker
        self.loss_function = loss_function
        self.optim = optim
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
        self.val_loss = tf.keras.metrics.Mean(name='val_loss')
        self.val_accuracy = tf.keras.metrics.Mean(name='val_accuracy')

        tf.TensorSpec(shape=[None, None], dtype=tf.int32),
        tf.TensorSpec(shape=[None, None], dtype=tf.int32)
    def train_step(self, inp, tar):
        # TODO: Implement this

            tf.TensorSpec(shape=[None, None], dtype=tf.int32),
            tf.TensorSpec(shape=[None, None], dtype=tf.int32)
    def valid_step(self, inp, tar):
        # TODO: Implement this

Here is what train_step should do:

  • Run the model and make predictions
  • Calculate the loss
  • Calculate the loss gradients
  • Pass the gradients to optimise and get it to update the weights
  • Update train_loss and train_accuracy (using accuracy_function from above for the latter)


  • If you have not written a custom training loop with TensorFlow before take a look at the train_step code here. You need to do something quite similar, just using the transformer as the model.
  • Remember to add label smoothing for the labels before calculating the loss
  • Use label smoothing only for the training loss, not for validation and not for calculating the accuracies either
  • Remember to squeeze the singleton dimensions of the pad mask before passing to accuracy function i.e. from ((B, 1, 1, N) to (B, N))
        tf.TensorSpec(shape=[None, None], dtype=tf.int32),
        tf.TensorSpec(shape=[None, None], dtype=tf.int32)
    def train_step(self, inp, tar):
        tar_inp, tar_real = split_target(tar)
        tar_pad_mask = self.pad_masker(tar_real)[:, 0, 0, :]

        with tf.GradientTape() as tape:
            predictions, _ = self.model(
                inp, tar_inp,

            tar_onehot = tf.one_hot(tar_real, depth=tf.shape(predictions)[-1])

            if self.config.smoothing_prob > 0:
                labels = smooth_labels(tar_onehot, self.config.smoothing_prob)
                labels = tar_onehot

            loss = self.loss_function(

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optim.apply_gradients(zip(gradients, self.model.trainable_variables))

        self.train_accuracy(accuracy_function(tar_real, predictions, tar_pad_mask))

Now write valid_step which will be very similar to train_step but with the gradient steps. Remember to call the model with training=False.

            tf.TensorSpec(shape=[None, None], dtype=tf.int32),
            tf.TensorSpec(shape=[None, None], dtype=tf.int32)
    def valid_step(self, inp, tar):
        tar_inp, tar_real = split_target(tar)
        tar_pad_mask = self.pad_masker(tar_real)[:, 0, 0, :]

        predictions, _ = self.model(
                inp, tar_inp,

        loss = self.loss_function(
            y_true=tf.one_hot(tar_real, depth=tf.shape(predictions)[-1]),

        self.val_accuracy(accuracy_function(tar_real, predictions, tar_pad_mask))


Some code to save checkpoints. Here we are only saving a single checkpoint corresponding to the lowest validation loss but feel free to change that.

CKPT_PATH = os.path.join(SAVE_PATH, 'checkpoints') 
config.ckpt_path = CKPT_PATH
ckpt = tf.train.Checkpoint(
ckpt_manager = tf.train.CheckpointManager(
    ckpt, config.ckpt_path, max_to_keep=1
if ckpt_manager.latest_checkpoint:
    print('Latest checkpoint restored')


Now add a few more settings to config and we are ready to run.

config.epochs = 10
config.log_every = 100 # how frequently to print metrics

If you want this to run fast connect to a GPU runtime by going to Runtime > Change Runtime Type from the top menu and selecting “GPU”. The code has not been carefully optimised to run fast but it took between 15-30 minutes to run for 10 epochs (presumably depending on what GPU I got).

After 10 epochs you should observe a validation accuracy of around 79%.

best_loss = float('inf')
trainer = Trainer(config, model, pad_masker, future_masker, loss_function, optim)
for epoch in range(config.epochs):
    start = time.time()
    for (batch, (inp, tar)) in enumerate(train_dataset):
        trainer.train_step(inp.to_tensor(), tar.to_tensor())
        if batch % config.log_every == 0:
            print (f'Epoch {epoch + 1} Batch {batch}'
                  f' Loss {trainer.train_loss.result(): .4f}',
                  f' Accuracy {trainer.train_accuracy.result(): .4f}')
    for (batch, (inp, tar)) in enumerate(val_dataset):
        trainer.valid_step(inp.to_tensor(), tar.to_tensor())
    val_loss = trainer.val_loss.result()
    if val_loss < best_loss:
      print(f"Validation loss decreased from {best_loss} to {val_loss}. Saving checkpoint")
      best_loss = val_loss

    print(f'Epoch {epoch + 1}',
                      f' Loss {trainer.train_loss.result(): .4f}',
                      f' Accuracy {trainer.train_accuracy.result(): .4f}',
                      f' Val loss {val_loss: .4f}',
                      f' Val accuracy {trainer.val_accuracy.result(): .4f}')
    print(f'Time taken for 1 epoch: {time.time() - start} secs')


So our model has trained for a bit. The loss has dropped and the accuracy has risen but what do the translations actually look like. To answer that question we need to write one more function. Write predict_sequence that takes a single input sequence and returns the predicted sequence along with attention maps.


  • To start the prediction input a start_symbol i.e. inp_tkn.word_index["<start>"]
  • At each step you should select the token which has the highest softmax prediction, append that to the existing translation and pass that in as the input to decoder.
  • Remember to stop predicting for a sequence when an <end> token has been predicted.
  • Remember that we defined encode and decode methods for the model. Use these instead of call to avoid having to predict the encoder output repeatedly.
  • However the decoder must be re-run for every step.
def predict_sequence(inputs, model, max_length,
                     start_symbol, end_symbol,
                     pad_masker, future_masker):
    inputs = inputs[None]
    result = tf.ones_like(inputs[:, :1]) * start_symbol
    encoding, enc_attn = model.encode(inputs, pad_masker(inputs), training=False)
    for _ in tf.range(max_length):
        predictions, self_attn, dec_attn = model.decode(result, encoding,
        # Just select the last sequence element and the symbol
        # with highest probability
        next_symbol = tf.argmax(predictions[:, -1], axis=-1, output_type=tf.int32)
        result = tf.concat([result, next_symbol[:, None]], axis=-1)

        # If sequence is done, stop
        if tf.equal(tf.squeeze(next_symbol), end_symbol):

    attention = dict(
        enc=tf.squeeze(tf.stack(enc_attn), axis=1),
        dec_self=tf.squeeze(tf.stack(self_attn), axis=1),
        dec_memory=tf.squeeze(tf.stack(dec_attn), axis=1)

    return tf.squeeze(result, axis=0), attention

Define a shuffled version of the validation dataset and translate a few sample sentences

val_sample_dataset =, 
val_sample_dataset = val_sample_dataset.shuffle(len(inp_val))
[inp_sample, tar_sample] = next(iter(val_sample_dataset))
pred, attn = predict_sequence(inp_sample, model, 50, inp_tkn.word_index['<start>'],
                 inp_tkn.word_index['<end>'], pad_masker, future_masker)

inp_sent_full = inp_tkn.sequences_to_texts(inp_sample.numpy()[None])[0]
tar_pred_sent_full = tar_tkn.sequences_to_texts(pred.numpy()[None])[0]
tar_true_sent_full = tar_tkn.sequences_to_texts(tar_sample.numpy()[None])[0]

print("spanish:", inp_sent_full)
print('english_pred:', tar_pred_sent_full)
print('english_true:', tar_true_sent_full)
spanish: estas personas dijeron que la guerra era una guerra civil .
english_pred: these people said war was a war .
english_true: these people said the war was a civil war .

spanish: dime donde pongo estos libros .
english_pred: tell me where i put these books .
english_true: tell me where to put these books .

spanish: necesito <unk> las piernas .
english_pred: i need to brush my legs .
english_true: i need to stretch my legs .

spanish: traeme una toalla .
english_pred: bring me a towel .
english_true: get me a towel .

spanish: tom dijo que tenia prisa .
english_pred: tom said he was in a hurry .
english_true: tom said that he was in a hurry .

Attention maps

Remember that we are also returning the attention maps. Since this architecture is literally “all” about attention we should take a look at those as well. Here is some code that will plot the attention maps for an input and output sequence.

def plot_attention_head(inp_sentence, tar_sentence, attention):
  # The attention map rows correspond to the last N - 1 tokens which
  # were predicted by the model. This does not include the '<start>'
  # symbol, which the model didn't generate.
  tar_sentence = tar_sentence[1:]

  ax = plt.gca()
  ax.matshow(attention, cmap='hot', vmin=0, vmax=1)

  labels = [label for label in inp_sentence]
      labels, rotation=90)

  labels = [label for label in tar_sentence]

This plots the first memory attention map from the final decoder head. Try modifying it to plot maps from other locations as well as self-attention maps.


plot showing attention weights for a single head

Now let us plot all the memory attention maps from the final decoder head.

def plot_attention_weights(inp_sentence, tar_sentence, attention_heads):
  fig = plt.figure(figsize=(16, 8))

  for h, head in enumerate(attention_heads):
    ax = fig.add_subplot(2, 4, h+1)

    plot_attention_head(inp_sentence, tar_sentence, head)

    ax.set_xlabel('Head {}'.format(h+1))


plot showing attention weights for all heads from a single decoder attention layer

What’s next

There are many ways this model can be extended. We can modify the architecture or use a different dataset. Metrics like BLEU will give us a better idea of the model’s performance. At inference time we are using a simple greedy approach choosing just the best prediction at each timestep which might not necessary lead to the best overall translation. In the paper they use beam search which stores a small number of thee top predictions at each stage searches for the best overall translation among these and this typically yields better results.

This is only the tip of the iceberg with regard to what Transformers are and what they can do. Many extensions to the original architecture have been developed and there are many more applications outside of NLP (such as image recognition and image generation). I have just linked to a handful that came to mind but there are many more.

I hope to cover some of these extensions through future tutorials. In the meantime I encourage you to experiment with making your own changes to the model.


This tutorial like the previous was inspired by The Annotated Transformer.

It also borrows and adapts code for data preparation, saving, plotting and other utility functions from the following sources:

  1. TensorFlow’s RNN-based translation tutorial:
  2. TensorFlow’s Transformer tutorial:


  1. Attention Is All You Need
  2. The Annotated Transformer
  3. Transformer model for language understanding
  4. Neural machine translation with attention