tensorflow.keras builds gan neural network, which can run directly
preface
keras is one of the high-level API libraries of tensorflow, with concise code and strong readability. This paper adopts tensorflow keras to implement gan network. The specific principle will not be described too much in this paper, but only as a case exchange#####keras Chinese reference documents
text
1, TF General steps for keras to build gan network
1. First, we need to replace all image data with tfrecords provided by tensorflow and use create_ tfrecords. Py file can be generated (this file is the script file I used to generate the label for image classification. If you do gan network, you don't need to save the label)
2. Use the generated tfrecords file to establish the data set, and use TF data. Tfrecorddataset. This paper also provides another method to obtain tfrecords data, but the same goal is achieved by different ways, and the methods are the same
3. Build generator network
4. Build discriminator network and integrate it into gan network (it is necessary to set discriminator network as untrainable before gan network compile)
5. Establish a loop to train generator network and discriminator network respectively
6. Save the network model
2, Use steps
1. Make tfrecords data set
creat_tfrecords.py
The default location for generating tfrecords is filename_train="./data/train.tfrecords"
Terminal input: Python create_ tfrecords. Py -- data [dataset location]
Generate train Tfrecords, you can also add the data of verification set and test set yourself
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import os from PIL import Image import random objects = ['cat','dog']#'cat'0,'dog'1 filename_train="./data/train.tfrecords" writer_train= tf.python_io.TFRecordWriter(filename_train) tf.app.flags.DEFINE_string( 'data', 'None', 'where the datas?.') FLAGS = tf.app.flags.FLAGS if(FLAGS.data == None): os._exit(0) dim = (224,224) object_path = FLAGS.data total = os.listdir(object_path) for index in total: img_path=os.path.join(object_path,index) img=Image.open(img_path) img=img.resize(dim) img_raw=img.tobytes() for i in range(len(objects)): if objects[i] in index: value = i else: continue example = tf.train.Example(features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) print([index,value]) writer_train.write(example.SerializeToString()) #Serialize to string writer_train.close()
2. Read in data
Using TF data. Tfrecorddataset creation
The code is as follows: (the load_image function is used as the input of the map to decode the dataset), and it is called in the main function:
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
def load_image(serialized_example): features={ 'label': tf.io.FixedLenFeature([], tf.int64), 'img_raw' : tf.io.FixedLenFeature([], tf.string)} parsed_example = tf.io.parse_example(serialized_example,features) image = tf.decode_raw(parsed_example['img_raw'],tf.uint8) image = tf.reshape(image,[-1,224,224,3]) image = tf.cast(image,tf.float32)*(1./255) label = tf.cast(parsed_example['label'], tf.int32) label = tf.reshape(label,[-1,1]) return image,label def dataset_tfrecords(tfrecords_path,use_keras_fit=True): #Whether to use TF keras if use_keras_fit: epochs_data = 1 else: epochs_data = epochs dataset = tf.data.TFRecordDataset([tfrecords_path])#This can have multiple components [tfrecords_name1,tfrecords_name2,...], You can use OS listdir(tfrecords_path): dataset = dataset\ .repeat(epochs_data)\ .batch(batch_size)\ .map(load_image,num_parallel_calls = 2)\ .shuffle(1000) iter = dataset.make_one_shot_iterator()#make_initialization_iterator train_datas = iter.get_next() #Use train_ Data [0], [1] return train_datas,iter
3. Build gan network
a. Build generator network
generator = keras.models.Sequential([ #fullyconnected nets keras.layers.Dense(256,activation='selu',input_shape=[coding_size]), keras.layers.Dense(64,activation='selu'), keras.layers.Dense(256,activation='selu'), keras.layers.Dense(1024,activation='selu'), keras.layers.Dense(7*7*64,activation='selu'), keras.layers.Reshape([7,7,64]), #7*7*64 #deconvolution keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'), #14*14*64 keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'), #28*28*64 keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'), #56*56*32 keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'), #112*112*16 keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#Use tanh instead of sigmoid #224*224*3 keras.layers.Reshape([224,224,3]) ])
b. Build discriminator network
discriminator = keras.models.Sequential([ keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]), keras.layers.MaxPool2D(pool_size=2), #56*56*128 keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'), keras.layers.MaxPool2D(pool_size=2), #14*14*64 keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'), #7*7*32 keras.layers.Flatten(), #dropout 0.4 keras.layers.Dropout(0.4), keras.layers.Dense(512,activation='selu'), keras.layers.Dropout(0.4), keras.layers.Dense(64,activation='selu'), keras.layers.Dropout(0.4), #the last net keras.layers.Dense(1,activation='sigmoid') ])
c. Integrate the generator and discriminator network into gan network
gan = keras.models.Sequential([generator,discriminator])
4. Compile (establish loss and optimizer)
#compile the net discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy']) discriminator.trainable=False gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
5. Training network (establish cycle)
Get dataset:
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
Loop body: (use cv2 to check the generator network)
generator,discriminator = gan.layers sess = tf.Session() for step in range(num_steps): #get the time start_time = time.time() #phase 1 - training the discriminator noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size]) noise = np.cast[np.float32](noise) generated_images = generator.predict(noise) train_datas_ = sess.run(train_datas) x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate #Never recycle tf Concat cannot be defined in the loop body with tf related functions #Otherwise, the memory will be exhausted and the training speed will be slower and slower y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size) discriminator.trainable = True dis_loss = discriminator.train_on_batch(x_fake_and_real,y1) #Train of keras_ on_ It is wise to put the batch function in the gan network #phase 2 - training the generator noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size]) noise = np.cast[np.float32](noise) y2 = np.array([[1.]]*batch_size) discriminator.trainable = False ad_loss = gan.train_on_batch(noise,y2) duration = time.time()-start_time if step % 5 == 0: #gan.save_weights('gan.h5') print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ') print('%.2f s/step'%(duration)) if step % 30 == 0 and step != 0: noise = np.random.normal(size=[1,coding_size]) noise = np.cast[np.float32](noise) fake_image = generator.predict(noise,steps=1) #Restore image #1. After multiplying by 255, it needs to be mapped to the type of uint8 #2. The float32 type of [0,1] can also be maintained and can still be output directly arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255 arr_img = np.cast[np.uint8](arr_img) #Saved as tfrecords using PIL Image, that is, it is opened as RGB, so it needs to be converted to BGR when displaying with cv arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR) cv2.imshow('fake image',arr_img) cv2.waitKey(1500)#show the fake image 1.5s cv2.destroyAllWindows()
6. Save the network
#save the models model_vision = '0001' model_name = 'gans' model_path = os.path.join(model_name,model_name) tf.saved_model.save(gan,model_path)
7. Complete Gans Py (operational)
# -*- coding: utf-8 -*- ''' @author:zyl author is zouyuelin a Master of Tianjin University(TJU) ''' import tensorflow as tf from tensorflow import keras #tf.enable_eager_execution() import numpy as np from PIL import Image import os import cv2 import time batch_size = 32 epochs = 120 num_steps = 2000 coding_size = 30 tfrecords_path = 'data/train.tfrecords' #--------------------------------------datasetTfrecord---------------- def load_image(serialized_example): features={ 'label': tf.io.FixedLenFeature([], tf.int64), 'img_raw' : tf.io.FixedLenFeature([], tf.string)} parsed_example = tf.io.parse_example(serialized_example,features) image = tf.decode_raw(parsed_example['img_raw'],tf.uint8) image = tf.reshape(image,[-1,224,224,3]) image = tf.cast(image,tf.float32)*(1./255) label = tf.cast(parsed_example['label'], tf.int32) label = tf.reshape(label,[-1,1]) return image,label def dataset_tfrecords(tfrecords_path,use_keras_fit=True): #Whether to use TF keras if use_keras_fit: epochs_data = 1 else: epochs_data = epochs dataset = tf.data.TFRecordDataset([tfrecords_path])#This can have multiple components [tfrecords_name1,tfrecords_name2,...], You can use OS listdir(tfrecords_path): dataset = dataset\ .repeat(epochs_data)\ .batch(batch_size)\ .map(load_image,num_parallel_calls = 2)\ .shuffle(1000) iter = dataset.make_one_shot_iterator()#make_initialization_iterator train_datas = iter.get_next() #Use train_ Data [0], [1] return train_datas,iter #------------------------------------tf.TFRecordReader----------------- def read_and_decode(tfrecords_path): #Generate a queue based on the file name filename_queue = tf.train.string_input_producer([tfrecords_path],shuffle=True) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example,features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string)}) image = tf.decode_raw(features['img_raw'], tf.uint8) image = tf.reshape(image,[224,224,3])#reshape 200*200*3 image = tf.cast(image,tf.float32)*(1./255)#The image tensor can be divided by 255, * (1. / 255) label = tf.cast(features['label'], tf.int32) img_batch, label_batch = tf.train.shuffle_batch([image,label], batch_size=batch_size, num_threads=4, capacity= 640, min_after_dequeue=5) return [img_batch,label_batch] #Autodecode decoder def autoencode(): encoder = keras.models.Sequential([ keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]), #112*112*32 keras.layers.MaxPool2D(pool_size=2), #56*56*32 keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'), #28*28*64 keras.layers.MaxPool2D(pool_size=2), #14*14*64 keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu'), #7*7*128 #deconvolution keras.layers.Conv2DTranspose(128,kernel_size=3,strides=2,padding='same',activation='selu'), #14*14*128 keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'), #28*28*64 keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'), #56*56*32 keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'), #112*112*16 keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#Use tanh instead of sigmoid #224*224*3 keras.layers.Reshape([224,224,3]) ]) return encoder def training_keras(): ''' Convolution and pooling output formula: output_size = (input_size-kernel_size+2*padding)/strides+1 keras Deconvolution output calculation, generally not out_padding 1.if padding = 'valid': output_size = (input_size - 1)*strides + kernel_size 2.if padding = 'same: output_size = input_size * strides ''' generator = keras.models.Sequential([ #fullyconnected nets keras.layers.Dense(256,activation='selu',input_shape=[coding_size]), keras.layers.Dense(64,activation='selu'), keras.layers.Dense(256,activation='selu'), keras.layers.Dense(1024,activation='selu'), keras.layers.Dense(7*7*64,activation='selu'), keras.layers.Reshape([7,7,64]), #7*7*64 #deconvolution keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'), #14*14*64 keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'), #28*28*64 keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'), #56*56*32 keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'), #112*112*16 keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#Use tanh instead of sigmoid #224*224*3 keras.layers.Reshape([224,224,3]) ]) discriminator = keras.models.Sequential([ keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]), keras.layers.MaxPool2D(pool_size=2), #56*56*128 keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'), keras.layers.MaxPool2D(pool_size=2), #14*14*64 keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'), #7*7*32 keras.layers.Flatten(), #dropout 0.4 keras.layers.Dropout(0.4), keras.layers.Dense(512,activation='selu'), keras.layers.Dropout(0.4), keras.layers.Dense(64,activation='selu'), keras.layers.Dropout(0.4), #the last net keras.layers.Dense(1,activation='sigmoid') ]) #gans network gan = keras.models.Sequential([generator,discriminator]) #compile the net discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy']) discriminator.trainable=False gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy']) #dataset #train_datas = read_and_decode(tfrecords_path) train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False) #sess = tf.Session() #sess.run(iter.initializer) generator,discriminator = gan.layers print("-----------------start---------------") sess = tf.Session() for step in range(num_steps): #get the time start_time = time.time() #phase 1 - training the discriminator noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size]) noise = np.cast[np.float32](noise) generated_images = generator.predict(noise) train_datas_ = sess.run(train_datas) x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate #Never recycle tf Concat cannot be defined in the loop body with tf related functions #Otherwise, the memory will be exhausted and the training speed will be slower and slower y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size) discriminator.trainable = True dis_loss = discriminator.train_on_batch(x_fake_and_real,y1) #Train of keras_ on_ It is wise to put the batch function in the gan network #phase 2 - training the generator noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size]) noise = np.cast[np.float32](noise) y2 = np.array([[1.]]*batch_size) discriminator.trainable = False ad_loss = gan.train_on_batch(noise,y2) duration = time.time()-start_time if step % 5 == 0: #gan.save_weights('gan.h5') print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ') print('%.2f s/step'%(duration)) if step % 30 == 0 and step != 0: noise = np.random.normal(size=[1,coding_size]) noise = np.cast[np.float32](noise) fake_image = generator.predict(noise,steps=1) #Restore image #1. After multiplying by 255, it needs to be mapped to the type of uint8 #2. The float32 type of [0,1] can also be maintained and can still be output directly arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255 arr_img = np.cast[np.uint8](arr_img) #Saved as tfrecords using PIL Image, that is, it is opened as RGB, so it needs to be converted to BGR when displaying with cv arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR) cv2.imshow('fake image',arr_img) cv2.waitKey(1500)#show the fake image 1.5s cv2.destroyAllWindows() #save the models model_vision = '0001' model_name = 'gans' model_path = os.path.join(model_name,model_name) tf.saved_model.save(gan,model_path) def main(): training_keras() main()
This completes the simple gan training
reference material
Thesis: General advantageous networks
Reference source code:
https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py
Reference blog:
https://blog.csdn.net/u010138055/article/details/94441812
Last words
The learning slag of deep learning and machine learning is small and strong. It has just started. Please give us more advice on the deficiencies.