Text-to-image conversion has always fascinated me, and the recent AttnGAN paper caught my attention. In this post I try to provide an intuition for their work, and hopefully get you curious enough to dig further :-).
Before we get to the actual model, some prerequisities:
To avoid reinventing the wheel (and promote my own work ofcourse), take a look at my previous post where I provide a small intro to Attention in Deep Learning.
Simply put, a GAN is a combination of two networks: A Generator (the one who produces interesting data from noise), and a Discriminator (the one who detects fake data fabricated by the Generator). The duo is trained iteratively:
- The Discriminator is taught to distinguish real data (Images/Text whatever) from that created by the Generator. At this step, the Generator is not being trained — only the Discriminator’s ‘detective’ skills are improved.
- The Generator is trained to produce data that can sufficiently fool the (now-improved) Discriminator. The random input ensures that the Generator keeps coming up with novel data everytime — essentially acting as inspiration.
The key insight is in the dual-objective: As (and because) the discriminator becomes a better detective, the generator becomes a better faking-artist. After a sufficient number of epochs, the Generator can create surprisingly realistic images!
Now, coming to ‘AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks’. The contributions of the paper can be divided into two parts:
Part 1: Multi-stage Image Refinement (the AttnGAN)
The Attentional Generative Adversarial Network (or AttnGAN) begins with a crude, low-res image, and then improves it over multiple steps to come up with a final image.
Lets start with the first stage:
Like most other Text-to-Image convertors, AttnGAN starts off by generating an image from (random noise + a summation of the caption’s token-embeddings):
h(0) = F(0)(z, E)
Here, z represents the noise-input, and E represents the sum of individual word-vectors. h(0) denotes as the ‘hidden context’ — essentially, AttnGAN’s concept of what the image should look like. Based on h(0), we generate x(0) — the first image — using a GAN:
x(0) = G(0)(h(0))
(Corresponding to the Generator G(0) we also have the Discriminator D(0), which we will talk about later.)
An example of x(0) straight fromthe paper:
Caption: “This bird has a green crown black primaries and a white belly”
One of the issues with generating an image from a combined ‘sentence’ vector (E above), is that we lose a lot of the fine-grained details hidden in individual words.
For instance, look at the example above: When you combine (green+crown+white+belly) into a ‘bag-of-words’, you are much, much less likely to understand the actual colors of the crown & belly — hence the hazy coloring in the generated image.
To remedy this, AttnGAN uses a combination of Attention & GAN at every stage, to iteratively add details to the image:
h(i) = F(i)(h(i-1), Attn([e], h(i-1)))
x(i) = G(i)(h(i))
h(1), h(2), … follow the template above.
Compare these to the initial equations:
- z gets replaced by the previous context h(i-1).
- [e] denotes the set of all word-embeddings in the sentence. Using Attention based on h(i-1), we compute a weighted average of [e] ( Attn([e], h(i-1)) ) to highlight words that need more detail.
- Based on this weighted vector, F(i) alters h(i-1) to yield h(i).
- As usual, a GAN is then used to produce x(i) from h(i).
Continuing with the previous example:
Top attended words for h(1): bird, this, has, belly, white
Top attended words for h(2): black, green, white, this, bird
Consider the words for h(2). You can literally see x(2) being a more colorful version of x(1).
Ofcourse the results aren’t always so pretty, but its a step in the right direction for optimizing the correct objectives. Which brings us to…
Part 2: Multi-modal loss
At this point, it will be good for you to go through the high-level diagram of the system, given in the paper:
Lets consider the parts we haven’t touched upon as yet.
Looking at the equations for h & x, it is natural to wonder why we need the x’s at all, except at the last step. For example, x(0) does not appear in the equations for h(1) & x(1)!
The reason is — training. In the learning phase, the D’s are trained with scaled-down versions of real image-caption examples (from a dataset like COCO). This makes the G’s better at generating x’s from the h’s. By rules of back-propagation, this makes the F functions better at generating the hidden contexts — thereby ensuring that each stage adds something meaningful to the image.
The Deep Attentional Multimodal Similarity Model (DAMSM)
After the concept of multi-stage image refinement, I think this is the second key feature of this framework.
While the individual discriminators do make the system better, we do not yet have an objective that checks if every single word in the caption is appropriately represented in the actual image (the discriminators are trained on the overall caption E & scaled-down image pairs).
To encode this task effectively, we first train an ‘expert’ of sorts — the DAMSM. DAMSM takes as input an image and the set [e], and provides feedback on how well the two ‘fit’ together. It does this as follows:
- Using a standard Convolutional Neural Network, the image is converted into a set of feature maps. Each feature map essentially signifies some concept/sub-region in the image.
- The dimensionality of the feature maps is made equal to that of the word embeddings, so that they can be treated as equivalent entities.
- Based on each token in the caption, Attention is applied over the feature maps, to compute a weighted average of them. This attention-vector essentially represents the image’s abstraction of the token.
- Finally, DAMSM is trained to minimize the difference between the above attention-vector (visual portrayal of the word) & the word-embedding (textual meaning of the word). You are basically trying to make the ‘green’ part of the image as ‘green’ as possible.
The reason DAMSM is called ‘multimodal’, is because it defines an objective that combines two modes of understanding — visual & textual.
Once DAMSM has been sufficiently trained on a dataset, it can then be used in conjunction with the step-wise discriminators, to provide a rich target for AttnGAN to optimize.
There are quite a few concepts from the paper that I have skipped in this post, such as Concept Augmentation, using BiRNNs for text encoding, etc. As always, I would suggest you read the original paper if you want to get into the finer details.
Thank you for reading 🙂
Gurupriyan is a Software Engineer and a technology enthusiast, he’s been working on the field for the last 6 years. Currently focusing on mobile app development and IoT.