tensorflow.keras builds gan neural network, which can run directly

tensorflow.keras builds gan neural network, which can run directly

preface

keras is one of the high-level API libraries of tensorflow, with concise code and strong readability. This paper adopts tensorflow keras to implement gan network. The specific principle will not be described too much in this paper, but only as a case exchange

#####keras Chinese reference documents

text

1, TF General steps for keras to build gan network

1. First, we need to replace all image data with tfrecords provided by tensorflow and use create_ tfrecords. Py file can be generated (this file is the script file I used to generate the label for image classification. If you do gan network, you don't need to save the label)
2. Use the generated tfrecords file to establish the data set, and use TF data. Tfrecorddataset. This paper also provides another method to obtain tfrecords data, but the same goal is achieved by different ways, and the methods are the same
3. Build generator network
4. Build discriminator network and integrate it into gan network (it is necessary to set discriminator network as untrainable before gan network compile)
5. Establish a loop to train generator network and discriminator network respectively
6. Save the network model

2, Use steps

1. Make tfrecords data set

creat_tfrecords.py
The default location for generating tfrecords is filename_train="./data/train.tfrecords"
Terminal input: Python create_ tfrecords. Py -- data [dataset location]
Generate train Tfrecords, you can also add the data of verification set and test set yourself

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt 
import os
from PIL import Image
import random

objects = ['cat','dog']#'cat'0,'dog'1

filename_train="./data/train.tfrecords"
writer_train= tf.python_io.TFRecordWriter(filename_train)

tf.app.flags.DEFINE_string(
    'data', 'None', 'where the datas?.')
FLAGS = tf.app.flags.FLAGS

if(FLAGS.data == None):
    os._exit(0)

dim = (224,224)
object_path = FLAGS.data
total = os.listdir(object_path)
for index in total:
    img_path=os.path.join(object_path,index)
    img=Image.open(img_path)
    img=img.resize(dim)
    img_raw=img.tobytes()
    for i in range(len(objects)):
        if objects[i] in index:
            value = i
        else:
            continue
    example = tf.train.Example(features=tf.train.Features(feature={
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
    print([index,value])
    writer_train.write(example.SerializeToString())  #Serialize to string
writer_train.close()

2. Read in data

Using TF data. Tfrecorddataset creation
The code is as follows: (the load_image function is used as the input of the map to decode the dataset), and it is called in the main function:
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)

def load_image(serialized_example):   
    features={
        'label': tf.io.FixedLenFeature([], tf.int64),
        'img_raw' : tf.io.FixedLenFeature([], tf.string)}
    parsed_example = tf.io.parse_example(serialized_example,features)
    image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
    image = tf.reshape(image,[-1,224,224,3])
    image = tf.cast(image,tf.float32)*(1./255)
    label = tf.cast(parsed_example['label'], tf.int32)
    label = tf.reshape(label,[-1,1])
    return image,label
 
def dataset_tfrecords(tfrecords_path,use_keras_fit=True): 
    #Whether to use TF keras
    if use_keras_fit:
        epochs_data = 1
    else:
        epochs_data = epochs
    dataset = tf.data.TFRecordDataset([tfrecords_path])#This can have multiple components [tfrecords_name1,tfrecords_name2,...], You can use OS listdir(tfrecords_path):
    dataset = dataset\
                .repeat(epochs_data)\
                .batch(batch_size)\
                .map(load_image,num_parallel_calls = 2)\
                .shuffle(1000)

    iter = dataset.make_one_shot_iterator()#make_initialization_iterator
    train_datas = iter.get_next() #Use train_ Data [0], [1]
    return train_datas,iter

3. Build gan network

a. Build generator network

    generator = keras.models.Sequential([
            #fullyconnected nets
            keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dense(256,activation='selu'),
            keras.layers.Dense(1024,activation='selu'),
            keras.layers.Dense(7*7*64,activation='selu'),
            keras.layers.Reshape([7,7,64]),
            #7*7*64
            #deconvolution 
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*64
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#Use tanh instead of sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])

b. Build discriminator network

    discriminator = keras.models.Sequential([
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*128
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*32
            keras.layers.Flatten(),
            #dropout 0.4
            keras.layers.Dropout(0.4),
            keras.layers.Dense(512,activation='selu'),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dropout(0.4),
            #the last net
            keras.layers.Dense(1,activation='sigmoid')
            ])

c. Integrate the generator and discriminator network into gan network

gan = keras.models.Sequential([generator,discriminator])

4. Compile (establish loss and optimizer)

    #compile the net
    discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    discriminator.trainable=False
    gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])

5. Training network (establish cycle)

Get dataset:

train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)

Loop body: (use cv2 to check the generator network)

    generator,discriminator = gan.layers
    sess = tf.Session()
    for step in range(num_steps):
        #get the time
        start_time = time.time()
        #phase 1 - training the discriminator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        generated_images = generator.predict(noise)
        train_datas_ = sess.run(train_datas)
        x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
        #Never recycle tf Concat cannot be defined in the loop body with tf related functions
        #Otherwise, the memory will be exhausted and the training speed will be slower and slower
        y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size)
        discriminator.trainable = True
        dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
        #Train of keras_ on_ It is wise to put the batch function in the gan network
        #phase 2 - training the generator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        y2 = np.array([[1.]]*batch_size)
        discriminator.trainable = False
        ad_loss = gan.train_on_batch(noise,y2)
        duration = time.time()-start_time
        if step % 5 == 0:
            #gan.save_weights('gan.h5')
            print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
            print('%.2f s/step'%(duration))
        if step % 30 == 0 and step != 0:
            noise = np.random.normal(size=[1,coding_size])
            noise = np.cast[np.float32](noise)
            fake_image = generator.predict(noise,steps=1)
            #Restore image
            #1. After multiplying by 255, it needs to be mapped to the type of uint8
            #2. The float32 type of [0,1] can also be maintained and can still be output directly
            arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
            arr_img = np.cast[np.uint8](arr_img)
            #Saved as tfrecords using PIL Image, that is, it is opened as RGB, so it needs to be converted to BGR when displaying with cv
            arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
            cv2.imshow('fake image',arr_img)
            cv2.waitKey(1500)#show the fake image 1.5s
            cv2.destroyAllWindows()

6. Save the network

    #save the models 
    model_vision = '0001'
    model_name = 'gans'
    model_path = os.path.join(model_name,model_name)
    tf.saved_model.save(gan,model_path)

7. Complete Gans Py (operational)

# -*- coding: utf-8 -*-
'''
    @author:zyl
    author is zouyuelin
    a Master of Tianjin University(TJU)
'''

import tensorflow as tf
from tensorflow import keras
#tf.enable_eager_execution()
import numpy as np
from PIL import Image
import os
import cv2
import time

batch_size = 32
epochs = 120
num_steps = 2000
coding_size = 30
tfrecords_path = 'data/train.tfrecords'

#--------------------------------------datasetTfrecord----------------   
def load_image(serialized_example):   
    features={
        'label': tf.io.FixedLenFeature([], tf.int64),
        'img_raw' : tf.io.FixedLenFeature([], tf.string)}
    parsed_example = tf.io.parse_example(serialized_example,features)
    image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
    image = tf.reshape(image,[-1,224,224,3])
    image = tf.cast(image,tf.float32)*(1./255)
    label = tf.cast(parsed_example['label'], tf.int32)
    label = tf.reshape(label,[-1,1])
    return image,label
 
def dataset_tfrecords(tfrecords_path,use_keras_fit=True): 
    #Whether to use TF keras
    if use_keras_fit:
        epochs_data = 1
    else:
        epochs_data = epochs
    dataset = tf.data.TFRecordDataset([tfrecords_path])#This can have multiple components [tfrecords_name1,tfrecords_name2,...], You can use OS listdir(tfrecords_path):
    dataset = dataset\
                .repeat(epochs_data)\
                .batch(batch_size)\
                .map(load_image,num_parallel_calls = 2)\
                .shuffle(1000)

    iter = dataset.make_one_shot_iterator()#make_initialization_iterator
    train_datas = iter.get_next() #Use train_ Data [0], [1]
    return train_datas,iter

#------------------------------------tf.TFRecordReader-----------------
def read_and_decode(tfrecords_path):
    #Generate a queue based on the file name
    filename_queue = tf.train.string_input_producer([tfrecords_path],shuffle=True) 
    reader = tf.TFRecordReader()
    _,  serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img_raw' : tf.FixedLenFeature([], tf.string)})

    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image,[224,224,3])#reshape 200*200*3
    image = tf.cast(image,tf.float32)*(1./255)#The image tensor can be divided by 255, * (1. / 255)
    label = tf.cast(features['label'], tf.int32)
    img_batch, label_batch = tf.train.shuffle_batch([image,label],
                    batch_size=batch_size,
                    num_threads=4,
                    capacity= 640,
                    min_after_dequeue=5)
    return [img_batch,label_batch]

#Autodecode decoder
def autoencode():
        encoder = keras.models.Sequential([
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            #112*112*32
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*32
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            #28*28*64
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*128
            #deconvolution 
            keras.layers.Conv2DTranspose(128,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*128
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#Use tanh instead of sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])
        return encoder

def training_keras():
    '''
        Convolution and pooling output formula:
            output_size = (input_size-kernel_size+2*padding)/strides+1
            
        keras Deconvolution output calculation, generally not out_padding
        1.if padding = 'valid':
            output_size = (input_size - 1)*strides + kernel_size
        2.if padding = 'same:
            output_size = input_size * strides
    '''
    generator = keras.models.Sequential([
            #fullyconnected nets
            keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dense(256,activation='selu'),
            keras.layers.Dense(1024,activation='selu'),
            keras.layers.Dense(7*7*64,activation='selu'),
            keras.layers.Reshape([7,7,64]),
            #7*7*64
            #deconvolution 
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*64
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#Use tanh instead of sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])
            
    discriminator = keras.models.Sequential([
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*128
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*32
            keras.layers.Flatten(),
            #dropout 0.4
            keras.layers.Dropout(0.4),
            keras.layers.Dense(512,activation='selu'),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dropout(0.4),
            #the last net
            keras.layers.Dense(1,activation='sigmoid')
            ])
    #gans network        
    gan = keras.models.Sequential([generator,discriminator])
    
    #compile the net
    discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    discriminator.trainable=False
    gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    
    #dataset
    #train_datas = read_and_decode(tfrecords_path)
    train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
    
    #sess = tf.Session()
    #sess.run(iter.initializer)
    
    generator,discriminator = gan.layers
    print("-----------------start---------------")
    sess = tf.Session()
    for step in range(num_steps):
        #get the time
        start_time = time.time()
        #phase 1 - training the discriminator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        generated_images = generator.predict(noise)
        train_datas_ = sess.run(train_datas)
        x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
        #Never recycle tf Concat cannot be defined in the loop body with tf related functions
        #Otherwise, the memory will be exhausted and the training speed will be slower and slower
        y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size)
        discriminator.trainable = True
        dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
        #Train of keras_ on_ It is wise to put the batch function in the gan network
        #phase 2 - training the generator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        y2 = np.array([[1.]]*batch_size)
        discriminator.trainable = False
        ad_loss = gan.train_on_batch(noise,y2)
        duration = time.time()-start_time
        if step % 5 == 0:
            #gan.save_weights('gan.h5')
            print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
            print('%.2f s/step'%(duration))
        if step % 30 == 0 and step != 0:
            noise = np.random.normal(size=[1,coding_size])
            noise = np.cast[np.float32](noise)
            fake_image = generator.predict(noise,steps=1)
            #Restore image
            #1. After multiplying by 255, it needs to be mapped to the type of uint8
            #2. The float32 type of [0,1] can also be maintained and can still be output directly
            arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
            arr_img = np.cast[np.uint8](arr_img)
            #Saved as tfrecords using PIL Image, that is, it is opened as RGB, so it needs to be converted to BGR when displaying with cv
            arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
            cv2.imshow('fake image',arr_img)
            cv2.waitKey(1500)#show the fake image 1.5s
            cv2.destroyAllWindows()
            
    #save the models 
    model_vision = '0001'
    model_name = 'gans'
    model_path = os.path.join(model_name,model_name)
    tf.saved_model.save(gan,model_path)
    
def main():
    training_keras()
main()

This completes the simple gan training

reference material

Thesis: General advantageous networks
Reference source code:
https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py
Reference blog:
https://blog.csdn.net/u010138055/article/details/94441812

Last words

The learning slag of deep learning and machine learning is small and strong. It has just started. Please give us more advice on the deficiencies.

Tags: Python TensorFlow Deep Learning

Posted by harman on Thu, 05 May 2022 17:43:27 +0300