Building a GAN From Scratch With PyTorch | Theory + Implementation
31:58

Building a GAN From Scratch With PyTorch | Theory + Implementation

AssemblyAI 10.03.2022 94 724 просмотров 1 738 лайков

Machine-readable: Markdown · JSON API · Site index

Поделиться Telegram VK Бот
Транскрипт Скачать .md
Анализ с AI
Описание видео
Learn how to create a GAN (Generative Adversarial Network) from scratch with PyTorch and PyTorch Lightning. Colab with starter code: https://colab.research.google.com/drive/1246XGJbtbqQ4pb86tGk64hZAF3_bV6hM?usp=sharing Get your Free Token for AssemblyAI Speech-To-Text API 👇https://www.assemblyai.com/?utm_source=youtube&utm_medium=referral&utm_campaign=yt_pat_20 Resources: https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/ https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html https://developers.google.com/machine-learning/gan/loss https://neptune.ai/blog/gan-loss-functions ▬▬▬▬▬▬▬▬▬▬▬▬ CONNECT ▬▬▬▬▬▬▬▬▬▬▬▬ 🖥️ Website: https://www.assemblyai.com 🐦 Twitter: https://twitter.com/AssemblyAI 🦾 Discord: https://discord.gg/Cd8MyVJAXd ▶️ Subscribe: https://www.youtube.com/c/AssemblyAI?sub_confirmation=1 🔥 We're hiring! Check our open roles: https://www.assemblyai.com/careers ▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬ Timestamps: 00:00 Intro 00:31 GAN Theory 03:29 PyTorch Lightning Setup 06:42 Generator + Discriminator 10:21 GAN Training 28:50 Testing #MachineLearning #DeepLearning

Оглавление (6 сегментов)

Intro

hi everyone i'm patrick from the assembly ai team and today we learn about generative adversarial networks or short gans so you might have seen this popular example where gans generate fake images of humans and they look incredibly real gans are indeed really powerful and are one of the most fascinating ideas in deep learning in recent years so today we have a quick look at the theory behind gans and then we code one from scratch using pytorch so let's get started alright so let's look at the

GAN Theory

theory first and i promise that this won't be too difficult because the idea is actually brilliant it's simple but super powerful so gans learn to generate new data with the same statistics as the training set and gans consist of two networks playing an adversarial game against each other that's why the name is generative adversarial networks so the goal is to generate fake data that is as close as possible to the training data and then we have these two networks that play a game against each other so how does this work exactly so the two networks are called the generator and the discriminator and the generator produces fake data and tries to trick the discriminator and the discriminator inspects the fake data and then determines if it's real or fake so this is like a detective and then they play against each other so this is basically the training so first they are initialized randomly and then they are trained simultaneously and this means we have to minimize two losses so we also use two optimizer and we use the binary cross entropy loss so i'm not going into detail here about the loss formula but i will link another resource below the video if you want to learn more about this and yeah this is basically the whole concept and now before we jump to the code let's look at an example so later we use the mnist data set so now the generator tries to generate mnist images so these digits from zero to nine and then tries to trick the discriminator and then they play against each other and both sides get better and in the beginning they don't know anything so they are randomly initialized so the generator just produces noise so random data that might look like this and then the discriminator looks at this and can also look at real data and compare it and then it might easily say yeah this is fake but then the learning or the training continues and then the generator comes up with new data and this might look something like this and then again the discriminator looks at it and it can simply still say yes still fake but then at some point the generator gets better and then the discriminator might be tricked and say yeah this is actually real data so obviously in this example the data is still not perfect but of course also the discriminator is still not perfect so it also has to improve and then this continues and both sides get better and eventually we get a or should get a generated data that should not be easily distinguishable from the original data and yeah this is how it works so now let's jump to the code and see how we can implement this alright so here

PyTorch Lightning Setup

i'm in a google collab and i already prepared some code for the start but later we will write the rest of the code together so i recommend that you just copy this collab into your own folder and then follow me here so i will put the link in the description and the first thing you should do is set the runtime so here we can click on change runtime type and select a gpu then the training will be much faster and then the first thing we do is we say pip install pi torch lightning because we also use pi torch lightning here to make the code a little bit shorter and i also want to mention that there is an official gan tutorial on the pytorch lightning website that is pretty similar to my code but in this case they just use normal linear layers and in our case we use cnns so it's a little bit different but yeah this is also a great resource that you should check out so yeah we install this here then here we do all the imports we need so torch and torch vision and the different modules then we also use pi matplotlib and pi torch lightning and then here we set some parameters for example a random seed the batch size we check if we have gpus and if we have multiple cpu cores and then the first thing we do is we create a data module that inherits from pytorch lightning dot lightning data module so this is responsible to create the data loaders for us for the training validation and test set and the way this works is that we have to implement the init function so here we simply define the parameters and we also define the transformations so here we convert from images to a tensor and then we normalize it so this is the mean and the standard deviation from the mnist images then we have to implement the prepared data function so here we call the built in mnist data and we can set download equals true and then this will download the data and one time we say training equals true for the training set and then one time training equals false for the test set then we also implement the setup function so here we further split the training data into training and validation set by applying this random split function and yeah for the test set we simply also create the mnist data by saying training equals false and in all steps we already apply the transformations and then we can implement these three functions the train data loader the validation data loader and the test data loader and they are all pretty simple so they just return the data loader with the corresponding data set and then we can also set the batch size and the number of workers so yeah this is what we have to do to get the data loaders

Generator + Discriminator

and then we have to implement the two networks so the discriminator and the generator and they are both vanilla pie charge modules so we can mix and match pie charge and pie torch lightning so for the discriminator we create a class that inherits from nn. module and like i said the discriminator is like a detective so it has the task to detect if it's fake or no fake so in the end we only need one output that must be between 0 and 1 and we only have to implement the init function and the forward function and basically here you could do whatever we want but in the end we only want one output that is between 0 and one so you could just use linear layers but in this example i want to show you how we can use c and n's so we use two convolutional 2d layers then we also use a dropout layer in the 2d case and in the end we use two linear layers and yeah we always have to be careful that we have the correct input and output sizes and then at the very end we need one output so yeah in the init function we simply create all the layers and then in the forward function we apply all layers and we also apply max pooling and activation functions so here the relu activation function so yeah first the first convolution with max pool and relu then the second convolution with a drop out max pool and relu then we reshape the data so we can feed it to the fully connected layers then we apply the first fully connected layer so the linear layer with an actuation function a drop out and the second fully connected layer and then in the very end we apply the sigmoid function so this takes care that the output is between 0 and 1. so this is our discriminator and now the generator basically works the other way around so we also create a class that inherits from nn. module and here this gets as a parameter the number of latent dimensions so this is a scalar value and basically from this value we upsample to a output that is in the shape that the original images is so 1 by 28 and the values are between -1 and 1 here so here we also use a linear layer then we use a convolution transpose 2d layer so two times actually and then in the end we apply a normal convolution to put it back again into this shape and then in the forward pass we apply all these steps so first we apply the linear layer with a relu activation then we reshape then we apply the convolution transposed layer this will up sample the data to be in this shape then again a activation function then the second convolution transposed layer this will put it into this shape then again a actuation function and at the very end a normal convolution which will put it back into this shape and then no activation at the end so yeah this are these are the two uh networks that we need and now we put them together in one class so let's do

GAN Training

this together so here we want to create one class that we call again and this inherits from pytorch lightning dot lightning module and then we have to implement a few functions so that it works with pytorch lightning so first of all we need the init function with self and then we also give it the number of latent dimensions and this should be 100 by default then we also give it a learning rate and by default this is zero two so you might actually want to play around with this a little bit so it's very the learning is very sensitive to the learning rate so it's very important that you have find a good one here so yeah by default just try this one then the first thing we do we call the super initializer so super in it like this and then here we store or first save the hyper parameters so we say self dot safe hyper parameters this will make it accessible everywhere later and then we create our two networks so we say self. generator equals the generator network and this gets one parameter this is the latent dimension equals and now we can access the hyper parameter by saying self dot hp params. and now we can use latent dim so this is the same um name that we use here for the parameter name so this is our generator and then we do the same thing for discriminator so we initialize a discriminator by saying this is a this criminator and this doesn't get any inputs and then we create a random noise that we want to use later to test the images so we say self dot validation c equals torch dot rant n and then it should be six images and we use self. hp params. latent underscore dim for the second size and yeah this is the init function then we need to implement the forward function which gets self and the input tensor so we call this c and the forward pass in the gan is actually just the generator so here we can say return self dot jenna raider and this gets the input the tensor as a input so this is the forward pass then let's create a function to define the loss so we could call this adversarial underscore loss also with self and then it gets y hat so the predicted labels and the actual labels and here we as i said in the beginning we use the binary cross entropy so this is just a one liner so we say return f dot binary cross entropy and then here we feed it with y hat and y then we need the training steps so define training step so this is actually a function from pi charge lightning that we have to implement and then everything later will be taken care of for us so this gets self then it also gets a batch then it gets the batch index and it also gets the optimizer index and for now let's just say pass so we do this as last step then the next function we have to implement from pi torch lightning is the configure optimizer function and this only gets self and it's actually called configure optimizer s optimizers and here we can access the learning rate by saying lr equals self. h params. lr and then lr and then we create two optimizers one for the generator and one for the discriminators so we say opt g equals and here we say torch dot optim dot adam for example which is always a good choice and then we want to optimize self dot generator dot parameters and we also give it the learning rate lr equals l r then we copy and paste this and do the same thing for the optimizer for the discriminator so opt d and here we optimize self. discriminator. parameters and we use the same learning rate and then we return those so we say return and then as a list opt underscore g and opt underscore d and as second parameter we return a empty list so this would be if you use a scheduler as well and then there is one more function that we want to implement for the pytorch lightning module so this is called the on epoc end function this also only gets self and here after each epoch i simply want to plot a few uh generated images to see how our fake data looks like so here i call self. plot images and this is a function that i simply copy and paste for you so let me copy and paste this in here so this will be one function and what we do here is here we use this validation c that we created up here so this is a random noise with six images and this number of latent dimensions so here we use this and then we also use this type s function this will basically move it to the gpu or not if we don't use a gpu then we call self. c so self. c will execute the forward pass so this will execute the generator so this will generate some images and then we put it back to the cpu and then here i have a small code snippet to generate some images with matplotlib and then plot this as a grid so yeah this is what we do after each epoch and now the only thing left to do is to implement the training steps so let's do this so here um first of all here i noticed i have a typo so this is just self. h params and yeah now let's do this so here first we want to unpack the batch so we call this the real images and then just an underscore because we don't need the labels and this comes from the batch and then again we create a sample noise data so here we call this c equals torch dot rand n and now it gets the shape of the real m e m g s and then the first dimension zero and a second um number it gets the number of latent dims so self. h params. latent and then we again have to move it to the gpu if we use one so we say c equals c dot type underscore s and then we want to use the same type as the image tensor and then we make a if statement for both optimizers so in the first um case we want to train the generator so the first optimizer opt g so here we check if now optimizer underscore index equals zero then we train the generator and here i will write the formula for you so here we want to maximize the lock off and here we use the discriminator and then the generator of c and c will be fake images so yeah again i will put a reference below the video where you can read more about the loss function but yeah this is basically the formula together with this binary cross entropy loss so let's do this so first we want to call the generator with c so we and this will be the fake images so here we can call self dot c so again self dot c will execute the forward pass which will execute the generator so this is this part then we need the discriminator and this will be the y hat so the predicted labels so here we call self dot this criminator and we put in the fake images and now as real wise so as real labels we say y equals and this will be a tensor of simply once and here again we say real images dot size zero and as second thing we just say one and then again we have to move it to the gpu so again we can use um type s so here we say y equals y type as the real images and then we create the or calculate the loss so we call this g loss equals self dot adversarial loss so the function and this gets y hat and y and then we return this so pytorch lightning knows what to do with this loss so we create a lock dictionary and this is a dictionary with the value with the key g loss and we put in the g loss in here and then we return a another full dictionary that this that has this included so first of all we need to have the key loss in here and the loss will be the g loss then we can also say the progress underscore bar gets also the lock dictionary and we and for logging if we want to use this later for example in the tensorboard then we can also use this key and this also gets the log dictionary so yeah this is what we have to do for the generator and then we do a similar things for the um discriminator so if optimizer index equals one and let's first write a comment here so here we want to train the discriminator and now again i will say what we have to maximize so here we want to maximize the lock of d of x and x will be the original images plus and then here the lock off one minus d of g of c so um let's go over this and then it might become clearer what this means so here we want to measure the discriminator's ability to classify a real from generated images so we want to check how well can it label as real so and then also how well can it label the generated images as fake and these are essentially those two parts so for the first one um the code is very similar to this one so we create y hat real equals and then we again self dot this criminator and now we don't put in fake images but here we put in real the real images and this is the y head reel and the normal y reel so the actual labels in this case it's also just torch dot once like we use here and then again we have to move it to the gpu if we use one so why real equals y real type is real images and then we calculate the real loss equals self dot adversarial loss of why had real and y hat so this is the first part so this is very similar to here and now we need this part um so here y um real and now how well can it label as fake so here we call this y head fake and then again self dot this criminator of self with c and now we have to be careful so um self dot c will basically generate fake images and we already did this here and we want to use the same ones here but when we execute this then this will do calculations on the graph and we don't want to do this twice so that's why here we call c self c dot d touch so this will create a new tensor that is detached from the computational graph so be careful here and yeah this is our y head fake and then the y fake equals and this is very similar to here so then again we want to move this to the gpu so y fake equals y fake and now we have to be careful so now this part says 1 minus this so actually if we have a look at the binary cross entropy formula then we find out that if we put in zeros here then we actually end up with this formula so here we put in torch dot zeros so yeah again i recommend that you just double check this for yourself and then we also calculate the fake loss again with self dot adversarial loss and here we put in y head fake and why fake and then the total loss is the average so we call this the d loss so the discriminator loss equals and here we say real loss plus the fake loss divided by two and then the same as we do here we create a log dictionary and return this and in this case we call this d loss and use d loss here and also here and now this is everything that we need for the training step so let's create a new code cell and now we want to set everything up so we create a data module by saying this is the mnist data module that we created then we create a model so this is our gan then again um

Testing

and now before i want to train this i actually want to call this plot images once so that we see that in the beginning we simply plot noise so let's say model dot plot images and then let's run everything and see if it works until here and yes so here we have the first generated images and as we can see this is just noise and now we want to train our gans so we call a we create a pytorch lightning dot trainer and here we can specify the maximum number of epochs equals let's use 20 and then we can also use the gpus equals the max um what would we call it in the beginning max no available gpus so let's put this also into our trainer and then we simply say trainer dot fit the model and we also put in the data module and then we run this and when we execute this you might see that after each epoch we should print the images so let's do this and we get a type error so in this part i forgot to call shape so here in the training step we must say real images dot shape and then zero so now again we have to run this cell then down here we again run this cell let's again plot this should be the same random noise and now let's run the trainer again and training is done so now we can scroll down and see after the first epoch so this is actually yeah after epoch zero we get this images so it still looks like random noise epoch one we have this images then this it still looks like noise but then after epoch 4 it starts getting better and it's starting to look like images so now let's scroll down a little bit and yeah so here it still looks a bit noisy but for example this might be a three four this might be a zero so yeah it's starting to get into shape and yeah our code works and yeah i actually recommend you that you play around with the hyper parameters a little bit and also maybe increase the maximum epochs and then test this for yourself again the call up will be the link to in the description below and yeah i hope you enjoyed this video if you did so then please hit the like button and consider subscribing to the channel and then i hope to see you next time bye

Другие видео автора — AssemblyAI

Ctrl+V

Экстракт Знаний в Telegram

Экстракты и дистилляты из лучших YouTube-каналов — сразу после публикации.

Подписаться

Дайджест Экстрактов

Лучшие методички за неделю — каждый понедельник