The DALL-E API from OpenAI is a very powerful image generation tool. However one limitation is that it only outputs square images. In this tutorial we will learn how to generate images of arbitrary dimension such as portraits and landscapes. This tutorial also covers the basics of using DALL-E via the openai
library in Python. To follow along make sure you have an account for using the OpenAI API.
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.util import montage
import requests
import os
import openai
from io import BytesIO
openai.api_key = "YOUR_API_KEY"
Here are the key details
generations
, variatons
and edits
generations
lets you create images based on a prompt.edits
allows you to input a mask and modify a section of the imagevariations
creates different variations of the input image256 x 256
, 512 x 512
, 1024 x 1024
, with cost increasing with size.To get started let us create a generate a few images based on a prompt. John Singer Sargent was a portrait painter in the late 19th and early 20th century famous for his glamorous high-society portraits. Here for instance is his famous painting of fashionable hostess Mrs Hugh Hammersley
Most of his subjects were European or American although he occasionally depicted other nationalities. Let us see if DALL-E can give us an idea of how he might have painted an Indian woman of that period.
prompt = "Portrait of an Indian lady in a green sari, in the style of John Singer Sargent, oil on canvas"
results = openai.Image.create(
prompt=prompt,
n=9,
size='256x256'
)
The API call returns a list of image urls from which we need to download the images. Let us a write a function to do that.
def get_image(image_url):
# Send a GET request to the URL and store the response
response = requests.get(image_url)
# Check if the request was successful
if response.status_code == 200:
# Write the image content to a buffer
image_buffer = BytesIO(response.content)
return Image.open(image_buffer)
else:
raise ValueError(f'Image could not be retrieved from {image_url}')
Fetch the images and display them as a montage
# url for image_i is in result['data'][i]
images = [get_image(data['url']) for data in results['data']]
m = montage([np.array(img) for img in images], padding_width=5, channel_axis=-1)
Image.fromarray(m)
Many of these look like cropped sections from a full length portrait, as indeed the training data would have been. But we can generate a larger, rectangular image using sliding windows. The initial image is regarded as a section of a larger image and used as context to generate the rest of the image. As a simple example let us extend the image at top and bottom.
First add padding to the top and bottom.
img_id = 3
h, w, f = img.shape
rectangle = np.zeros([int(h * 1.5), w, f], dtype='uint8')
start = h//4
rectangle[start:start + h] = np.array(images[img_id])
Image.fromarray(rectangle)
We will now use the edits
endpoint to fill in the blank regions. First let us split the image into two squares.
top = rectangle[:h]
bottom = rectangle[-h:]
Image.fromarray(montage([top, bottom], padding_width=5, grid_shape=(1, 2), channel_axis=-1))
To use the edits
endpoint you need to indicate which regions of the input should be modified. There are two ways to do this:
mask
input:An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where image should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as image.
image
:If mask is not provided, image must have transparency, which will be used as the mask.
You also need to provide a prompt. The prompt should describe the full image not only the section you want to modify.
Let us create masks for each section. The masks are in the RGBA format.
# Top part is all zeros
mask_top = np.concatenate([np.zeros([h//4, w, f+1]),
255*np.ones([h-h//4, w, f+1])]).astype('uint8')
# Bottom part is all zeros
mask_bottom = mask_top[::-1]
For the first image we will provide a separate mask. We need to write the mask and image to buffers. The prompt will be the same as for the original image.
top_buffer = BytesIO()
mask_buffer = BytesIO()
Image.fromarray(top).save(top_buffer, 'png')
Image.fromarray(mask_top).save(mask_buffer, 'png')
edit_top = openai.Image.create_edit(
image=top_buffer.getvalue(),
mask=mask_buffer.getvalue(),
prompt=prompt,
n=1,
size='256x256'
)
image_top = get_image(edit_top['data'][0]['url'])
image_top
For the bottom part, let’s add a transparency channel to the image in place of a separate mask.
bottom_buffer = BytesIO()
Image.fromarray(np.concatenate([bottom, mask_bottom[..., -1:]], -1)).save(bottom_buffer, 'png')
edit_bottom = openai.Image.create_edit(
image=bottom_buffer.getvalue(),
prompt=prompt,
n=1,
size='256x256'
)
image_bottom = get_image(edit_bottom['data'][0]['url'])
image_bottom
img_final = Image.new('RGB', (w, rectangle.shape[0]))
img_final.paste(image_top, (0, 0))
img_final.paste(image_bottom, (0, h//2))
img_final
That looks a lot better. You could of course keep on extending the image in all directions. But let us now create a landscape image.
Now we will try using the variations
endpoint to create landscape inspired by Van Gogh’s Wheatfield with Crows. This was painted on a double-square canvas so the width is approximately twice the height.
wheatfield = Image.open(
BytesIO(
requests.get(
'https://upload.wikimedia.org/wikipedia/commons/thumb/d/d3/Vincent_Van_Gogh_-_Wheatfield_with_Crows.jpg/2560px-Vincent_Van_Gogh_-_Wheatfield_with_Crows.jpg'
).content
)
)
wheatfield
Not that the width is a little over twice the height.
img_w, img_h = wheatfield.size
img_w, img_h
(2560, 1228)
The image variations are conditioned only on an image and don’t make use of a prompt. To create a “double-square” landscape we will
First crop a square section from the left and resize to 256x256
.
left = img_h // 2
img_centre = wheatfield.crop((left, 0, left + img_h, img_h))
assert set(img_centre.size) == {img_h}
img_centre = img_centre.resize((256, 256))
img_centre
Now generate a bunch of variations.
img_centre_buffer = BytesIO()
img_centre.save(img_centre_buffer, 'png')
variations = openai.Image.create_variation(
image=img_centre_buffer.getvalue(),
n=9,
size='256x256'
)
images_var = [get_image(data['url']) for data in variations['data']]
m_var = montage([np.array(img) for img in images_var], padding_width=5, channel_axis=-1)
Image.fromarray(m_var)
The prompt is inspired by the description of the painting in the Wikipedia entry. According to the DALL-E Prompt Book adding descriptors like “acclaimed” can sometimes help generate higher quality images. Let’s also make it clear that we would like the result to be a “masterpiece”.
prompt_var = 'A dramatic, cloudy sky filled with crows over a wheat field in the style of Van Gogh, oil on canvas, acclaimed, masterpiece'
A helper function to mask the image at the appropriate region.
def get_masked_image(img, pad):
alpha = 255 * np.ones_like(img[..., -1:])
img = np.concatenate([img, alpha], -1)
img = np.pad(img, pad)
return img
To fill in the left part we do the same as before.
var_id = 5
var = np.array(images_var[var_id])
hh, ww, f = var.shape
start_idx = hh // 2
img_left = get_masked_image(var[:, :start_idx], [(0, 0), (start_idx, 0), (0, 0)])
Image.fromarray(img_left)
Another helper function to get image bytes
def get_img_bytes(img, img_type='png'):
bfr = BytesIO()
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(bfr, img_type)
return bfr.getvalue()
edit_left = openai.Image.create_edit(
image=get_img_bytes(img_left),
prompt=prompt_var,
n=1,
size='256x256'
)
img_edit_left = get_image(edit_left['data'][0]['url'])
img_edit_left
Generate the right side is a bit more involved because we need to pass the sliding window across several times, using the output from the previous stage as input.
landscape = np.zeros([hh, hh*2, 3], dtype='uint8')
landscape[:, :ww] = np.array(img_edit_left)
inputs = []
start_idx = hh//2
while True:
inp = get_masked_image(landscape[:, start_idx:start_idx + ww-hh//2],
[(0, 0), (0, hh//2), (0, 0)])
inputs.append(inp)
edit_right = openai.Image.create_edit(
image= get_img_bytes(inp),
prompt=prompt_var,
n=1,
size='256x256'
)
landscape[:, start_idx:start_idx + ww] = np.array(get_image(edit_right['data'][0]['url']))
if start_idx == (landscape.shape[1] - ww):
break
start_idx += hh//2
Image.fromarray(landscape)
In this tutorial we will learn how to use the Whisper API to generate subtitles for a video. We will generate subtitles for the opening of A Star Is Born (1937), an early colour movie that has been frequently remade. This first version is in my opinion the best and very much worth watching if you have not done so before. The film is in the public domain and may be downloaded from the Internet Archive if you want to follow along.
Import stuff. moviepy
is a useful library for manipulating video files with python. You can install it with pip
pip install moviepy
Replace YOUR_API_KEY
with your key.
import moviepy.editor as mp
import requests
import openai
openai.api_key = "YOUR_API_KEY"
Change this to the location where you have saved the video.
filename = "YOUR_FILENAME"
video_intro = 'Star-intro.mp4'
audio_intro = 'Star-intro.mp3'
Let us load the movie and save a clip a section from a scene where the heroine Esther tries to get a job as an extra.
start, end = (12*60 + 35), (12*60 + 49)
logger = None # Turn off logging for cleaner output for blog post
# Uncomment below to see progress bar when saving
# logger='bar'
video = mp.VideoFileClip(filename)
# Clip a small section
video_clip = video.subclip(start, end)
# Save audio
video_clip.audio.write_audiofile(audio_intro, logger=None)
# Save video
# From here: https://stackoverflow.com/questions/40445885/no-audio-when-adding-mp3-to-videofileclip-moviepy
# Doesn't appear to save the audio otherwise
video_clip.write_videofile(video_intro,
codec='libx264',
audio_codec='aac',
temp_audiofile='temp-audio.m4a',
remove_temp=True,
logger=logger
)
First we will send the smaller of the audio file, which skips the credits and see what happens. At this point we are going to use the functionality from the openai
library. We only set two arguments
model
which is whisper-1
file
which is a file bufferresult = openai.Audio.transcribe(
model='whisper-1',
file=open(audio_intro, 'rb')
)
print(result['text'])
I beg your pardon. I'd like to register for extra work. How long have you been in Hollywood? Well, it's about a month now. We haven't put anyone on our books for over two years.
It is an accurate transcription but we would like this in the form of subtitles that may be added to a video. From the documentation we see that we can request alternative formats by setting the response_format
field
The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
But at present it openai.Audio.transcribe
does not seem to be able to handle the output returned when using this setting and throws a JSONDecodeError
openai.Audio.transcribe(
model='whisper-1',
file=open(audio_intro, 'rb'),
response_format='vtt'
) #=> JSONDecodeError
So we will use the python requests
library instead. The parameters go in the data
dict and the file in the files
dict. Notice that we have added an extra parameter language
which is the language of the input video. It is not necessary and as we saw about we got a good transcription without it but the documentation states:
Supplying the input language in ISO-639-1 format will improve accuracy and latency.
def get_subtitles(file, subtitle_format='srt', **kwargs):
url = 'https://api.openai.com/v1/audio/transcriptions'
headers = {
'Authorization': f'Bearer {openai.api_key}',
}
data = {
'model': 'whisper-1',
'response_format': subtitle_format,
'language': 'en',
}
data.update(kwargs)
files = {
'file': (file, open(file, 'rb'))
}
response = requests.post(url, headers=headers, data=data, files=files)
return response.text
subtitles = get_subtitles(audio_intro)
print(subtitles)
1
00:00:00,000 --> 00:00:03,000
I beg your pardon.
2
00:00:03,000 --> 00:00:06,000
I'd like to register for extra work.
3
00:00:06,000 --> 00:00:08,000
How long have you been in Hollywood?
4
00:00:08,000 --> 00:00:10,000
Well, it's about a month now.
5
00:00:10,000 --> 00:00:30,000
We haven't put anyone on our books for over two years.
We can also send the video file directly since MP4 is also accepted as input format. It will take a bit longer as the file size is larger but returns the same result.
subtitles_from_video = get_subtitles(video_intro)
print(subtitles_from_video)
1
00:00:00,000 --> 00:00:02,000
I beg your pardon.
2
00:00:02,000 --> 00:00:05,000
I'd like to register for extra work.
3
00:00:05,000 --> 00:00:07,000
How long have you been in Hollywood?
4
00:00:07,000 --> 00:00:09,000
Well, it's about a month now.
5
00:00:09,000 --> 00:00:31,000
We haven't put anyone on our books for over two years.
Finally let us save the subtitles to an .srt
file. If save the file in the same folder as the video and open the video in a player such as VLC, the subtitles will be shown automatically.
with open('Star-intro.srt', 'w') as f:
f.write(subtitles)
The code for Parts 1 and 2 of this tutorial can be found in this Colab notebook.
Note: this a repost with minor modifications from Annotaited, where it was initially published on 19th February 2021.
We used the Adam optimizer [20] with $\beta_1 = 0.9$, $\beta_2 = 0.98$ and $\epsilon = 10$.
We varied the learning rate over the course of training, according to the formula:
\[\text{lrate}=d_\text{model}^{−0.5} \cdot \min(\text{step_num}^{−0.5}, \text{step_num}\cdot\text{warmup_steps}^{−1.5})\]This corresponds to increasing the learning rate linearly for the first warmup_steps training steps, and decreasing it thereafter proportionally to the inverse square root of the step number. We used $\text{warmup_steps} = 4000$.
Create a custom schedule using tf.keras.optimizers.schedules.LearningRateSchedule
. It should have the following methods:
__init__ |
Initialise the hyperparameters |
__call__ |
Receives a value step as input and returns the learning rate according to the above equation |
class LearningRateScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps):
self._d_model = d_model
self._warmup_steps = warmup_steps
self.d_model_term = d_model ** (-0.5)
self.warmup_term = warmup_steps ** (-1.5)
def get_config(self):
return dict(d_model=self._d_model, warmup_steps=self.warmup_steps)
def __call__(self, step):
step_term = tf.math.minimum(step ** (-0.5), step * self.warmup_term)
return self.d_model_term * step_term
Try plotting the learning rate curve for a few different values of d_model
and warmup_steps
to get a sense of how it evolves over time.
plt.figure(figsize=(8, 8))
t = tf.range(20000, dtype=tf.float32)
d_model = np.stack([128, 256, 512])
warmup_steps = [2000, 4000, 6000]
for cmap, w in zip(['Reds', 'Greens', 'Blues'], warmup_steps):
cm = plt.cm.get_cmap(cmap)
for i, d in enumerate(d_model, 1):
lr_wd = LearningRateScheduler(
d_model=d,
warmup_steps=w
)
val = lr_wd(t).numpy()
clr = cm(int(i * cm.N / len(d_model)))
plt.plot(t.numpy(), val, label=f'warmup_steps={w}, d_model={d}', c = clr)
plt.legend();
The model is trained with a softmax loss over all the tokens in the vocabulary with label smoothing.
During training, we employed label smoothing of value $\epsilon_{ls}$ = 0.1.
A very brief guide to label smoothing
Typically it is divided uniformly over the classes as follows
\[q'(k\vert x_i) = \delta_{k,y_i} * (1 - \epsilon_{ls}) + \epsilon_{ls}/K\]Write a function smooth_labels
which implements the equation above. The inputs will be a tensor of one-hot labels of shape [..., K]
and the smoothing parameter eps
.
def smooth_labels(labels, eps=0.1):
num_classes = tf.cast(tf.shape(labels)[-1], labels.dtype)
labels = (labels - eps) + eps / num_classes
return labels
The idea is that we should average the losses across all sequences and timesteps:
\[L = \frac{1}{N \cdot T}\sum_{n=0}^{N-1}\sum_{t=0}^{T-1}\text{loss}(\hat{y}_{nt}, y_{nt})\]However we should remember to exclude the zero padded elements. As we shall see, at inference time as soon as the model predicts an <end>
we stop predicting further. So the model should not be trained to predict anything after the <end>
symbol.
Write a MaskedLoss
class that wraps a loss function and when called returns a masked average of the loss.
class MaskedLoss(object):
def __init__(self, loss_fn):
self.loss_fn = loss_fn
def __call__(self, y_true, y_pred, mask, **kwargs):
loss = self.loss_fn(y_true=y_true, y_pred=y_pred, **kwargs)
mask = tf.cast(mask, loss.dtype)
return tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
It is convenient to store all the training and inference settings and hyperparameters in a single config
object. For simplicity we will use an EasyDict
, which, if you have not previously used one, behaves just like an ordinary dict
except that its keys can be accessed like attributes i.e. config.key1
returns config['key1']
.
Here is our config
. We will add more settings to it as we go along
config = EasyDict()
config.smoothing_prob = 0.1
config.data = EasyDict()
config.data.pad_symbol = 0
Now that all the pieces of the training are in place we can prepare the data. We will be using the crowdsourced Many Things
Spanish-English dataset which consists of sentences and phrases in these two languages. Datasets are consisting of translations between many different languages and English available from the website so feel free to swap in another language. Depending on what language you use you might need to modify the preprocessing code to handle special characters.
Download the data from here and save it to a directory of your choice and unzip it. If using Colab I recommend that you connect it to Google Drive so that your that your files will be saved even when the runtime gets disconnected.
You can do so by running these lines
from google.colab import drive
drive.mount('/content/drive')
A link will appear that you need to click to get an authentication code which you input into a box below the link.
You also download and unzip the data by running:
# Modify the path to your save path
SAVE_PATH = "/content/drive/My Drive/transformer-tutorial"
if not os.path.exists(SAVE_PATH):
os.makedirs(SAVE_PATH)
DATA_DIR = os.path.join(SAVE_PATH, "data")
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
!wget http://www.manythings.org/anki/spa-eng.zip -P "{DATA_DIR}"
!unzip "{DATA_DIR}/spa-eng.zip" -d "{DATA_DIR}/spa-eng"
DATA_PATH = os.path.join(DATA_DIR, "spa-eng")
First we will add some functions that will do some basic processing to get rid of special characters, trim whitespace and separate punctuation from words.
def unicode_to_ascii(s):
return ''.join(c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn')
def preprocess_sentence(w):
w = unicode_to_ascii(w.lower().strip())
# Put spaces between words and punctuation
w = re.sub(r"([?.!¿])", r" \1 ", w)
# Reduce several whitespaces to a single one
w = re.sub(r'[" "]+', " ", w)
# Replace characters with whitespace unless in (a-z, A-Z, ".", "?", "!", ",")
w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)
w = w.strip()
# Add start / end tokens
w = ' '.join(['<start>', w, '<end>'])
return w
Now we can load the data and preprocess the sequences. The data consists of tab-separated columns of text as shown below. We only ned the first two columns which contain pairs of Spanish and English sequences. The second column is English and the first column is the other language.
def load_dataset(path, num_examples):
with open(path, encoding='UTF-8') as f:
lines = f.read().strip().split('\n')
word_pairs = [[preprocess_sentence(w) for w in l.split('\t')[:2]] for l in lines[:num_examples]]
return zip(*word_pairs)
A few randomly sampled pairs from the processed dataset. A character called “Tom” seems to feature in a lot of the examples.
spanish: <start> da la impresion de que has estado llorando . <end>
english: <start> it looks like you ve been crying . <end>
spanish: <start> tom y mary hicieron lo que se les dijo . <end>
english: <start> tom and mary did what they were told . <end>
spanish: <start> cuando no queda nada por hacer, ¿ que haces ? <end>
english: <start> when there s nothing left to do, what do you do ? <end>
spanish: <start> llevo toda la semana practicando esta cancion . <end>
english: <start> i ve been practicing this song all week . <end>
spanish: <start> tom era feliz . <end>
english: <start> tom was happy . <end>
spanish: <start> llamalo, por favor . <end>
english: <start> please telephone him . <end>
spanish: <start> vere que puedo hacer . <end>
english: <start> i ll see what i can do . <end>
spanish: <start> hubo poca asistencia a la conferencia . <end>
english: <start> a few people came to the lecture . <end>
spanish: <start> mi papa no lo permitira . <end>
english: <start> my father won t allow it . <end>
spanish: <start> ¿ como fue el vuelo ? <end>
english: <start> how was the flight ? <end>
To create inputs that we can feed into the model we need to do the following:
First let us write a function that creates a tokeniser, fits it on the training set and returns it to use subsequently. The validation and test sets might contain words not present in the training set and this is handled by replacing this with an <unk>
token.
def get_tokenizer(lang, num_words=None):
lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
num_words=num_words,
filters="",
oov_token='<unk>'
)
lang_tokenizer.fit_on_texts(lang)
return lang_tokenizer
We will split the data into train/val/test for input and target, using 70% of the data for training and split the remaining examples into two equal parts for val and test.
def load_split(path, num_examples=None, seed=1234):
tar, inp = load_dataset(path, num_examples)
inp_trn, inp_valtest, tar_trn, tar_valtest = train_test_split(inp, tar, test_size=0.3,
random_state=seed)
inp_val, inp_test, tar_val, tar_test = train_test_split(inp_valtest, tar_valtest, test_size=0.5,
random_state=seed)
# Delete to avoid returning them by mistake
del inp_valtest, tar_valtest
return (inp_trn, inp_val, inp_test), (tar_trn, tar_val, tar_test)
To build our tokenised datasets, we will create a tokeniser for each language, use it to process the data and return these along with the splits. We also return the tokeniser as we need to hold on it to use later to map the model outputs into words.
def create_dataset(inp, tar, num_inp_words=None, num_tar_words=None):
inp_tokenizer = get_tokenizer(inp[0], num_words=num_inp_words)
tar_tokenizer = get_tokenizer(tar[0], num_words=num_tar_words)
inp_trn_seq, inp_val_seq, inp_test_seq = [
inp_tokenizer.texts_to_sequences(inp_split) for inp_split in inp
]
tar_trn_seq, tar_val_seq, tar_test_seq = [
tar_tokenizer.texts_to_sequences(tar_split) for tar_split in tar
]
inputs = dict(
train=inp_trn_seq, val=inp_val_seq, test=inp_test_seq
)
targets = dict(
train=tar_trn_seq, val=tar_val_seq, test=tar_test_seq
)
return inputs, targets, inp_tokenizer, tar_tokenizer
Let us try out the tokeniser:
inp, tar = load_split(os.path.join(DATA_PATH, 'spa.txt'), num_examples=None, seed=1234)
tmp_inp_tkn = get_tokenizer(inp[0])
tmp_tar_tkn = get_tokenizer(tar[0])
print(inp[0][0], tar[0][0])
tmp_inp_tkn.texts_to_sequences([inp[0][0]]), tmp_tar_tkn.texts_to_sequences([tar[0][0]])
<start> no son idiotas . <end> <start> they re not stupid . <end>
([[2, 8, 65, 3396, 4, 3]], [[2, 48, 50, 41, 448, 4, 3]])
To keep our model relatively small we will use just the top 10000 words in each language
config.data.input_vocab_size = 10000
config.data.target_vocab_size = 10000
Let us generate the data splits and save them along with the tokenizer to be able to reuse them later without having to regenerate them.
# If path doesn't exist create dataset or let RECREATE_DATASET = True to
# force creation
RECREATE_DATASET = False
if not os.path.exists(os.path.join(DATA_PATH, 'splits.pkl')) or RECREATE_DATASET:
print('Creating dataset')
inputs, targets, inp_tkn, tar_tkn = create_dataset(
inp, tar ,num_inp_words=config.data.input_vocab_size,
num_tar_words=config.data.target_vocab_size)
for k, v in inputs.items():
print(k, len(v), len(targets[k]))
config.data.split_path = os.path.join(DATA_PATH, 'splits.pkl')
with open(config.data.split_path, 'wb') as f:
pickle.dump(
file=f,
obj=dict(
inputs=inputs,
targets=targets,
inp_tkn=inp_tkn,
tar_tkn=tar_tkn
)
)
The vocabulary of num_words
includes <start>
, <end>
, <unk>
(for unknown word) and implicitly includes padding. tf.keras.preprocessing.text.Tokenizer
does not have an explict token for padding but starts the the token ids from 1 so that when we use the tokeniser below to decode the predictions we need to get rid of the padding first.
print(inp_tkn.index_word.get(0), tar_tkn.index_word.get(0))
print(inp_tkn.sequences_to_texts([[inp_tkn.num_words-1]]),
tar_tkn.sequences_to_texts([[tar_tkn.num_words-1]]))
print(inp_tkn.sequences_to_texts([[inp_tkn.num_words]]),
tar_tkn.sequences_to_texts([[tar_tkn.num_words]]))
None None
['enloquecido'] ['harbored']
['<unk>'] ['<unk>']
Then we can load the splits.
with open(config.data.split_path, 'rb') as f:
data = pickle.load(f)
inp_trn = data['inputs']['train']
inp_val = data['inputs']['val']
tar_trn = data['targets']['train']
tar_val = data['targets']['val']
inp_tkn = data['inp_tkn']
tar_tkn = data['tar_tkn']
Here is how to decode the input sequence. Decoding can sometimes be lossy when there are words not present in the vocabulary, as they get converted to <unk>
def convert(tensor, tokenizer):
for t in tensor:
if t != 0:
print(f'{t} -----> {tokenizer.index_word[t]}')
convert(inp_trn[0], inp_tkn)
convert(tar_trn[0], tar_tkn)
2 -----> <start>
8 -----> no
65 -----> son
3396 -----> idiotas
4 -----> .
3 -----> <end>
2 -----> <start>
48 -----> they
50 -----> re
41 -----> not
448 -----> stupid
4 -----> .
3 -----> <end>
Now that our data is all prepared let us build the input pipeline using tf.data.Dataset
config.data.batch_size = 64
# - Data will be padded to the longest sequence length in the batch
# - There are ways to optimise this such that groups of sequences of similar length
# are batched together so that computation is not wasted on the padded elements
# but we won't worry about that for now as we are dealing
# with quite a small dataset and model.
padded_shapes = ([None], [None])
buffer_size = len(inp_trn)
train_dataset = tf.data.Dataset.from_tensor_slices((tf.ragged.constant(inp_trn),
tf.ragged.constant(tar_trn)))
train_dataset = train_dataset.shuffle(buffer_size).batch(config.data.batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((tf.ragged.constant(inp_val),
tf.ragged.constant(tar_val)))
val_dataset = val_dataset.batch(config.data.batch_size)
Although called the target, the sequence for the target language is used both as input to the decoder and as target for the final prediction. For sequence of length $T$, the first $T-1$ elements of the sequence are input to the decoder which is trained to predicts the last $T-1$ elements. For convenience and to avoid mistakes let us write a simple function that splits the target into two parts tar_inp
and tar_real
. Assume that the target is of arbitrary rank but the sequence or “time” dimensions is the last dimension e.g. it could have shape [batch_size, T]
or it could just have shape [T]
def split_target(target):
tar_real = target[..., :-1]
tar_inp = target[..., 1:]
return tar_ral, tar_inp
In this tutorial we won’t go into the metrics like BLEU that are used for evaluating translation models as they merit a tutorial in themselves but use loss and accuracy to get a sense of how well the model is doing. The caveat is that accuracy is not the best metric to use when you don’t have a balanced dataset - and a vocabulary is not particularly well-balanced - since the model can do badly on infrequently occuring classes (here words) and that won’t really be reflected in the accuracy. Moreover iit only deals with pointwise comparisons so doesn’t take into account the sequential nature of the data.
However for this simple application, it will give us a reasonable idea of whether the model is doing well or badly. Once we have trained the model we will also make some predictions and evaluate them qualitatively.
Here is the masked accuracy function we will be using:
def accuracy_function(real, pred, pad_mask):
pred_ids = tf.argmax(pred, axis=2)
accuracies = tf.cast(tf.equal(tf.cast(real, pred_ids.dtype), pred_ids), tf.float32)
mask = tf.cast(pad_mask, dtype=tf.float32)
return tf.reduce_sum(accuracies * mask)/tf.reduce_sum(mask)
Add some more fields to config
. Note that in the interests of speed we are using quite small model. Feel free with change these settings and try different experiments.
config.model = EasyDict()
config.model.model_dim = 128
config.model.ff_dim = 512
config.model.num_heads = 8
config.model.num_encoder_blocks = 4
config.model.num_decoder_blocks = 4
config.model.dropout = 0.1
config.smoothing_prob = 0.1
Let us initialise instances of the following using the settings from config
:
LearningRateScheduler
MaskedLoss
SequenceMask
TargetMask
Transformer
We used the Adam optimizer [20] with $\beta_1 = 0.9, \beta_2 = 0.98 and \epsilon = 10$.
lr = LearningRateScheduler(
d_model=config.model.model_dim,
warmup_steps=4000
)
loss_function = MaskedLoss(
loss_fn=tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction="none",
label_smoothing=config.smoothing_prob
),
)
pad_masker = SequenceMask(config.data.pad_symbol)
future_masker = TargetMask(config.data.pad_symbol)
model = Transformer(
num_tokens=inp_tkn.num_words,
num_tgt_tokens=tar_tkn.num_words,
model_dim=config.model.model_dim,
num_heads=config.model.num_heads,
dropout=config.model.dropout,
ff_dim=config.model.ff_dim,
num_encoder_blocks=config.model.num_encoder_blocks,
num_decoder_blocks=config.model.num_decoder_blocks
)
optim = tf.optimizers.Adam(
learning_rate=lr,
beta_1=0.9,
beta_2=0.98,
epsilon=1e-9
)
We will write a simple class Trainer
that will hold together all these components and which will have methods for running training and validation steps.
class Trainer(tf.Module):
def __init__(self,
config,
model: tf.keras.models.Model,
pad_masker,
future_masker,
loss_function,
optim):
self.config = config
self.model = model
self.pad_masker = pad_masker
self.future_masker = future_masker
self.loss_function = loss_function
self.optim = optim
self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
self.val_loss = tf.keras.metrics.Mean(name='val_loss')
self.val_accuracy = tf.keras.metrics.Mean(name='val_accuracy')
@tf.function(input_signature=(
tf.TensorSpec(shape=[None, None], dtype=tf.int32),
tf.TensorSpec(shape=[None, None], dtype=tf.int32)
))
def train_step(self, inp, tar):
# TODO: Implement this
pass
@tf.function(input_signature=(
tf.TensorSpec(shape=[None, None], dtype=tf.int32),
tf.TensorSpec(shape=[None, None], dtype=tf.int32)
))
def valid_step(self, inp, tar):
# TODO: Implement this
pass
Here is what train_step
should do:
train_loss
and train_accuracy
(using accuracy_function
from above for the latter)Hints:
train_step
code here. You need to do something quite similar, just using the transformer as the model.(B, 1, 1, N)
to (B, N)
) @tf.function(input_signature=(
tf.TensorSpec(shape=[None, None], dtype=tf.int32),
tf.TensorSpec(shape=[None, None], dtype=tf.int32)
))
def train_step(self, inp, tar):
tar_inp, tar_real = split_target(tar)
tar_pad_mask = self.pad_masker(tar_real)[:, 0, 0, :]
with tf.GradientTape() as tape:
predictions, _ = self.model(
inp, tar_inp,
src_mask=self.pad_masker(inp),
tgt_mask=self.future_masker(tar_inp),
training=True
)
tar_onehot = tf.one_hot(tar_real, depth=tf.shape(predictions)[-1])
if self.config.smoothing_prob > 0:
labels = smooth_labels(tar_onehot, self.config.smoothing_prob)
else:
labels = tar_onehot
loss = self.loss_function(
y_true=labels,
y_pred=predictions,
mask=tar_pad_mask
)
gradients = tape.gradient(loss, self.model.trainable_variables)
self.optim.apply_gradients(zip(gradients, self.model.trainable_variables))
self.train_loss(loss)
self.train_accuracy(accuracy_function(tar_real, predictions, tar_pad_mask))
Now write valid_step
which will be very similar to train_step
but with the gradient steps. Remember to call the model with training=False
.
@tf.function(input_signature=(
tf.TensorSpec(shape=[None, None], dtype=tf.int32),
tf.TensorSpec(shape=[None, None], dtype=tf.int32)
))
def valid_step(self, inp, tar):
tar_inp, tar_real = split_target(tar)
tar_pad_mask = self.pad_masker(tar_real)[:, 0, 0, :]
predictions, _ = self.model(
inp, tar_inp,
src_mask=self.pad_masker(inp),
tgt_mask=self.future_masker(tar_inp),
training=False
)
loss = self.loss_function(
y_true=tf.one_hot(tar_real, depth=tf.shape(predictions)[-1]),
y_pred=predictions,
mask=tar_pad_mask
)
self.val_loss(loss)
self.val_accuracy(accuracy_function(tar_real, predictions, tar_pad_mask))
Some code to save checkpoints. Here we are only saving a single checkpoint corresponding to the lowest validation loss but feel free to change that.
CKPT_PATH = os.path.join(SAVE_PATH, 'checkpoints')
config.ckpt_path = CKPT_PATH
ckpt = tf.train.Checkpoint(
transformer=model,
optimizer=optim
)
ckpt_manager = tf.train.CheckpointManager(
ckpt, config.ckpt_path, max_to_keep=1
)
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print('Latest checkpoint restored')
Now add a few more settings to config
and we are ready to run.
config.epochs = 10
config.log_every = 100 # how frequently to print metrics
If you want this to run fast connect to a GPU runtime by going to Runtime > Change Runtime Type
from the top menu and selecting “GPU”. The code has not been carefully optimised to run fast but it took between 15-30 minutes to run for 10 epochs (presumably depending on what GPU I got).
After 10 epochs you should observe a validation accuracy of around 79%.
best_loss = float('inf')
trainer = Trainer(config, model, pad_masker, future_masker, loss_function, optim)
for epoch in range(config.epochs):
start = time.time()
trainer.train_loss.reset_states()
trainer.train_accuracy.reset_states()
trainer.val_loss.reset_states()
trainer.val_accuracy.reset_states()
for (batch, (inp, tar)) in enumerate(train_dataset):
trainer.train_step(inp.to_tensor(), tar.to_tensor())
if batch % config.log_every == 0:
print (f'Epoch {epoch + 1} Batch {batch}'
f' Loss {trainer.train_loss.result(): .4f}',
f' Accuracy {trainer.train_accuracy.result(): .4f}')
for (batch, (inp, tar)) in enumerate(val_dataset):
trainer.valid_step(inp.to_tensor(), tar.to_tensor())
val_loss = trainer.val_loss.result()
if val_loss < best_loss:
print(f"Validation loss decreased from {best_loss} to {val_loss}. Saving checkpoint")
best_loss = val_loss
ckpt_manager.save()
print(f'Epoch {epoch + 1}',
f' Loss {trainer.train_loss.result(): .4f}',
f' Accuracy {trainer.train_accuracy.result(): .4f}',
f' Val loss {val_loss: .4f}',
f' Val accuracy {trainer.val_accuracy.result(): .4f}')
print(f'Time taken for 1 epoch: {time.time() - start} secs')
print()
So our model has trained for a bit. The loss has dropped and the accuracy has risen but what do the translations actually look like. To answer that question we need to write one more function. Write predict_sequence
that takes a single input sequence and returns the predicted sequence along with attention maps.
Hints:
start_symbol
i.e. inp_tkn.word_index["<start>"]
<end>
token has been predicted.encode
and decode
methods for the model. Use these instead of call
to avoid having to predict the encoder output repeatedly.def predict_sequence(inputs, model, max_length,
start_symbol, end_symbol,
pad_masker, future_masker):
inputs = inputs[None]
result = tf.ones_like(inputs[:, :1]) * start_symbol
encoding, enc_attn = model.encode(inputs, pad_masker(inputs), training=False)
for _ in tf.range(max_length):
predictions, self_attn, dec_attn = model.decode(result, encoding,
pad_masker(inputs),
future_masker(result),
training=False)
# Just select the last sequence element and the symbol
# with highest probability
next_symbol = tf.argmax(predictions[:, -1], axis=-1, output_type=tf.int32)
result = tf.concat([result, next_symbol[:, None]], axis=-1)
# If sequence is done, stop
if tf.equal(tf.squeeze(next_symbol), end_symbol):
break
attention = dict(
enc=tf.squeeze(tf.stack(enc_attn), axis=1),
dec_self=tf.squeeze(tf.stack(self_attn), axis=1),
dec_memory=tf.squeeze(tf.stack(dec_attn), axis=1)
)
return tf.squeeze(result, axis=0), attention
Define a shuffled version of the validation dataset and translate a few sample sentences
val_sample_dataset = tf.data.Dataset.from_tensor_slices((tf.ragged.constant(inp_val),
tf.ragged.constant(tar_val)))
val_sample_dataset = val_sample_dataset.shuffle(len(inp_val))
[inp_sample, tar_sample] = next(iter(val_sample_dataset))
pred, attn = predict_sequence(inp_sample, model, 50, inp_tkn.word_index['<start>'],
inp_tkn.word_index['<end>'], pad_masker, future_masker)
inp_sent_full = inp_tkn.sequences_to_texts(inp_sample.numpy()[None])[0]
tar_pred_sent_full = tar_tkn.sequences_to_texts(pred.numpy()[None])[0]
tar_true_sent_full = tar_tkn.sequences_to_texts(tar_sample.numpy()[None])[0]
print("spanish:", inp_sent_full)
print('english_pred:', tar_pred_sent_full)
print('english_true:', tar_true_sent_full)
spanish: estas personas dijeron que la guerra era una guerra civil .
english_pred: these people said war was a war .
english_true: these people said the war was a civil war .
spanish: dime donde pongo estos libros .
english_pred: tell me where i put these books .
english_true: tell me where to put these books .
spanish: necesito <unk> las piernas .
english_pred: i need to brush my legs .
english_true: i need to stretch my legs .
spanish: traeme una toalla .
english_pred: bring me a towel .
english_true: get me a towel .
spanish: tom dijo que tenia prisa .
english_pred: tom said he was in a hurry .
english_true: tom said that he was in a hurry .
Remember that we are also returning the attention maps. Since this architecture is literally “all” about attention we should take a look at those as well. Here is some code that will plot the attention maps for an input and output sequence.
def plot_attention_head(inp_sentence, tar_sentence, attention):
# The attention map rows correspond to the last N - 1 tokens which
# were predicted by the model. This does not include the '<start>'
# symbol, which the model didn't generate.
tar_sentence = tar_sentence[1:]
ax = plt.gca()
ax.matshow(attention, cmap='hot', vmin=0, vmax=1)
ax.set_xticks(range(len(inp_sentence)))
ax.set_yticks(range(len(tar_sentence)))
labels = [label for label in inp_sentence]
ax.set_xticklabels(
labels, rotation=90)
labels = [label for label in tar_sentence]
ax.set_yticklabels(labels)
This plots the first memory attention map from the final decoder head. Try modifying it to plot maps from other locations as well as self-attention maps.
plot_attention_head(inp_sent_full.split(),
tar_pred_sent_full.split(),
attn['dec_memory'][-1][0].numpy())
Now let us plot all the memory attention maps from the final decoder head.
def plot_attention_weights(inp_sentence, tar_sentence, attention_heads):
fig = plt.figure(figsize=(16, 8))
for h, head in enumerate(attention_heads):
ax = fig.add_subplot(2, 4, h+1)
plot_attention_head(inp_sentence, tar_sentence, head)
ax.set_xlabel('Head {}'.format(h+1))
plt.tight_layout()
plt.show()
plot_attention_weights(inp_sent_full.split(),
tar_pred_sent_full.split(),
attn['dec_memory'][-1].numpy())
There are many ways this model can be extended. We can modify the architecture or use a different dataset. Metrics like BLEU will give us a better idea of the model’s performance. At inference time we are using a simple greedy approach choosing just the best prediction at each timestep which might not necessary lead to the best overall translation. In the paper they use beam search which stores a small number of thee top predictions at each stage searches for the best overall translation among these and this typically yields better results.
This is only the tip of the iceberg with regard to what Transformers are and what they can do. Many extensions to the original architecture have been developed and there are many more applications outside of NLP (such as image recognition and image generation). I have just linked to a handful that came to mind but there are many more.
I hope to cover some of these extensions through future tutorials. In the meantime I encourage you to experiment with making your own changes to the model.
This tutorial like the previous was inspired by The Annotated Transformer.
It also borrows and adapts code for data preparation, saving, plotting and other utility functions from the following sources:
Since its introduction in Attention Is All You Need [1], the Transformer architecture has become very influential and has had successes in many tasks not just in NLP but in other areas like vision. This tutorial is inspired by the approach in The Annotated Transformer [2] which primarily uses text directly quoted from the paper to explain the code (or you could say that it uses code to explain the paper). Differently from [2], which uses PyTorch and adopts a top-down approach to building the model, this tutorial uses Tensorflow along with bottom up approach starting with individual components and gradually putting them together. All quoted sections are from [1].
The code for Parts 1 and 2 of this tutorial can be found in this Colab notebook.
Note: this a repost with minor modifications from Annotaited, where it was initially published on 2nd February 2021.
Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in the input or output sequences [2, 19]. In all but a few cases [27], however, such attention mechanisms are used in conjunction with a recurrent network.
In this work we propose the Transformer, a model architecture eschewing recurrence and instead relying entirely on an attention mechanism to draw global dependencies between input and output.
Figure 1 of [1])
Most competitive neural sequence transduction models have an encoder-decoder structure [5, 2, 35]. Here, the encoder maps an input sequence of symbol representations $(x_1,…,x_n)$ to a sequence of continuous representations $z = (z_1,…,z_n)$. Given $z$, the decoder then generates an output sequence $(y_1, …, y_m)$ of symbols one element at a time. At each step the model is auto-regressive [10], consuming the previously generated symbols as additional input when generating the next.
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.
We call our particular attention “Scaled Dot-Product Attention” (Figure 2). The input consists of queries and keys of dimension $d$ , and values of dimension $d$ . We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values.
In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix Q. The keys and values are also packed together into matrices K and V . We compute the matrix of outputs as:
\[\text{Attention}(Q, K, V) = \text{softmax}\frac{QK^T}{\sqrt{d_k}} V\]
From Figure 2 of [1])
Let us implement the steps in the diagram, not worrying for now about the Mask
step. We will call this function scaled_dot_product_attention_temp
. Assume there are three inputs query
, key
and value
with final two dimensions (shape[-2:]
) N_q, d_k
, N_k, d_k
, N_v, d_k
, where N_k = N_v
. Return the final output as well as the attention weights as these are useful for inspecting the model.
def scaled_dot_product_attention_temp(query, key, value, inf=1e9):
d_k = tf.cast(tf.shape(query)[-1], tf.float32)
key_transpose = tf.transpose(key,
tf.concat([tf.shape(key)[:-2], [-1, -2]]))
qkt = tf.matmul(query, key_transpose)
alpha = tf.nn.softmax(qkt/tf.sqrt(d_k))
return tf.matmul(alpha, value), alpha
Instead of performing a single attention function with $d_\text{model}$-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to $d_k$, $d_k$ and $d_v$ dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding $d_v$ -dimensional output values. These are concatenated and once again projected, resulting in the final values
In this work we employ h = 8 parallel attention layers, or heads. For each of these we use dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.
We can apply scaled_dot_product_attention
in parallel across heads, batches and postitions. It can be helpful to track the shapes of the inputs as they get transformed via Multi-Head Attention.
Now we can implement a MultiHeadAttention
module which will apply these steps. It will be called on four inputs query, key, value, mask
and return a single output.
class MultiHeadAttention(tf.keras.models.Model):
def __init__(self, dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.transform_query, self.transform_key, self.transform_value = [
*(tf.keras.layers.Dense(units=dim) for _ in range(3))
]
self.transform_out = tf.keras.layers.Dense(units=dim)
def split_heads(self, x):
# x: (B, N, d)
# (B, N, h, d//h)
x = tf.reshape(x, (tf.shape(x)[0], -1, self.num_heads, self.dim // self.num_heads))
# (B, h, N, d//h)
x = tf.transpose(x, (0, 2, 1, 3))
return x
def merge_heads(self, x):
# x: (B, h, N, d//h)
# (B, N, h, d//h)
x = tf.transpose(x, (0, 2, 1, 3))
# (B, N, d)
x = tf.reshape(x, (tf.shape(x)[0], -1, self.dim))
return x
def call(self, query, key, value, mask):
# (query=(B, N_q, d), key=(B, N_k, d), value=(B, N_v, d))
query = self.transform_query(query)
key = self.transform_key(key)
value = self.transform_value(value)
# (query=(B, h, N_q, d//h), key=(B, h, N_k, d//h), value=(B, h, N_v, d//h))
query, key, value = (self.split_heads(i) for i in [query, key, value])
# (B, h, N_q, d)
x, attn = scaled_dot_product_attention(query, key, value, mask)
x = self.merge_heads(x)
x = self.transform_out(x)
return x, attn
Examples of self-attention and memory attention maps for a translation model (which we will build in the next tutorial):
[E]ach of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically. This consists of two linear transformations with a ReLU activation in between. FFN(x) = max(0, xW1 + b1 )W2 + b2 (2) While the linear transformations are the same across different positions, they use different parameters from layer to layer. Another way of describing this is as two convolutions with kernel size 1. The dimensionality of input and output is dmodel = 512, and the inner-layer has dimensionality dff =2048.
Let us implement a class FeedForward
. It should be an instance of tf.keras.models.Model
and take a single input.
Implementation details:
class FeedForward(tf.keras.models.Model):
def __init__(self, hidden_dim, output_dim):
super(FeedForward, self).__init__()
self.dense1 = tf.keras.layers.Dense(hidden_dim,
activation='relu')
self.dense2 = tf.keras.layers.Dense(output_dim)
def call(self, x):
x = self.dense2(self.dense1(x))
return x
The encoder is composed of a stack of N = 6 identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position- wise fully connected feed-forward network. We employ a residual connection [11] around each of the two sub-layers, followed by layer normalization [1]. That is, the output of each sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension dmodel = 512.
Let us start by building the Sublayer block shown below. One way to implement it is to implement a ResidualLayer
module in which the Dropout, Add & Norm and the residual connection are contained in a single block. The model receives the input and output to the sublayer which the sublayer might be the FeedForward
and MultiHeadedAttention
blocks. These are the key details:
Hint: sublayers can have had an arbitrary number of inputs for example, query, key and value and mask(s) are required for the multi-headed attention. Think about how to handle that.
class ResidualLayer(tf.keras.layers.Layer):
def __init__(self, dropout=0.0):
super(ResidualLayer, self).__init__()
self.use_dropout = dropout > 0
if self.use_dropout:
self.dropout_layer = tf.keras.layers.Dropout(dropout)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
def call(self, skip, out, training=True):
if self.use_dropout:
out = self.dropout_layer(out, training=training)
return self.layer_norm(skip + out, training=training)
Now we can use this component block to construct a module EncoderBlock
, which will be the building block of the encoder:
class EncoderBlock(tf.keras.models.Model):
def __init__(self, dim, ff_dim, num_heads, dropout=0.0):
super(EncoderBlock, self).__init__()
self.attn_block = MultiHeadAttention(dim, num_heads)
self.res1 = ResidualLayer(dropout)
self.ff_block = FeedForward(hidden_dim=ff_dim, output_dim=dim)
self.res2 = ResidualLayer(dropout)
def call(self, query, key, value, mask, training=True):
x, attn = self.attn_block(query, key, value, mask, training=training)
skip = x = self.res1(skip=query, out=x, training=training)
x = self.ff_block(x)
x = self.res2(skip=skip, out=x, training=training)
return x, attn
Finally we can put together the Encoder
, which consists of a stack of N encoder blocks.
class Encoder(tf.keras.models.Model):
def __init__(self, dim, ff_dim, num_heads, num_blocks, dropout=0.0):
super(Encoder, self).__init__()
self.blocks = [
EncoderBlock(dim, ff_dim, num_heads, dropout=dropout)
for _ in range(num_blocks)
]
def call(self, query, mask, training=True):
attn_maps = []
for block in self.blocks:
query, attn = block(
query=query, key=query, value=query,
mask=mask, training=training)
attn_maps.append(attn)
return query, attn_maps
The decoder is also composed of a stack of N = 6 identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization.
Let us write a DecoderBlock
. The decoder block consists of the following:
Hint: it will be very similar to EncoderBlock
but remember that the inputs to the final AttentionBlock
will be different. Can you reuse EncoderBlock
?
class DecoderBlock(tf.keras.models.Model):
def __init__(self, dim, ff_dim, num_heads, dropout=0.0):
super(DecoderBlock, self).__init__()
self.self_attn_block = MultiHeadAttention(dim, num_heads)
self.res = ResidualLayer(dropout=dropout)
self.memory_block = EncoderBlock(dim, ff_dim, num_heads, dropout=dropout)
def call(self, query, key, value, decoder_mask, memory_mask, training=True):
# if not self.skip_attn:
x, self_attn = self.self_attn_block(query, query, query, decoder_mask, training=training)
x = self.res(skip=query, out=x, training=training)
x, attn = self.memory_block(x, key, value, memory_mask, training=training)
return x, self_attn, attn
The decoder network will be very similar to the encoder except that it will receive two different mask inputs, one for self-attention and the other for the encoder outputs.
class Decoder(tf.keras.models.Model):
def __init__(self, dim, ff_dim, num_heads, num_blocks, dropout=0.0):
super(Decoder, self).__init__()
self.blocks = [
DecoderBlock(dim, ff_dim, num_heads, dropout=dropout)
for i in range(num_blocks)
]
def call(self, query, memory, decoder_mask, memory_mask, training=True):
self_attn_maps = []
memory_attn_maps = []
for block in self.blocks:
query, self_attn, memory_attn = block(
query=query, key=memory, value=memory,
decoder_mask=decoder_mask,
memory_mask=memory_mask,
training=training)
self_attn_maps.append(self_attn)
memory_attn_maps.append(memory_attn)
return query, self_attn_maps, memory_attn_maps
Two kinds of masks are used to prevent information flow from some sequence positions.
You can plot the masks generated below after squeezing the dimensions of size 1, using the following code:
def plot_mask(mask):
plt.pcolormesh(mask, cmap='gray', vmin=0, vmax=1, edgecolors='gray')
plt.gca().invert_yaxis()
plt.axis('off')
This type of masking is not specific to the Transformer and is not discussed in the paper but used in practice. Padding sequences to the same length allows us to batch together sequences of different lengths. However this is only an engineering requirement and we don’t actually want the model to use the padding elements. The solution is to mask all the positions that have a padding symbol.
Implement a SequenceMask
class that does the following:
False
at a location if the value is any of the pad symbols otherwise True
(batch_size, 1, 1, sequence_length)
tensor that is suitable for using in scaled_dot_product_attention
class SequenceMask:
def __init__(self, pad):
self.pad = pad
def __call__(self, x):
# Disregards padded elements
# x: (B, N)
if isinstance(self.pad, int):
mask = tf.not_equal(x, self.pad)
else:
mask = tf.reduce_all(tf.not_equal(x[..., None], self.pad), axis=-1)
# Same mask for every position
# (B, 1, 1, N)
return mask[:, None, None]
The sequence masks for tf.stack([[1, 2, 3, 4, 5, 0, 0], [1, 2, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7]])
and pad=0
:
Since we train all the target positions in parallel, the model has access to elements from the “future” and we need to prevent information flowing from later to earlier positions.
We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position $i$ can depend only on the known outputs at positions less than $i$.
Implement a subsequent_mask
function that for sequence_length=N
returns an $N \times N$ tensor, mask
where, mask[i, j] = i <=j
This will be used for self-attention only and when applied to a query of length $N$ will broadcast to a $N \times N$ sequence-agnostic mask shared across all the batch elements and attention heads, as shown below:
Hint: use tf.linalg.band_part
.
def subsequent_mask(seq_length):
# (N, N)
# lower_triangular matrix
future_mask = tf.linalg.band_part(
tf.ones((seq_length, seq_length)),
-1, 0)
future_mask = tf.cast(future_mask, tf.bool)
return future_mask
The result for subsequent_mask(7)
:
Now write a TargetMask
class that inherits from SequenceMask
does the following when called:
combined_mask[:, j] = False
if position j
is padding else combined_mask[:, j] = future_mask[i, j]
(batch_size, 1, sequence_length, sequence_length)
tensor that is suitable for using in scaled_dot_product_attention
class TargetMask(SequenceMask):
def __call__(self, x):
# Disregards "future" elements and any others
# which are padded
# x: (B, N)
# (B, 1, N)
pad_mask = super().__call__(x)
seq_length = tf.shape(x)[-1]
# Mask shared for same position across batches
# (N, N)
future_mask = subsequent_mask(seq_length)
# (B, 1, 1, N) & (N, N) -> (B, 1, N, N)
mask = tf.logical_and(pad_mask, future_mask)
return mask
The target masks for tf.stack([[1, 2, 3, 4, 5, 0, 0], [1, 2, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7]])
and pad=0
:
In attention layers, the attention weights should be 0 for the padding elements so that other elements don’t attend to these elements.
We need to prevent leftward information flow in the decoder to preserve the auto-regressive property. We implement this inside of scaled dot-product attention by masking out (setting to $-\infty$) all values in the input of the softmax which correspond to illegal connections.
One way to handle masking is to set the positions where mask=False
to a negative value with large magnitude such that its softmax score is almost zero and has negligible effect on the scores of the values where mask=True
.
def scaled_dot_product_attention_temp(query, key, value, inf=1e9):
d_k = tf.cast(tf.shape(query)[-1], tf.float32)
key_transpose = tf.transpose(key,
tf.concat([tf.shape(key)[:-2], [-1, -2]]))
qkt = tf.matmul(query, key_transpose)
alpha = tf.nn.softmax(qkt/tf.sqrt(d_k))
return tf.matmul(alpha, value), alpha
[W]e use learned embeddings to convert the input tokens and output tokens to vectors of dimension $d_\text{model}$. We also use the usual learned linear transformation and softmax function to convert the decoder output to predicted next-token probabilities. In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation, similar to [30]. In the embedding layers, we multiply those weights by $\sqrt{d_\text{model}}$
A simple approach to share the weights is to implement an ScaledEmbedding
layer using tf.keras
, then get the weight from this layer and matrix multiply to generate the input to the softmax. Alternatively we can skip weight sharing and just use a dense layer for the output.
Accordingly let implement us implement a ScaledEmbedding
that takes as input a num_tokens
length vector.
Hint: you can multiply the output by $\sqrt{d_\text{model}}$ instead of the weights.
class ScaledEmbedding(tf.keras.layers.Layer):
def __init__(self, num_tokens, dim):
super(ScaledEmbedding, self).__init__()
self.embed = tf.keras.layers.Embedding(
input_dim=num_tokens,
output_dim=dim
)
self.dim = tf.cast(dim, tf.float32)
def call(self, x):
return tf.sqrt(self.dim) * self.embed(x)
If we want to share weights, we can do as follows:
tf.matmul(x, embed_layer.weights[0], transpose_b=True)
Since our model contains no recurrence and no convolution, in order for the model to make use of the order of the sequence, we must inject some information about the relative or absolute position of the tokens in the sequence. To this end, we add “positional encodings” to the input embeddings at the bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel as the embeddings, so that the two can be summed.
In this work, we use sine and cosine functions of different frequencies:
\[PE_{(\text{pos},2i)} = \sin(\text{pos}/10000^{2i/d_\text{model}})\] \[PE_{(\text{pos},2i+1)} = \cos(\text{pos}/10000^{2i/d_\text{model}})\]
where $\text{pos}$ is the position and $i$ is the dimension.
[W]e apply dropout to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks
Let us implement a PostionalEncoding
layer as follows:
class PositionalEncoding(tf.keras.models.Model):
def __init__(self, dim, dropout=0.0):
super(PositionalEncoding, self).__init__()
# (D / 2,)
self.range = tf.range(0, dim, 2)
self.dim = tf.cast( 1 / (10000 ** (self.range / dim)), tf.float32)
self.use_dropout = dropout > 0
if self.use_dropout:
self.dropout_layer = tf.keras.layers.Dropout(dropout)
def call(self, x, training=True):
# x: (B, N, D)
# (N,)
length = tf.shape(x)[-2]
pos = tf.cast(tf.range(length), tf.float32)
# (1, N) / (D / 2, 1) -> (D / 2, N)
inp = pos[None] * self.dim[:, None]
sine = tf.sin(inp)
cos = tf.cos(inp)
# (D, N)
enc = tf.dynamic_stitch(
indices=[self.range, self.range + 1],
data=[sine, cos]
)
# (N, D)
enc = tf.transpose(enc, (1, 0))[None]
if self.use_dropout:
return self.dropout_layer(x + enc, training=training)
return x + enc
To get a positional encoding of shape [length, dim]
that you can plot, call PositionalEncoding(dim)(tf.zeros((1, length, dim))).numpy().squeeze()
. (With a zeros input and zero dropout just the positional encoding is returned).
Here we can see for a few positions how for each dimension, the positional encoding tends to vary at each position helping to differentiate between the positions
plt.figure(figsize=(12, 8))
d_model = 32
length = 128
pe = PositionalEncoding(d_model)(tf.zeros([1, 128, d_model])).numpy().squeeze()
plt.plot(np.arange(length), pe[:, 8:16]);
In the plots below we plot the positional encodings as:
d_model
, length
-sized vectors, showing how at each dimension the value at each position varies
Here we see the sinusoids for a few positions:
fig = plt.figure(figsize=(12, 8))
d_model = 16
length = 128
pe = PositionalEncoding(d_model)(tf.zeros([1, 128, d_model])).numpy().squeeze()
plt.plot(np.arange(length), pe[:, 4:8]);
plt.legend(["dim %d"%p for p in range(4, 8)])
plt.xlabel('position')
plt.ylabel('dimension')
In this figure all the positions are plotted
fig = plt.figure(figsize=(12, 6))
d_model = 16
length = 128
pe = PositionalEncoding(d_model)(tf.zeros([1, length, d_model])).numpy().squeeze()
# add an offset to so that
offset = 4 * np.arange(d_model)
# plot with orientation consistent with the [length, d_model] shape of the inputs
plt.plot((pe + offset), np.arange(length));
plt.xticks(offset);
fig.axes[0].set_xticklabels(offset // 4);
plt.xlabel('dimension')
plt.ylabel('position')
plt.legend(["dim %d"%p for p in range(length)], loc='upper right')
2/ length
, d_model
-sized vectors, which lets us see how each position can be represented as a different sinusoid
fig = plt.figure(figsize=(12, 8))
d_model = 128
length = 16
pe = PositionalEncoding(d_model)(tf.zeros([1, length, d_model])).numpy().squeeze()
offset = 4 * np.arange(length)
plt.plot(np.arange(d_model), (pe + offset[:, None]).T);
plt.yticks(offset);
fig.axes[0].set_yticklabels(offset // 4);
plt.xlabel('dimension')
plt.ylabel('position')
plt.legend(["pos %d"%p for p in range(length)])
Now using all the classes and functions that we have written we can build a transformer. Write a Transformer
class that is initialised with the following arguments:
num_src_tokens |
Number of tokens in the input / source dataset |
num_tgt_tokens |
Number of tokens in the target dataset |
model_dim |
Same as d_model |
num_heads |
Number of attention heads in MultiHeadAttention |
dropout |
Value between 0 and 1 indicating fraction of units to drop in dropout layers |
ff_dim |
Number of hidden dimensions for the FeedForward block |
num_encoder_blocks |
Number of EncoderBlock modules to use in Encoder |
num_decoder_blocks |
Number of DecoderBlock modules to use in Decoder |
share_embed_weights |
Whether to share the embedding weights for source and target, only applicable if num_src_tokens =num_tgt_tokens |
share_softmax_weights |
Whether to share the weights between the output layer and the target embeddings |
This module will be called with the following inputs:
training
It will have the following methods
encode
which only takes the inputs needed to generate the output Encoder
training
decode
which assumes that an encoded sequence is available and takes only the inputs need to generate the final classification logits
encode
training
call
which calls encode
and decode
and returns the final logitsclass Transformer(tf.keras.models.Model):
def __init__(self,
num_tokens,
num_tgt_tokens,
model_dim=256,
num_heads=8,
dropout=0.1,
ff_dim=2048,
num_encoder_blocks=6,
num_decoder_blocks=6,
share_embed_weights=False,
share_output_weights=False):
super(Transformer, self).__init__()
self.share_embed_weights = share_embed_weights
self.share_output_weights = share_output_weights
self.input_embedding = ScaledEmbedding(num_tokens, model_dim)
self.enc_pos_encoding = PositionalEncoding(model_dim, dropout=dropout)
self.dec_pos_encoding = PositionalEncoding(model_dim, dropout=dropout)
if not self.share_embed_weights:
self.target_embedding = ScaledEmbedding(num_tgt_tokens, model_dim)
self.encoder = Encoder(dim=model_dim, # 256
ff_dim=ff_dim, # 2048
num_heads=num_heads, # 8
dropout=dropout,
num_blocks=num_encoder_blocks)
self.decoder = Decoder(dim=model_dim, # 256
ff_dim=ff_dim, # 2048
num_heads=num_heads, # 8
dropout=dropout,
num_blocks=num_decoder_blocks)
if not self.share_output_weights:
#TODO: I think need to scale this
self.output_layer = tf.keras.layers.Dense(units=num_tgt_tokens)
def encode(self, x, src_mask, training=True):
x = self.input_embedding(x)
x = self.enc_pos_encoding(x, training=training)
memory, attn = self.encoder(x, mask=src_mask, training=training)
return memory, attn
def decode(self, y, memory, src_mask, tgt_mask, training=True):
if self.share_embed_weights:
y = self.input_embedding(y)
else:
y = self.target_embedding(y)
y = self.dec_pos_encoding(y, training=training)
out, self_attn, attn = self.decoder(y, memory,
memory_mask=src_mask,
decoder_mask=tgt_mask,
training=training)
if not self.share_output_weights:
logits = self.output_layer(out)
else:
# This works because this is called only after
# target_embedding is called so the weights will
# have been created
if not self.share_embed_weights:
embed = self.target_embedding
else:
embed = self.input_embedding
logits = tf.matmul(out, embed.weights[0], transpose_b=True)
return logits, self_attn, attn
def call(self, x, y, src_mask, tgt_mask, training=True):
memory, enc_attn = self.encode(x, src_mask, training=training)
logits, dec_self_attn, dec_attn = self.decode(y, memory,
src_mask=src_mask,
tgt_mask=tgt_mask,
training=training)
attention = dict(
enc=enc_attn,
dec_self=dec_self_attn,
dec_memory=dec_attn
)
return logits, attention
We have built a Transformer
but we are not done yet. The paper introduces some approaches to train the model and we need to implement those and we need to write to code to prepare the data and to process the outputs. In Part 2 we will learn how to do all of these and train a translation model.
Generative image modeling is a central problem in unsupervised learning.
However, because images are high dimensional and highly structured, estimating the distribution of natural images is extremely challenging.
One of the most important obstacles in generative modeling is building complex and expressive models that are also tractable and scalable.
One effective approach to tractably model a joint distribution of the pixels in the image is to cast it as a product of conditional distributions.
The factorization turns the joint modeling problem into a sequence problem, where one learns to predict the next pixel given all the previously generated pixels.
Whilst recurrent neural networks are an obvious choice:
\[p(x_{i,R}|x<i)p(x_{i,G}|x<i, x_{i,R})p(x_{i,B}|x<i, x_{i,R}, x_{i,G})\]We observe that Convolutional Neural Networks (CNN) can also be used as sequence model with a fixed dependency range, by using Masked convolutions. The PixelCNN architecture is a fully convolutional network of fifteen layers that preserves the spatial resolution of its input throughout the layers and outputs a conditional distribution at each location.
In this post we will focus on the PixelCNN architecture. They are simpler to implement and once we have grasped the concepts, we can move on to the PixelRNN architecture in a later post. A repository with the code for this post can be found here.
Note: This tutorial was written a while ago, initially intended for publication on Annotaited. There are several newer pixel-level generative models based on transformers that produce higher quality images. However the basic concepts of masked autoregressive models are still essentially the same so I think the material covered here will still be useful.
[Recurrent models] have a potentially unbounded dependency range within their receptive field. This comes with a computational cost as each state needs to be computed sequentially. One simple workaround is to make the receptive field large, but not unbounded. We can use standard convolutional layers to capture a bounded receptive field and compute features for all pixel positions at once.
Can you explain how a convolutional layer can predict sequences in parallel?
Each pixel of the output (leaving out the very first pixel) can be interpreted as a prediction conditioned on the input pixels above and to the left (leaving out the very last pixel).
Note that the advantage of parallelization of the PixelCNN over the PixelRNN is only available during training or during evaluating of test images. The image generation process is sequential for both kinds of networks, as each sampled pixel needs to be given as input back into the network.
Can you think of any other issues with parallel prediction?
When predicting pixels in parallel, for each pixel the model should only have access to "past" pixels i.e. those above and to the left. Otherwise it would be able to cheat by having access information about elements it is supposed to be predicting.
Masks are adopted in the convolutions to avoid seeing the future context
The $h$ features for each input position at every layer in the network are split into three parts, each corresponding to one of the RGB channels.
When predicting the R channel for the current pixel $x_i$, only the generated pixels left and above of $x_i$ can be used as context.
When predicting the G channel, the value of the R channel can also be used as context in addition to the previously generated pixels.
Likewise, for the B channel, the values of both the R and G channels can be used.
Let us consider an example with 3 feature maps. Step through the animation and try to determine which pixels from input feature maps will contribute to the prediction in each case. For clarity a single arrow is used for the above-and-left pixels to indicate they will be used.
To restrict connections in the network to these dependencies, we apply a mask …
We use two types of masks that we indicate with mask A and mask B
Mask A is applied only to the first convolutional layer in a PixelRNN and restricts the connections to those neighboring pixels and to those colors in the current pixels that have already been predicted.
On the other hand, mask B is applied to all the subsequent input-to-state convolutional transitions and relaxes the restrictions of mask A by also allowing the connection from a color to itself.
Why do we need to mask the input at position $i,j,c$ to predict the output at $i,j,c$ for the first layer but not for subsequent layers?
We only need to hide the true data. As the diagram shows, the output of the first layer at $i,j,c$, $X1[i,j,c]$ only depends on the previous pixels and not on the pixel $I[i,j,c]$ of the input image $I$, which is the target for position $i,j,c$. Since $X1[i,j,c]$ is therefore only connected to the valid context for position $i,j,c$, it can be used when predicting $X2[i,j,c]$.
The masks can be easily implemented by zeroing out the corresponding weights
This is easiest to see if we note that 3x3 convs are used in the Masked Conv layers and that “same” padding is added in order to preserve dimensions. Thus as the kernel slides over the image, the centre of a kernel will be aligned with the pixel whose value is going to be predicted. Like the feature maps are split into 3 corresponding to each colour, the $k \times k \times f_\text{in} \times f_\text{out} $ kernel can be split into 9 of size $k \times k \times (f_\text{in}/3) \times (f_\text{out}/3) $ kernels that map every colour to every other colour.
Applying the masking rules, how will the mask for a $3 x 3$ kernel from each of these groups look like for mask A and mask B? Assume, as noted above, that the centre will be aligned with the “present” pixel.
The dark pixels are 0 and the light pixels are 1. Note for clarity the centre pixel is shown slightly lighter or darker.
If you group together just the kernel centres, which correspond to the present pixel position you get what looks like an upper triangular matrix. For mask A everything above the diagonal is 1 and for mask B everything above and including the diagonal is 1. Actually it will be a block triangular matrix with blocks of size $f_\text{in}/3 \times f_\text{out}$.
Now try implementing this, keeping in mind these points:
def get_mask(kernel_size, features_in, features_out, mode='B'):
assert (features_in % 3) == 0, features_in
assert (features_out % 3) == 0, features_out
# Assume centre of kernel corresponds to present pixel
# which will be the case if kernel dims are odd
# and "same" padding is used
h, w = kernel_size
i_cent = h // 2
j_cent = w // 2
mask = np.ones((h, w, features_in, features_out))
# all zeros to the left in the centre row
mask[i_cent, (j_cent + 1):] = 0.
# all zeros below
mask[(i_cent + 1):] = 0.
# clr2clr[i, j] indicates whether colour_i
# in the previous layer is connected to colour_j in present
# where colours are R,G,B.
# Entries above the diagonal are always 1 and below always 0
# since there is no backward flow.
# For mask B a colour feeds into itself in the next layer
# so the diagonal entries are 1 but for mask A they are 0
# meaning that the colour can't feed into itself if mask A is used
clr2clr = np.triu(np.ones((3, 3)), k=1 if mode=='A' else 0)
rgb_mask = np.repeat(np.repeat(clr2clr, features_in//3, axis=0),
features_out//3, axis=1)
mask[i_cent, j_cent] = rgb_mask
return mask.astype('float32')
Try plotting the masks from the kernels.
## Parameters to plot the image for mask A shown. Modify to plot different masks
mode = 'A'
k = 7
ksize= (k, k)
f_in = 15
f_out = 24
## Code to plot
m = get_mask(ksize, f_in, f_out, mode)
masks = [[m[:, :, i * f_in//3, j * f_out//3]
for j in range(3)]
for i in range(3)]
masks = [np.pad(i, ([(0, 1)] * 2), constant_values=1)
for i in 2 * np.stack(masks).reshape([-1, 3, 3])
]
masks = np.reshape(masks, ksize + (4, 4))
fig, axes = plt.subplots(3, 3, figsize=(8, 8))
for i, clr_in in enumerate('RGB'):
for j, clr_out in enumerate('RGB'):
ax = axes[i, j]
kernel = deepcopy(m[:, :, i * f_in//3, j * f_out//3]) * 255
mid = np.floor_divide(kernel.shape, 2)
mid_val = kernel[mid[0], mid[1]]
kernel[mid[0], mid[1]] = 55 if mid_val == 0 else 200
ax.pcolor(kernel, edgecolor='k', vmin=0, vmax=255, cmap='gray')
if j == 0:
ax.set_ylabel('$%s_{in}$'%clr_in, color=clr_in.lower(), fontsize=16)
if i == 0:
ax.set_xlabel('$%s_{out}$'%clr_out, color=clr_out.lower(), fontsize=16)
ax.set_xticks([])
ax.set_yticks([])
ax.invert_yaxis()
ax.xaxis.set_label_position('top')
plt.suptitle(f'Mask {mode} for a {k} x {k} kernel', fontsize=24)
Let us now implement a masked conv layer. A simple approach is to implement conv layer where the kernel is multiplied with the mask before the convolution step.
class MaskedConv2D(tf.keras.layers.Layer):
def __init__(self, kernel_size, filters_in, filters, act=None, mode='B'):
super(MaskedConv2D, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
self.filters = filters
self.mode = mode
if act is not None:
self.act = tf.keras.layers.Activation(act)
self.kernel = self.add_weight(name='kernel', shape=(*self.kernel_size, filters_in, self.filters))
self.bias = self.add_weight(name='bias', shape=(self.filters,))
mask_kwargs = dict(kernel_size=self.kernel_size,
features_in=filters_in,
features_out=self.filters,
mode = self.mode)
self.mask = get_mask(**mask_kwargs)
def call(self, x):
kernel = self.kernel * self.mask
out = tf.nn.conv2d(x, kernel, [1, 1, 1, 1], padding='SAME')
out = tf.nn.bias_add(out, self.bias)
if hasattr(self, 'act'):
return self.act(out)
return out
Note: for simplicity we are passing in the number of input channels. An alternative approach would be to redefine the `build` function so that the mask as well as the weights are initialised when the module is first called so that the number of filters is inferred from the input rather than having to be provided in advance.
Once masking is understood, the architecture is actually rather simple and does not require many lines of code.
The first layer is a 7 × 7 convolution that uses the mask of type A … The PixelCNN [then] uses convolutions of size 3 × 3 with a mask of type B. The top feature map is then passed through a couple of layers consisting of a Rectified Linear Unit (ReLU) and a 1×1 convolution. For the CIFAR-10 and ImageNet experiments, these layers have 1024 feature maps; for the MNIST experiment, the layers have 32 feature maps. Residual and layer-to-output connections are used across the layers… [Emphasis added]
This is summarised in the following table:
The residual block is as shown below:
Let us a implement a module PixelCNNResBlock
based on the diagram that makes use of the MaskedConv2D
layer.
class PixelCNNResBlock(tf.keras.models.Model):
def __init__(self, filters_in, filters):
super(PixelCNNResBlock, self).__init__()
self.conv_in = MaskedConv2D(1, filters_in, filters, act='relu')
self.conv_mid= MaskedConv2D(3, filters, filters, act='relu')
self.conv_out = MaskedConv2D(1, filters, 2 * filters, act='relu')
def __call__(self, x):
out = self.conv_in(x)
out = self.conv_mid(out)
out = self.conv_out(out)
return x + out
Now we can construct the model. Reshape the output so that it has shape [batch_size, height, width, 3, 256]
corresponding to the 256 possible intensities for each colour. We will call this PixelRNN
and pass in the residual block as an argument making it possible to reuse the module to also build the recurrent models in the paper later.
class PixelRNN(tf.keras.models.Model):
def __init__(self, hidden_dim, out_dim, n_layers, pixel_layer):
super(PixelRNN, self).__init__()
hidden_dim = hidden_dim * 3
out_dim = out_dim * 3
self.input_conv = MaskedConv2D(kernel_size=7,
filters_in=3,
filters=2 * hidden_dim,
mode='A')
self.pixel_layer = [pixel_layer(2 * hidden_dim, hidden_dim) for _ in range(n_layers)]
self.output_conv1 = MaskedConv2D(kernel_size=1,
filters_in=2 * hidden_dim,
filters=out_dim)
self.output_conv2 = MaskedConv2D(kernel_size=1,
filters_in=out_dim,
filters=out_dim)
self.final_conv = MaskedConv2D(kernel_size=1,
filters_in=out_dim,
filters=256 * 3)
def __call__(self, x):
y = self.input_conv(x)
for layer in self.pixel_layer:
y = layer(y)
y = self.output_conv1(tf.nn.relu(y))
y = self.output_conv2(tf.nn.relu(y))
y = self.final_conv(y)
y = tf.reshape(y, tf.concat([tf.shape(y)[:-1], [3, 256]], 0))
return y
We will use the CIFAR10 dataset and load it into memory using keras. Then we write a simple function get_dataset
to create a dataset that does the following
def get_dataset(imgs, batch_size=16, mode='train'):
def _map_fn(x):
labels = tf.cast(x, tf.int32)
imgs = tf.cast(x, tf.float32)
imgs = imgs / 255.
return imgs, labels
ds = tf.data.Dataset.from_tensor_slices(imgs)
if mode == 'train':
ds = ds.shuffle(1024)
ds = ds.repeat(-1)
ds = ds.map(_map_fn)
return ds.batch(batch_size)
All our models are trained and evaluated on the log-likelihood loss function coming from a discrete distribution. For the multinomial loss function we use the raw pixel color values as categories.
In other words, the model is trained with a softmax cross-entropy loss over the 256 classes for each of red, green and blue.
For CIFAR-10 and ImageNet we report negative log-likelihoods in bits per dimension. The total discrete log-likelihood is normalized by the dimensionality of the images (e.g., 32×32×3 = 3072 for CIFAR-10). These numbers are interpretable as the number of bits that a compression scheme based on this model would need to compress every RGB color value (van den Oord & Schrauwen, 2014b; Theis et al., 2015); in practice there is also a small overhead due to arithmetic coding.
How can you derive the negative log-likelihoods (NLL) in bits per dimension from the softmax loss?
The softmax loss for $N$ RGB images of size $H \times W$ is as follows:
\({L_\text{softmax}(I_\text{true}, I_\text{pred})}=
\\-\frac{1}{3\cdot NHW}\ln\prod_{i=0}^{N-1}\prod_{x=0}^{H-1}\prod_{y=0}^{W-1}p(x_{i,R}|x_{<i})p(x_{i,G}|x_{<i}, x_{i,R})p(x_{i,B}|x_{<i}, x_{i,R}, x_{i,G})\)
where $x$ is the true intensity.
This is the normalized NLL in nats averaged across the images. Here the $\log$ has base of $e$ whereas to get it in bits per dimension we need to use a base of 2. Since $\log_2(x) = \ln(x) / \ln(2)$, simply dividing the softmax loss by $\ln(2)$ e.g. tf.log(2)
or np.log(2)
gives you the NLL in bits normalized by the dimensionality of the images and averaged across the images.
RMSProp gives best convergence performance and is used for all experiments. The learning rate schedules were manually set for every dataset to the highest values that allowed fast convergence.
As further details were not provided here about the learning rates, I looked OpenAI’s codebase for PixelCNN++ and found that they used Adam with a learning rate of 0.001 multiplied by 0.999995 after each step so that is what we will use here.
Let us create a Trainer
class that will be used to train the model. It will store the model, optimizer, loss function and metrics and have methods to train and evaluate the model. Here is some code to get you started.
class Trainer(object):
def __init__(self,
config: dict,
model: tf.keras.models.Model,
loss_fn: Union[tf.keras.losses.Loss, Callable],
optim: tf.keras.optimizers.Optimizer):
self.config = config
self.model = model
self.loss_fn = loss_fn
self.optim = optim
self.ckpt = tf.train.Checkpoint(
transformer=self.model,
optimizer=self.optim
)
self.trn_writer = tf.summary.create_file_writer(os.path.join(config.log_path, 'train'))
self.test_writer = tf.summary.create_file_writer(os.path.join(config.log_path, 'test'))
self.img_writer = tf.summary.create_file_writer(os.path.join(config.log_path, 'imgs'))
Now add the following methods to the class:
run_model
which takes a batch of images and returns the model output and losstrain_step
which takes a batch of images, updates the model weights and returns the lossvalid_step
which takes a batch of images and returns the lossevaluate
which takes a dataset and returns the average lossFeel free to add any other methods you think are necessary as well as code to print metrics, add summaries, etc.
def run_model(self, images, labels):
predictions = self.model(images)
batch_size = tf.shape(images)[0]
# Model can be evaluated on all except for the very first element
# i.e. the R value of the top left corner pixel
y_true = tf.reshape(labels, [batch_size, -1])[:, 1:]
y_pred = tf.reshape(predictions, [batch_size, -1, 256])[:, 1:]
loss = tf.reduce_mean(self.loss_fn(labels=y_true, logits=y_pred))
return loss, predictions
@tf.function
def train_step(self, images, labels):
with tf.GradientTape() as tape:
loss, predictions = self.run_model(images, labels)
grads = tape.gradient(loss, self.model.trainable_variables)
self.optim.apply_gradients(zip(grads, self.model.trainable_variables))
itr = self.optim.iterations
bits_per_dim = loss / tf.math.log(2.)
with self.trn_writer.as_default():
tf.summary.scalar('loss', bits_per_dim, itr)
tf.summary.scalar('lr', self.optim.learning_rate, itr)
self.optim.learning_rate.assign(
self.optim.learning_rate * 0.999995
)
return bits_per_dim, predictions
def valid_step(self, images, labels):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss, predictions = self.run_model(images, labels)
return loss, predictions
@tf.function
def evaluate(self, dataset: tf.data.Dataset):
loss = tf.TensorArray(tf.float32, size=tf.cast(len(dataset), tf.int32))
for idx, (images, labels) in enumerate(dataset):
loss_, _ = self.valid_step(images, labels)
idx = tf.cast(idx, tf.int32)
loss = loss.write(idx, loss_)
loss = loss.stack()
bits_per_dim = tf.reduce_mean(loss / tf.math.log(2.))
with self.test_writer.as_default():
tf.summary.scalar(
'bits_per_dim',
bits_per_dim,
self.optim.iterations
)
return bits_per_dim
Inference as noted before must be done sequentially. How do think it is done?
Now implement a function generate_images
, which given an model and a number of images to generate, repeatedly calls the model to produce an image. The function should use TensorFlow so that it can be run on the GPU and as part of a graph.
Hint: use a tf.while_loop
to loop over the HWC pixels of the image and use tf.tensor_scatter_nd_update
to update the image.
def generate_images(
model: tf.keras.models.Model,
img_dims: Iterable,
n_images: int=1,
n_channels: int=3,
initial_values=None,
) -> tf.Tensor:
batch_inds = tf.expand_dims(tf.range(n_images), axis=-1)
# [0, 0, 0, 0, ...], [1, 0, 0, 0, ...], [2, 0, 0, 0, ...], ...
init_inds = tf.concat(
[
batch_inds,
tf.zeros((n_images, n_channels), dtype='int32')
], axis=-1
)
if initial_values is None:
initial_values = tf.random.uniform(shape=(n_images,),
minval=0,
maxval=1,
dtype='float32')
img = tf.scatter_nd(indices=init_inds,
updates=initial_values,
shape=(n_images, *img_dims, n_channels))
# Letting step = 1 lets us omit the condition
# if (clr + colm + row) > 0;
# which is used to avoid updating the top left pixel of the first channel
step = tf.constant(1)
total_steps = tf.math.reduce_prod(img_dims) * n_channels
def cond(step, img):
return tf.less(step, total_steps)
def body(step, img):
row = tf.math.floordiv(step, img.shape[2] * n_channels)
colm = tf.math.floordiv(tf.math.mod(step, img.shape[2] * n_channels), n_channels)
clr = tf.math.mod(step, n_channels)
update_inds = tf.concat(
[
batch_inds,
tf.tile([[row, colm, clr]], [n_images, 1])
], axis=-1
)
# The prediction only depends on the pixels to the left and above
# so we can be get a bit of speedup by using only the rows
# upto and including the current row
result = model(img[:, :row + 1, :, :])
result_rgb = tf.random.categorical(
logits=result[:, row, colm, clr],
num_samples=1
)
result_rgb = tf.cast(tf.squeeze(result_rgb, axis=-1), tf.float32)
img = tf.tensor_scatter_nd_update(img,
indices=update_inds,
updates=result_rgb / 255.)
return step + 1, img
step, img = tf.while_loop(cond, body, [step, img])
return img
Now we can add a function to Trainer
to make image summaries over the course of training.
@tf.function
def create_image_samples(self):
itr = self.optim.iterations
gen_imgs = generate_images(
model=self.model,
img_dims=self.config.img_dims,
)
with self.img_writer.as_default():
tf.summary.image('gen_imgs', tf.cast(gen_imgs * 255, tf.uint8), itr)
Now that all the components have been implemented we can put them together to train the model. This has been done in the file main.py
). The settings are defined in config.yml
. Run python main.py
to train the model. A GPU is highly recommended as the model takes a long time to train on the CPU.
It took over 5 days to run ~450K steps on single GPU for the NLL to get close to the reported values. I stopped it after that as I needed the GPU for other stuff. The test loss went down to about 3.22 which is a bit higher than the reported value in the paper. I have didn’t measure the NLL over the entire training set at a fixed checkpoint and the plot is of the per batch values after each training set. However it can be seen that this converges to a similar level.
For comparison note that the better performing PixelCNN++ model is said to take around 5 days on 8 GPUs with a batchsize of 16 (so a 8x the number of images per step as used here) to converge it its best value. However image quality seems to improve much sooner. The figure below shows images generated at various points between around 150k and 450k iterations (ordered from left to right top to bottom).
Nevertheless it should be noted that the images don’t match the quality of those that are generated via models such as GANs so don’t be be worried or disappointed if your images don’t seem to resemble anything in the real world. We trade-off model interpretability for quality. Here are examples of images from the models trained in the paper and you can see what they have are plausible patterns of colours but they lack meaningful content.
]]>In this tutorial we will demonstrate how to implement Example 5.1: Blackjack from Reinforcement Learning (Sutton and Barto).
Blackjack is a card game where the goal is to obtain cards the sum of whose values is maximised without exceeding 21. All face cards (Jack, Queen, King) count as 10. An ace counts as 1 or 11 and we shall see below how this is determined.
The example uses first-visit Monte Carlo prediction to estimate the values $v_\pi(s)$ of each state $s$ under policy $\pi$ as the average of the returns following the first visits to $s$.
In code for an arbitrary episode generating function episode_fn
def first_visit_mc_prediction(pi, gamma, states, num_episodes, episode_fn):
num_states = len(states)
V = dict(zip(states, np.random.normal(size=num_states)))
returns = {state: [] for state in states}
for episode in tqdm.trange(num_episodes):
episode = episode_fn(pi)
S_0Tm1, A_0Tm1, R_1T = episode
G = 0
for t, (St, At, Rtp1) in list(enumerate(zip(*episode)))[::-1]:
G = gamma * G + Rtp1
if St not in S_0Tm1[:t]:
returns[St].append(G)
V[St] = np.mean(returns[St])
return V
Now let us implement a function that generates blackjack episodes
import numpy as np
import matplotlib.pyplot as plt
import itertools
from tqdm import autonotebook as tqdm
The game begins with two cards dealt to both dealer and player. One of the dealer’s cards is face up and the other is face down. If the player has 21 immediately (an ace and a 10-card), it is called a natural. He then wins unless the dealer also has a natural, in which case the game is a draw. If the player does not have a natural, then he can request additional cards, one by one (hits), until he either stops (sticks) or exceeds 21 (goes bust). If he goes bust, he loses; if he sticks, then it becomes the dealer’s turn. The dealer hits or sticks according to a fixed strategy without choice: he sticks on any sum of 17 or greater, and hits otherwise. If the dealer goes bust, then the player wins; otherwise, the outcome—win, lose, or draw—is determined by whose final sum is closer to 21.
The state will be as follows
the player makes decisions on the basis of three variables: his current sum (12–21), the dealer’s one showing card (ace–10), and whether or not he holds a usable ace. This makes for a total of 200 states.
Here is how we will interpret and implement the game
Initially dealer gets two cards, one face up and another face down. Any subsequent cards given to dealer are face down. The state is (current_sum_P, D_up, usable_ace_P)
, where D_up
does not change whilst current_sum_P
and usable_ace_P
can change
If the player holds an ace that he could count as 11 without going bust, then the ace is said to be usable. In this case it is always counted as 11 because counting it as 1 would make the sum 11 or less, in which case there is no decision to be made because, obviously, the player should always hit.
def blackjack_episode(pi):
def initialise(pair):
current_sum = pair.sum()
usable_ace = False
if (pair == 1).any():
current_sum += 10
usable_ace = True
# While < 11, ace is usable
while current_sum < 11:
deal = np.random.choice(cards)
current_sum += deal
if deal == 1:
current_sum += 10
usable_ace = True
# Since < 12, won't go bust adding any card (ace counted as 1)
# since max card value is <= 10
if current_sum == 11:
current_sum += np.random.choice(cards)
return current_sum, usable_ace
def hit(current_sum, usable_ace):
current_sum += np.random.choice(cards)
if current_sum > 21 and usable_ace:
current_sum -= 10
usable_ace = False
return current_sum, usable_ace
states = []
actions = []
rewards = []
cards = np.minimum(np.arange(1, 14), 10)
D_up, D_down = cards_D = np.random.choice(cards, size=2)
cards_P = np.random.choice(cards, size=2)
if set(cards_P) == {1, 10}:
states.append((21, D_up, True)) # natural
actions.append('stick')
if set(cards_D) == {1, 10}:
rewards.append(0)
else:
rewards.append(1)
return states, actions, rewards
current_sum_P, usable_ace_P = initialise(cards_P)
states.append(
(current_sum_P, D_up, usable_ace_P)
)
while True:
current_sum_P, _, usable_ace_P = state = states[-1]
action = pi(state)
actions.append(action)
if action == 'hit':
current_sum_P, usable_ace_P = hit(current_sum_P, usable_ace_P)
if current_sum_P > 21:
rewards.append(-1)
break
else:
rewards.append(0)
states.append((current_sum_P, D_up, usable_ace_P))
else:
# Dealer's turn
current_sum_D, usable_ace_D = initialise(cards_D)
while current_sum_D < 17:
current_sum_D, usable_ace_D = hit(current_sum_D, usable_ace_D)
if current_sum_D > 21:
rewards.append(1)
elif current_sum_D > current_sum_P:
rewards.append(-1)
elif current_sum_D == current_sum_P:
rewards.append(0)
else:
assert current_sum_D < current_sum_P
rewards.append(1)
break
return states, actions, rewards
Consider the policy that sticks if the player’s sum is 20 or 21, and otherwise hits.
def default_policy(state):
total, *_ = state
if total < 20:
return 'hit'
return 'stick'
Let us now plot a figure similar to Figure 5.1 in the book. First we need to run first_visit_mc_prediction
for 10k and 500k episodes.
_pi = default_policy
_gamma = 1
_showing_card = range(1, 11)
_current_sum = range(12, 22)
_usable_ace = [True, False]
_states = list(itertools.product(_current_sum, _showing_card, _usable_ace))
print('Number of states:', len(_states))
value = dict()
for _num_episodes in [10, 500]:
_episode_fn = blackjack_episode
print(f'Number of episodes: {_num_episodes},000')
value[_num_episodes] = first_visit_mc_prediction(
_pi, _gamma, _states, _num_episodes * 1000, _episode_fn)
Number of states: 200
Number of episodes: 10,000
0%| | 0/10000 [00:00<?, ?it/s]
Number of episodes: 500,000
0%| | 0/500000 [00:00<?, ?it/s]
fig = plt.figure(figsize=(8, 8))
for idx, (ne, val) in enumerate(value.items(), 1):
for i, u_ace in zip(range(0, 3, 2), [True, False]):
axis = fig.add_subplot(2, 2, idx + i, projection='3d')
if i == 0:
axis.set_title(f'After {ne},000 episodes')
X, Y = np.meshgrid(_showing_card, _current_sum)
Z = np.zeros_like(X).astype('float')
for (cs, sc, ua), v in val.items():
if ua == u_ace:
Z[cs-12, sc-1] = v
axis.plot_wireframe(X, Y, Z, linewidth=0.7, color='k')
if (idx + i) == 4:
axis.set_ylabel('Player sum')
axis.set_xlabel('Dealer showing')
axis.set_zlim(-1, 1)
axis.set_xlim(1, 10)
axis.set_ylim(12, 21)
axis.set_zticks([-1, 0, 1])
if idx == 1:
axis.text(2, 12, 4, ('U' if u_ace else 'No\nu') + 'sable\nace', fontdict={'ha': 'center'})
axis.set_box_aspect([1,1,0.2])
plt.tight_layout()
A popular metric for evaluating image generation models is the Fréchet Inception Distance (FID). Like the Inception score, it is computed on the embeddings from an Inception model. But unlike the Inception score, it makes use of the true images as well as the generated ones. In the post we will learn how to implement it in PyTorch.
Let $p_w$ be the real world data distribution and $p$ the generated data distribution, with mean, covariance $\left(\mathbf{m}_w, \mathbf{C}_w\right)$ and $\left(\mathbf{m}, \mathbf{C}\right)$.
The Fréchet distance between a pair of Gaussians with parameters $\left(\mathbf{m}_w, \mathbf{C}_w\right)$ and $\left(\mathbf{m}, \mathbf{C}\right)$ is given by
\[d\left(\left(\mathbf{m}, \mathbf{C}\right), \left(\mathbf{m}_w, \mathbf{C}_w\right)\right) = \left\lVert \mathbf{m} - \mathbf{m}_w\right\rVert_2^2 + \text{Tr}\left(\mathbf{C} + \mathbf{C}_w - 2\left(\mathbf{C}\mathbf{C}_w\right)^\frac{1}{2}\right)\]To compute the Fréchet Inception distance Inception features are used to find the covariance and the mean.
$\def\half{\frac{1}{2}}$ $\def\C{\mathbf{C}}$ $\def\Cw{\C_w}$ $\def\CCwh{\left(\C\C_w\right)^\half}$ $\newcommand\trace[1]{\text{Tr}\left({#1}\right)}$
First we need to get Inception features. In the paper that introduced FID, they write
For computing the FID, we propagated all images from the training dataset through the pretrained Inception-v3 model following the computation of the Inception Score … however, we use the last pooling layer as coding layer.
Here we will use the pretrained model from torchvision
and replace the last fully connected layer with an Identity
layer so that we get the pooled outputs from the last pooling layer.
def get_model():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.IMAGENET1K_V1)
# Hack to get features
model.fc = torch.nn.Identity()
return model
Now let us implement a function that calculates the covariance and the mean. Given features $X_i$ for example $i$
\[E\left[X\right] \approx \hat{E}\left[X\right] = \frac{1}{N}\sum_{i=1}^N X_i \\\\\\ \text{cov}\left(X\right) = E\left[XX^T\right] - E\left[X\right]E\left[X\right]^T \approx \frac{1}{N}\sum_{i=1}^N X_i X_i^T - \hat{E}\left[X\right] \hat{E}\left[X\right] ^T\]It will do this in a batched fashion, obtaining the features, then updating the sum of the features $X$ and the sum of $XX^T$ whilst keeping tracking of total number of inputs seen so far.
def get_moments(samples, model):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
with torch.no_grad():
X_sum = torch.zeros((2048, 1)).to(device)
XXT_sum = torch.zeros((2048, 2048)).to(device)
count = 0
for inp in tqdm.tqdm(samples):
# [B, F]
pred = model(inp.to(device))
# [B, F] -> [1, F] -> [F, 1]
X_sum += pred.sum(dim=0, keepdim=True).T
# [B, 1, F] x [B, F, 1] -> [B, F, F] -> [F, F]
XXT_sum += (pred[:, None] * pred[..., None]).sum(0)
count += len(inp)
# [F, 1]
mean = X_sum / count
# [F, F] - [F, F] -> [F, F]
cov = XXT_sum / count - mean @ mean.T
return mean, cov
Given the moments we can obtain the Fréchet Inception distance. Note that
\[\trace{\C + \Cw - 2\CCwh} = \trace{\C} + \trace{\Cw} - 2\trace{\CCwh}\]The first two terms are straightforward to evaluate. To obtain $\trace{\CCwh}$ note the sum of the eigenvalues of a matrix is equal to its trace. It can be shown that since both $\C$ and $\Cw$ are covariance matrices $\CCwh$ has ab eigenvalue matrix $D^\half$ where $D$ is the eigenvalue matrix of $\C\Cw$. Moreover, the eigenvalues of $\C\Cw$ are real and positive which means that the eigenvalues of $\CCwh$ are real. See the Appendix for more details.
def frechet_inception_distance(m_w, C_w, m, C, debug=False):
eigenvals = torch.linalg.eigvals(C @ C_w)
trace_sqrt_CCw = eigenvals.real.clamp(min=0).sqrt().sum()
if debug:
print('Largest imaginary part magnitude:', eigenvals[eigenvals.imag > 0].abs().max().item())
print('Most negative:', eigenvals[eigenvals.real < 0].real.min().item())
fid = ((m - m_w)**2).sum() + C.trace() + C_w.trace() - 2 * trace_sqrt_CCw
return fid
Note that although theoretically real and non-negative the eigenvalues might have small complex parts due to floating point errors. Values close to 0 might also be represented as tiny negative numbers necessitating clamping to 0 and higher.
TF Flowers is dataset containing photographs for four types of flowers: roses, tulips, daisies and dandelions. We split this into two disjoint parts a large one used as the reference dataset of “true samples” and a small one. We compare three datasets to the reference dataset
The FID of these datasets is as follows:
Dataset | FID |
---|---|
Photos small | 71.61 |
GLIDE samples | 116.71 |
Art | 162.6 |
Since the datasets involved are one or two orders of magnitude smaller compared to the sizes typically used in literature, these results
Unsurprisingly the small photos dataset has the best FID (lower is better), followed by the photo-realistic GLIDE samples whilst the paintings have the highest FID. GLIDE images were generated with the prompt of the form “a photo of a rose” and “a photo of roses” with half of the images generated using each prompt type. As you can see most of the images show relatively close up images and only a small number of flowers. It is possible that more diverse prompts might result in an improved FID. It is interesting that the painting dataset often shows the flowers as part of a scene and in this respect is more similar to the photos dataset. However the appearance seems to be sufficiently different to lead to large FID.
Here are the links to resources used to generate these results. I have not shared the datasets except for the generated samples but you should be able to generate them yourself
We wish to show that if $A$ and $B$ are covariance matrices then
Preliminaries
1. $AB$ is diagonalisable
\(\forall x. \text{ }\text{ } x^T{B^\half}A{B^\half}x = x^T\left({B^\half}\right)^T A{B^\half}x = \left({B^\half}x\right)^TA\left({B^\half}x\right) \geq 0\)
2. $\left(AB\right)^\half$ has real eigenvalues
A square root of a diagonalisable matrix can be written in the form $VD^\half V^{-1}$ where $D^\half$ is a diagonal matrix such that $D^\half_{ii} = \sqrt{D_{ii}}$ which are the eigenvalues of $\left(AB\right)^\half$
This is the first of a planned series of blogs covering background topics for DeepMind’s AlphaTensor paper. In this post we will cover Strassen’s algorithm for matrix multiplication.
Matrices matter to an extent unimaginable in in 1969 when the paper proposing this method was published and methods to multiply them fast are crucial. Most notably they are key to deep learning training and inference so it is very exciting that by using learning methods new algorithms for fast matrix multiplication have been discovered.
You are almost know the naïve method for multiplying two matrices. Restricting ourselves to square $n \times n$ matrices for simplicity, this can written as
\[Z = XY \\Z_{ij} = \sum_{t=1}^n X_{it}Y_{tj}\]To compute each element $Z_{ij}$ you need $n$ multiplications and $n-1$ additions and there are $n^2$ such elements leading to a total of $n^3$ multiplications and $n^2(n-1)$ additions so $n^2(2n - 1) = O(n^3)$ arithmetical operations in total.
Laid out like this it seems almost obvious that matrix multiplication requires $O(n^3)$ arithmetical operations. It turns out that it is a strictly simpler problem.
Strassen’s algorithm relies on the fact that the matrix multiplication formula is valid when $X_{ij}$, $Y_{ij}$ and $Z_{ij}$ are themselves matrices rather than numbers. This allows you to divide the matrix into blocks and work with them instead, as you can see in the example below
However simply recasting the naïve method to a recursive form is not enough. To gain an improvement the computations need to happen in a different way and you can find out how by reading the paper yourself.
Read the paper and optionally implement the algorithm.
The paper is less than 3 pages in length and quite easy to follow. Focus on the first two pages which cover multiplication (since the method can also be extended to finding inverses and determinants). Don’t worry if you don’t understand all the details. In the subsequent exercises we will derive some of the results in greater detail.
If implementing, I recommend not using any libraries like NumPy or similar libraries in other languages but representing the matrices as arrays and defining your own matrix multiplication and addition functions.
def matmul(X, Y):
(size, _), _ = shapeX, shapeY = list(map(get_shape, [X, Y]))
# Restrict to square
assert (set(shapeX + shapeY)) == {size}
Z = zeros_square(size)
for i in range(size):
for j in range(size):
for t in range(size):
Z[i][j] += (X[i][t] * Y[t][j])
return Z
def matadd(X, Y):
(size, _), _ = shapeX, shapeY = list(map(get_shape, [X, Y]))
# Restrict to square
assert (set(shapeX + shapeY)) == {size}
Z = zeros_square(size)
for i in range(size):
for j in range(size):
Z[i][j] = X[i][j] + Y[i][j]
return Z
def matsub(X, Y):
(size, _), _ = shapeX, shapeY = list(map(get_shape, [X, Y]))
# Restrict to square
assert (set(shapeX + shapeY)) == {size}
Z = zeros_square(size)
for i in range(size):
for j in range(size):
Z[i][j] = X[i][j] - Y[i][j]
return Z
def get_shape(X):
return (len(X), len(X[0]))
def zeros_square(size):
return [[0 for _ in range(size)] for _ in range(size)]
def split(X):
size = len(X)
halfsize = size // 2
result = [[zeros_square(halfsize) for _ in range(2)] for
_ in range(2)]
assert (halfsize *2) == size
for i in range(size):
for j in range(size):
result[i//halfsize][j//halfsize][i%halfsize][j%halfsize] = X[i][j]
return result
def merge(splits):
halfsize = len(splits[0][0])
size = halfsize * 2
result = zeros_square(size)
for i in range(size):
for j in range(size):
result[i][j] = splits[i//halfsize][j//halfsize][i%halfsize][j%halfsize]
return result
def strassen_matmul(X, Y, return_params=False):
(size, _), _ = shapeX, shapeY = list(map(get_shape, [X, Y]))
# Restrict to square
assert (set(shapeX + shapeY)) == {size}
# It might be more optimal in somes cases to round up to the nearest power of 2
# and pad the matrix accordingly but we will keep things simple here
# for demo purposes and express the size in form m*2^k
k = 0
m = size
while (m % 2) == 0:
m = m // 2
k += 1
if k == 0:
return matmul(X, Y)
[[X11, X12], [X21, X22]] = split(X)
[[Y11, Y12], [Y21, Y22]] = split(Y)
I = strassen_matmul(matadd(X11, X22), matadd(Y11, Y22))
II = strassen_matmul(matadd(X21, X22), Y11)
III = strassen_matmul(X11, matsub(Y12, Y22))
IV = strassen_matmul(X22, matsub(Y21, Y11))
V = strassen_matmul(matadd(X11, X12), Y22)
VI = strassen_matmul(matsub(X21, X11), matadd(Y11, Y12))
VII = strassen_matmul(matsub(X12, X22), matadd(Y21, Y22))
Z11 = matadd(matsub(matadd(I, IV), V), VII)
Z21 = matadd(II, IV)
Z12 = matadd(III, V)
Z22 = matadd(matsub(matadd(I, III), II), VI)
Z = merge([[Z11, Z12], [Z21, Z22]])
return (Z, k, m) if return_params else Z
# Example
import numpy as np # using np just for testing
s = 80 # = 2^4 * 5
A = np.random.randint(0, 100, (s, s))
B = np.random.randint(0, 100, (s, s))
Z = A @ B
Al, Bl = map(np.ndarray.tolist, [A, B])
assert (matmul(Al, Bl) == Z).all()
assert (merge(split(A)) == A).all()
assert (matadd(Al, Bl) == (A + B)).all()
assert (matsub(Al, Bl) == (A - B)).all()
ZZ, kk, mm = strassen_matmul(Al, Bl, return_params=True)
assert (ZZ == Z).all()
print(kk, mm) # => 4 5
Let $A(k)$ denote the number of additions and $M(k)$ the number of multiplications needed to multiply 2 $m 2^k$ square matrices. We know that $M(0) = m^3$ and $A(0) = m^2(m - 1)$.
Show the following from Fact 1 in the paper
\(M(k) = 7^k m^3\) \(A(k) = (5 + m)m^2 7^k - 6(m2^k)^2\)
There is one matrix multiplication in steps $I-IV$ so 7 in total, involving the sub-matrices of size $m 2^{k-1}$, each of which contributes $M(k-1)$ multiplication operations
\[M(k) = 7M(k-1) = 7^2M(k-2)= \cdots = 7^kM(0) = 7^k m^3\]There are 18 matrix additions, 10 in $I-IV$ and $8$ in the subsequent steps, involving the sub-matrices of size $m 2^{k-1}$ leading to $18(m 2^{k-1})^2$ addition operations. But there are also the additions involved in the 7 matrix multiplications in $I-IV$, each of which contributes contributes $A(k-1)$ additions
\[A(k) = 18 m^2 4^{k-1} + 7A(k-1) \\= 18 m^2 4^{k-1} + 7\cdot18 m^2 4^{k-2} + 7^2A(k-2) \\= 18 m^2 4^{k-1} + 7\cdot18 m^2 4^{k-2} + 7^2\cdot18 m^2 4^{k-3} + 7^3A(k-3) \\= \cdots \\ = (18m^24^{k-1})\sum_{t=0}^{k-1}(7/4)^t + 7^kA(0) \\ = (18m^24^{k-1})\frac{(7/4)^k - 1}{7/4 - 1} + 7^km^2(m-1) \\ = (18m^24^{k-1})\frac{7^k/4^{k-1} - 4}{3} + 7^km^2(m-1) \\ = 6m^2\left(7^k - 4^k\right) + 7^km^2(m-1) \\ = (5 + m)m^2 7^k - 6(m 2^k)^2\]I think the the notation $[x]$ here refers to the floor function or the integral part of the number which for a positive number is equivalent to floor. In any case from the way Fact 2 is used in the rest of the paper it would appear that
\[x - 1 \leq [x] \leq x\]Show that
\[(5 + 2m)m^27^k - 6(m2^k)^2 < 2n^3(7/8)^k + 12.03n^2(7/4)^k\]We know that
\[m = \left[n 2^{-k}\right] + 1 \\ \implies \left(n 2^{-k} - 1\right) + 1 \leq m \leq n 2^{-k} + 1 \\ \implies n 2^{-k} \leq m \leq n 2^{-k} + 1\]Hence
\[(5 + 2m)m^27^k - 6(m2^k)^2 \\ \leq \left(5 + 2\left(n 2^{-k} + 1\right)\right)\left(n 2^{-k} + 1\right)^27^k - 6\left(\left(n 2^{-k} + 1\right)2^k\right)^2 \\ \lt \left(5 + 2\left(n 2^{-k} + 1\right)\right)\left(n 2^{-k} + 1\right)^27^k\]since $6\left(\left(n 2^{-k} + 1\right)2^k\right)^2 > 0$.
\[=7^{k} \left(7 + 16 n2^{- k} + 11 n^{2}2^{- 2 k} + 2 n^{3}2^{- 3 k}\right) \\=7^{k}\left(7 + 16 n2^{- k} + 11 \left(n 2^{-k}\right)^2 + 2 \left(n 2^{-k}\right)^3\right)\]and as we have that $k \leq \log_2 n - 4 \implies 16\leq n 2^{-k}$
\[\leq 7^{k}\left((7/256)\left(n2^{- k}\right)^2 + \left(n2^{- k}\right)^2 + 11 \left(n 2^{-k}\right)^2 + 2 \left(n 2^{-k}\right)^3\right) \\= 7^{k}\left((11 + 263 / 256) \left(n 2^{-k}\right)^2 + 2 \left(n 2^{-k}\right)^3\right) \\ < 12.03 n^2(7/8)^k + 2 n^3(7/4)^k\]Show that
$2n^3(7/8)^k + 12.03n^2(7/4)^k \leq \left(2(8/7)^{\log_2 n - k} + 12.03(4/7)^{\log_2 n - k}\right)n^{\log_2 7}$
Since $n = 2^{\log_2 n}$
\[2n^3(7/8)^k + 12.03n^2(7/4)^k \\=\left(2\left(8^{\log_2 n}/8^k\right)(7^k/7^{\log_2 n}) + 12.03\left(4^{\log_2 n}\right/4^k) (7^k/7^{\log_2 n})\right)7^{\log_2 n}\]and since $7^{\log_2 n} = \left(2^{\log_2 7}\right)^{\log_2 n} = \left(2^{\log_2 n}\right)^{\log_2 7} = n^{\log_2 7}$
\[=\left(2(8/7)^{\log_2 n - k} + 12.03(4/7)^{\log_2 n - k}\right)n^{\log_2 7}\]Here we will finally prove that the number of operations needed is strictly less than $O(n^3)$.
Show that
\[\left(2(8/7)^{\log_2 n - k} + 12.03(4/7)^{\log_2 n - k}\right)n^{\log_2 7} \leq 4.7 \cdot n^{\log_2 7}\]Let $t = \log_2 n - k$. We know that
\[\left(\log n - 4\right) - 1 \leq k \leq \log n - 4 \\ \implies \log n - 5 \leq k \leq \log n - 4 \\ \implies 4 \leq \log n - k \leq 5 \\ \implies 4 \leq t \leq 5\]$f(t) = 2(8/7)^{\log_2 n - k} + 12.03(4/7)^{\log_2 n - k}$ is convex since it is the sum of two functions $2(8/7)^{\log_2 n - k}$ and $12.03(4/7)^{\log_2 n - k}$ each of which is convex since $t \geq 1$ and the coefficients are positive.
So the maximum value of the left hand side is given as
\[\left(2(8/7)^{\log_2 n - k} + 12.03(4/7)^{\log_2 n - k}\right)n^{\log_2 7} \\ \leq \max_{4 \leq t \leq 5}\left(2(8/7)^t + 12.03(4/7)^t\right)n^{\log_2 7}\]Let $f(t) = 2(8/7)^t + 12.03(4/7)^t$. We wish to find the maximum of this function. As the sum of $2(8/7)^t$ and $12.03(4/7)^t$ each of which is a convex function it is of the form $\alpha \cdot b^t$ for $\alpha > 0$, $f(t)$ is a convex function.
We can see that $f(4) \approx 4.7 \gt f(5) \approx 4.6$. From the plot it looks like $f(4)$ is maximum value of $f(t)$ in this interval and we can confirm this mathematically.
Any value $t$ in the interval can be expressed as
\[t(\theta) = 4\theta + 5(1 - \theta) \\ 0 \leq \theta \leq 1\]Then we find that
\[\forall t \in [4, 5] \\ f(t) = f\left(4\theta + 5(1 - \theta\right) \\ \leq \theta \cdot f(4) + (1 - \theta)\cdot f(5)\]since $f$ is convex
\[\leq \theta \cdot f(4) + (1 - \theta)\cdot f(4) \\ \leq f(4)\]Therefore
\[\left(2(8/7)^{\log_2 n - k} + 12.03(4/7)^{\log_2 n - k}\right)n^{\log_2 7} \\ \leq 4.7\cdot n^{\log_2 7}\]Compared to the original $O(n^3)$, less than $4.7\cdot n^{\log_2 7} \approx O(n^{2.8})$ arithmetical operations does not seem a terribly impressive result. Even so, it leads to non-negligible improvements for sufficiently large matrices.
But what is important is that it raises the question of whether there aren’t even simpler algorithms. However the paper does not explain how this method comes about nor does it give any hints about how one might go about finding more optimal algorithms. AlphaTensor uses deep reinforcement learning to look for solutions.
]]>$\def\ttheta{\boldsymbol{\theta}}$
import numpy as np
import matplotlib.pyplot as plt
In the previous post we introducted the short-corridor gridworld with switched actions from Example 13.1 Reinforcement Learning (Sutton and Barto)
Previously we found an analytical solution for the probability for moving right. Now we will use the REINFORCE algorithm to solve this problem.
This is a situation where with a discrete as well as small action space (move left or right). Consequently a suitable policy parameterization is a softmax over action preferences
\[\pi\left(a \vert s, \ttheta\right) = \frac{e^{h\left(s, a, \ttheta\right)}}{\sum_b e^{h\left(s, b, \ttheta\right)}} \\\\\\ h\left(s, a, \ttheta\right) = \theta^T\mathbf{x}\left(s,a\right)\]where $\mathbf{x}\left(s,a\right)$ is a feature vector.
def policy(t, feature):
logits = t @ feature
logits_exp = np.exp(logits - np.max(logits, -1, keepdims=True))
# [T, [pR, pL]]
pi = logits_exp / logits_exp.sum(-1, keepdims=True)
return pi
We are going to implement a vectorised version of REINFORCE to enabling running several trials at once whose results will then be averaged. The states will be numbered 1 to 4 from left to right
1 2 3 4
+-------+-------+-------+-------+
| S | <---. | | |
| <-o-> | .-o-' | <-o-> | G |
| | '---> | | |
+-------+-------+-------+-------+
The feature vectors $\mathbf{x}\left(s,a\right)$ are one hot vectors and they don’t depend on the state
\[x(s, R) = [0, 1]^T \\\\\\ x(s, L) = [1, 0]^T\]x_sR = np.array([[1], [0]])
x_sL = np.array([[0], [1]])
# [2, 2]
X = np.concatenate([x_sR, x_sL], axis=-1)
For $\alpha$ we will try three values: $2^{-12}, 2^{-13}, 2^{-14}$
alpha_powers = [-12, -13, -14]
num_alpha = len(alpha_powers)
episodes = 1000
trials = 100
max_steps = 1000
result = np.zeros([num_alpha, trials, episodes, 2])
G0 = np.zeros([num_alpha, trials, episodes])
To handle having different random seeds for each trial, we can create a matrix of random values in advance from each seed and then used these to sample from the policy such that if the value is less that $\pi(R \vert s)$, $R$ is be chosen otherwise $L$ is chosen.
seeds = [np.random.randint(low=0, high=2**32) for run in range(trials)]
random_vals = np.zeros([num_alpha, trials, episodes, max_steps])
for t, s in enumerate(seeds):
random_vals[:, t] = np.random.RandomState(s).uniform(size=random_vals[:, t].shape)
Now create a function that generates an episode for all the trials at once
def generate_episode(theta, X, random_values):
"""
theta: [trials, num_actions]
X: [num_actions, num_actions]
random_values: [trials, max_steps]
"""
trials, max_steps = random_values.shape
# initialise arrays
states = [np.ones(trials)]
actions = []
rewards = []
# keep track of whether the terminal state has been reached for each trial
inprog = [np.ones(trials).astype('bool')]
step = 0
while inprog[-1].any() and step < max_steps:
s = states[-1]
# Find probability of going right according to policy
pi = policy(theta, X)
pR = pi[:, 0]
# 1 for right, -1 for left
a = np.where(random_values[..., step] < pR, 1, -1)
actions.append(a)
# To flip action for state 2
coef = np.where(s==2, -1, 1)
# If in progress, update by moving left or right,
# flipping for state 2 and disallowing moving
# left from state 1
# If not in progress, remain in state 4 which is
# the terminal state
s_next = np.where(inprog[-1], np.maximum(coef * a + s, 1), 4)
states.append(s_next)
# Reward is -1 until terminal state when it is 0
rewards.append(-np.ones(trials) * inprog[-1].astype('float'))
inprog.append((s_next!=4))
step += 1
return states, actions, rewards, inprog
Then we can implement the second part. Following the “official” (and non-vectorised) implementation on the RL book website we have $\gamma = 1$.
def update_parameters(theta, X, alpha, states, actions, returns, inprog):
# Calculate returns in advance
returns = np.cumsum(rewards[::-1], axis=0)[::-1]
for prog_t, st, gt, at in zip(inprog, states, returns, actions):
# R=1 -> 0, L=-1 -> 1
act_idx = (at < 0).astype('int')
# [T, 2]
x_sa = X.T[act_idx]
# [T, 2]
pi = policy(theta, X)
# See Exercise 13.3
# [T, 2] - [T, 2] -> [T, 2]
ln_pi_grad = x_sa - pi @ X.T
update = alpha * gt[:, None] * ln_pi_grad
theta[prog_t] += update[prog_t]
return theta.squeeze(), returns[0]
Now we can put these together to run the example. We initialise $\ttheta = [\ln(1), \ln(19)]$ also following the official implementation.
for idx, power in enumerate(alpha_powers):
theta = np.tile(np.log([1, 19])[None], [trials, 1])
alpha = 2**power
for episode in range(episodes):
states, actions, rewards, inprog = generate_episode(theta, X, random_vals[idx, :, episode])
result[idx, :, episode], G0[idx, :, episode] = update_parameters(theta, X, alpha, states, actions, rewards, inprog)
if episode % 100 == 0:
print(f'alpha=2^{power}, episode={episode + 1}', policy(theta, X).mean(0).squeeze())
alpha=2^-12, episode=1 [0.05137732 0.94862268]
alpha=2^-12, episode=101 [0.2730593 0.7269407]
alpha=2^-12, episode=201 [0.37056455 0.62943545]
alpha=2^-12, episode=301 [0.43080733 0.56919267]
alpha=2^-12, episode=401 [0.47190531 0.52809469]
alpha=2^-12, episode=501 [0.49833314 0.50166686]
alpha=2^-12, episode=601 [0.51648553 0.48351447]
alpha=2^-12, episode=701 [0.54041319 0.45958681]
alpha=2^-12, episode=801 [0.5450163 0.4549837]
alpha=2^-12, episode=901 [0.55932724 0.44067276]
alpha=2^-13, episode=1 [0.05070895 0.94929105]
alpha=2^-13, episode=101 [0.15114146 0.84885854]
alpha=2^-13, episode=201 [0.22269873 0.77730127]
alpha=2^-13, episode=301 [0.2788509 0.7211491]
alpha=2^-13, episode=401 [0.32195895 0.67804105]
alpha=2^-13, episode=501 [0.35928062 0.64071938]
alpha=2^-13, episode=601 [0.39119684 0.60880316]
alpha=2^-13, episode=701 [0.41764465 0.58235535]
alpha=2^-13, episode=801 [0.44153935 0.55846065]
alpha=2^-13, episode=901 [0.45895968 0.54104032]
alpha=2^-14, episode=1 [0.05057726 0.94942274]
alpha=2^-14, episode=101 [0.09540479 0.90459521]
alpha=2^-14, episode=201 [0.13557795 0.86442205]
alpha=2^-14, episode=301 [0.17262241 0.82737759]
alpha=2^-14, episode=401 [0.20356134 0.79643866]
alpha=2^-14, episode=501 [0.23276498 0.76723502]
alpha=2^-14, episode=601 [0.25941758 0.74058242]
alpha=2^-14, episode=701 [0.28391618 0.71608382]
alpha=2^-14, episode=801 [0.30547963 0.69452037]
alpha=2^-14, episode=901 [0.32668467 0.67331533]
Finally we plot the averaged values which can be compared to value using the optimal choice for the probability to move the right that we derived earlier.
v_best = -11.65685424949238 # from before
plt.figure(figsize=(12,8))
for i in range(3):
plt.plot(np.arange(episodes) + 1, G0[i].mean(0), label=f'alpha=2^{12 + i}')
plt.title('REINFORCE for Example 13.1 averaged across 100 trials for different alpha')
plt.ylabel('$G_0$ - total return in the episode')
plt.xlabel('Number of episodes')
plt.xlim([1, episodes])
plt.ylim([-95, -10])
plt.hlines(xmin=1, xmax=episodes, y=v_best, label='$v*(s_0)$', linestyle='--');
plt.legend();
In an earlier post about we learned about diffusion models, a new and powerful type of generative model that most notably is the basis for OpenAI’s DALL-E 2. However an important disadvantage with these models is that a large number forward passes - typically 1000 or more - are needed to generate samples, making inference slow and expensive. In this post we will look at a couple of different methods that have been developed to speed up inference.
Methods like DDIM allow you to generate 20 or more images of comparable quality in the time it would take to generate a single one using the original sampling technique
The reason inference is slow is because the sampling process reverses the forward process which has $T$ timesteps. A key insight from the paper “Denosing Diffusion Implicit Models” (DDIM) is that the objective $L_\text{simple}$ only depends on $\qrcond{\xt}{\xz}$ and not on the joint distribution $\qrcond{\xoneT}{\xz}$. This means that there could be other inference processes - including those that need fewer steps - which could have the same marginals.
One approach is to define a family of distributions indexed by $\sigma$
\[q_\sigma\left(\xoneT\vert\xz\right) = q_\sigma\left(\xtt{T} \vert\xz\right) \prod_{t=2}^Tq_\sigma\left(\xtmone \vert\xt, \xz\right)\]where $q_\sigma\left(\xtmone \vert\xt, \xz\right)$ is Gaussian
\[\qrsigcond{\xoneT}{\xz} = \qrsigcond{\xtt{T}}{\xz}\prod_{t=2}^T\qrsigcond{\xtmone}{\xt, \xz}\]From this a new non-Markovian forward process is derived using Bayes rule
\[\qrsigcond{\xtt{1}}{\xtmone, \xz} = \frac{ \qrsigcond{\xtmone}{\xt, \xz}\qrsigcond{\xt}{\xz}}{ \qrsigcond{\xtmone}{\xz} }\]It is not Markovian since $\xt$ depends on $\xz$ as well as $\xtmone$. However the parameters of $q_\sigma\left(\xtmone \vert\xt, \xz\right)$ are chosen such that $q_\sigma\left(\xt \vert \xz\right)$ is the same as in the original Markovian forward process.
Now we can define a new reverse distribution based on $q_\sigma\left(\xtmone \vert \xt, \xz\left(\xt\right)\right)$. As before we can estimate $\xz$ given $\xt$. Define $f_\theta^{(t)}\left(\xt\right) := \frac{1}{\abt}\left(\xt - \sqrt{1 - \abt}\etta\left(\xt, t\right)\right)$. Then the reverse process is defined as
\[p_\theta^{(t)}\left(\xtmone \vert \xt\right) = \left\{ \begin{array}{ll} \norm{f_\theta^{(1)}\left(\xtt{1}\right), \sigma_1^2\II} & \text{if }{t=1} \\ q_\sigma\left(\xtmone \vert \xt, f_\theta^{(t)}\left(\xt\right)\right) & \text{otherwise} \end{array} \right.\]To sample from the reverse process you started off with a noise sample $\xtt{T} \sim \norm{\mathbf{0}, \II}$, successively sampling
\[\xtmone = \sqrt{\frac{\abtmone}{\abt}}\left(\xt - \sqrt{1 - \abtmone}\etta\left(\xt, t\right)\right) + \sqrt{1 - \abtmone - \sigma_t^2}\etta\left(\xt, t\right) + \sigma_t\epsilon_t\]The hyperparameter $\sigma_t$ controls the stochasticity of the process:
In fact we can use any forward process provided that the marginals match. This allows the use of a forward process defined only on a subset of the latent variables
\[\xtt{\tau_1} \ldots \xtt{\tau_t}\]where $\tau_1 \ldots \tau_t$ is an increasing sub sequence of $1 \ldots T$ with of length $S$ where $S$ could be much smaller that $T$. The corresponding reverse process goes backwards through the timesteps
\[\xtt{ \tau_{t - 1}} = \sqrt{\frac{\abtautmone}{\abtaut}}\left(\xtt{\tau_t} - \sqrt{1 - \abtaut}\etta\left(\xtt{\tau_t}, \tau_t\right)\right) + \sqrt{1 - \abtautmone - \sigma_{\tau_t}^2}\etta\left(\xtt{\tau_t}, \tau_t\right) + \sigma_{\tau_t}\epsilon_t \\\\\\ t = 1 \ldots S\]Another key advantage of this approach is that you can simply use an existing diffusion model trained on the DDPM objective and don’t need to do further training. In the DDIM paper the reasoning for this is as follows
However in the DDPM paper the parameters are shared across timesteps which means that $\sum_{t=1}^T\Ef{\xz, \boldeps}{L_t}$ is only an approximation to the loss. Consequently the reasoning above does not strictly hold with regard to the actual implementation. Nevertheless, as with so many things in deep learning, using a pre-trained model works in practice.
In the paper they experiment with a linear and quadratic sequence for the timesteps. Here we will focus on the linear sequence. In the paper this is defined as $\lfloor ci \rfloor$. We will use the approach in the improved_diffusion
codebase
def space_timesteps(
num_sample_steps = 50,
total_steps = 1000
):
"""
Adapted from `improved_diffusion.respace.space_timesteps`
"""
for i in range(1, total_steps):
if len(range(0, total_steps, i)) == num_sample_steps:
return list(range(0, total_steps, i))[::-1]
raise ValueError(f'cannot create exactly {total_steps} steps with an integer stride')
To sample you would loop through pairs $(\tau_t, \tau_{t-1})$
timesteps = get_timesteps(S, T)
tau = timesteps[:-1]
tau_prev = tau[1:]
for t, tau_t, tau_tm1 in zip(range(len(tau), 0, -1), tau, tau_prev):
## CODE FOR SAMPLING GOES HERE
When the model has been trained, we can sample new data points by first sampling random noise $\xtt{\tau_t}$, then successively sampling latents at points $\tau_t$ along the sequence, $\xtt{\tau_t}$. As you can see in the figure below, if $\tau_t < T$ provided it is large enough $\xtt{\tau_t}$ will be very noisy and reasonably similar to $\xtt{T} \sim \norm{\mathbf{0}, \II}$ so random noise can be used as the initial value.
Assume we have $\xtt{\tau_t}$. To start with we sample $\xtt{\tau_t} \sim \norm{\mathbf{0}, \II}$
def p_sample_step(self, t, tau_t, tau_tm1, xtau_t):
Predict $\etta\left(\xtt{\tau_t}, \tau_t\right)$
batch_shape = tf.shape(xtau_t)
eps_theta = self.model(
self.get_input(xtau_t, tau_t), training=False
)
Transform this to $\xz = \frac{1}{\abtaut}\left(\xtt{\tau_t} - \sqrt{1 -\abtaut}\etta\left(\xtt{\tau_t}, \tau_t\right)\right)$. Clip to the range $[-1, 1]$ to get the prediction for $\xz$ that will be returned.
predicted_x0 = (
xtau_t - self.select_timestep(self.sqrt_1_m_alpha_bar, tau_t, xtau_t) * eps_theta
) / self.select_timestep(self.sqrt_alpha_bar, tau_t, xtau_t)
x0 = tf.clip_by_value(predicted_x0, -1, 1)
$\def\dirxt{\sqrt{1 - \abtmone - \sigma_{\tau_t}^2}\cdot\etta\left(\xtt{\tau_t}, \tau_t\right)}$
Calculate direction pointing to $\xt$
\[\dirxt\]
sigma_tau_t = self.select_timestep(self.sigma, tau_t, xtau_t)
xtau_t_dir = tf.math.sqrt(
1 - self.select_timestep(
self.alpha_bar_prev, tau_tm1 + 1, xtau_t
) - sigma_tau_t ** 2
) * eps_theta
$\def\predxz{\underbrace{\frac{\xt - \sqrt{1 - \abt}\etta\left(\xtt{\tau_t}, \tau_t\right)}{\abtaut}}_\text{predicted $\xz$}}$
For $t > 1$, sample $\boldeps \sim \norm{\mathbf{0},\II}$ and return
\[\xtt{\tau_{t-1}} = \frac{1}{\abtautmone}\predxz + \underbrace{\dirxt}_\text{direction pointing to $\xtt{\tau_t}$} +\sigma_{\tau_t}\epsilon_t\]which is a sample from $\qrcond{\xtt{\tau_{t-1}}}{\xz}$. At the final step, $t=1$, just return $\xz$.
z = tf.cond(
tf.greater(t, 1),
lambda: tf.random.normal(shape=batch_shape),
lambda: tf.zeros(batch_shape)
)
xtau_tm1 = self.select_timestep(
self.sqrt_alpha_bar_prev, tau_tm1 + 1, xtau_t
) * predicted_x0 + xtau_t_dir + sigma_tau_t * z
return x0, xtau_tm1
First let us look at some DDIM samples for CIFAR-10. The figure below shows samples with different initial values and visually there isn’t much of a difference in quality between 50 and 1000 and even using only 10 steps still yields tolerable images.
The images below were generated using the same initial random values. Here we see that just 50 steps is enough to get results that look not too different from 1000 steps. Using just 10 steps leads to much more blurry results.
This trend is more evident when we look at a single image
where the first example is blurry after ten steps whilst the second looks like an unfinished painting with most of the horse’s body absent except for some shadows and a slighlty pale region where its front hooves should be.
What about quantitative performance? I enourage you to refer to the paper for detailed information, figures and tables. But in summary, DDIM typically outperforms DDPM when only a small number of timesteps are used, 100 or less. However with 1000 timesteps, the same as training, DDPM does the best. Datasets consider include CIFAR-10, CelebA, LSUN. For example for CIFAR10, Frechet Inception Distance (FID) for DDIM and DDPM is 4.16 and 9.99 respectively with 100 sampling steps. However at a 1000 steps the score for DDIM only rises a little to 4.04 whilst for DDPM it plummets to 3.17.
In the paper “Improved Denoising Diffusion Probabilistic Models” (Improved Diffusion) they show that you can actually get good results simply by sampling at points along a strided sequence $\tau$ from $1 \ldots T$ inclusive using hyperparameters indexed by $\tau_t$ and $\tau_{t-1}$ instead of $t$ and $t-1$ at step $t$. In the improved diffusion paper they introduce the following extensions to the original model (among others)
Accordingly when sampling at strided timesteps and using learned variance, $\beta_{\tau_t}$ and $\tilde{\beta}_{\tau_t}$ are derived as
\begin{align} \btaut = 1 - \frac{\abtaut}{\abtautmone}, && \tbtaut = \frac{1 - \abtaut}{1 - \abtautmone}\btaut \end{align}
They use a simple strided timestep sequence as follows
To reduce the number of sampling steps from $T$ to $K$, we use $K$ evenly spaced real numbers between 1 and $T$ (inclusive), and then round each resulting number to the nearest integer
which can be implemented as follows
timesteps = np.linspace(1, T, K).round().astype('int')
timesteps = np.unique(timesteps)[::-1]
On the models developed in the Improved Diffusion paper, this method actually outperforms DDIM except when using fewer than 50 steps. At around 100 steps strided sampling leads to FID close to the optimal value. The paper claims that models with fixed variance
suffer much more in sample quality when using a reduced number of sampling steps
for both $\tbt$ and $\bt$. However the plots in the paper show that except for CIFAR-10, the difference in FID after around 100 steps is not large, when the fixed variance is $\tbt$ rather than $\bt$.
]]>