博客
关于我
人工智能深度学习入门练习之(20)TensorFlow2教程-Keras函数式API
阅读量:578 次
发布时间:2019-03-11

本文共 5904 字,大约阅读时间需要 19 分钟。

TensorFlow2教程-Keras函数式API

函数式API是一种创建深度学习模型的灵活方式,相比Sequential方法,它支持更复杂的模型拓扑结构,包括非线性连接、共享层以及多输入多输出。在TensorFlow 2.x中,函数式API通过Keras层来构建模型,提供了更高的灵活性和可定制性。


1. 构建简单的网络

1.1 创建网络

在函数式API中,模型通常从InputLayer开始,然后通过多个Layer连接到最终的OutputLayer。以下是一个简单的MNIST分类网络示例:

import tensorflow as tffrom tensorflow.keras import layersimport tensorflow.keras.backend as Kinputs = tf.keras.Input(shape=(784,))  # 输入层,形状为(batch_size, 784)h1 = layers.Dense(32, activation='relu')(inputs)  # 第一层全连接层h2 = layers.Dense(32, activation='relu')(h1)  # 第二层全连接层outputs = layers.Dense(10, activation='softmax')(h2)  # 输出层model = tf.keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')model.summary()  # 打印模型摘要

模型摘要显示,模型包含三个全连接层,输入形状为(784,),输出为10个类别。


1.2 训练与验证

模型训练与Sequential模型类似,使用fit方法训练,evaluate方法验证。以下是MNIST数据集的加载与训练代码:

from tensorflow.keras.datasets import mnistimport numpy as np(x_train, y_train), (x_test, y_test) = mnist.load_data()x_train = x_train.reshape(60000, 784).astype('float32') / 255x_test = x_test.reshape(10000, 784).astype('float32') / 255model.compile(optimizer=K.optimizers.RMSprop(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)test_scores = model.evaluate(x_test, y_test, verbose=0)print('Test Loss: ', test_scores[0])print('Test Accuracy: ', test_scores[1])

训练过程显示,模型在5个epoch内达到了较高的验证准确率。


2. 使用共享网络创建多个模型

函数式API允许通过共享层创建多个模型。例如,以下是一个自编码器网络的示例,通过共享编码器和解码器层来实现图像恢复。

from tensorflow.keras import layers# 编码器部分encode_input = layers.Input(shape=(28, 28, 1))h1 = layers.Conv2D(16, 3, activation='relu')(encode_input)h1 = layers.Conv2D(32, 3, activation='relu')(h1)h1 = layers.MaxPool2D(3)(h1)h1 = layers.Conv2D(32, 3, activation='relu')(h1)h1 = layers.Conv2D(16, 3, activation='relu')(h1)encode_output = layers.GlobalMaxPool2D()(h1)encode_model = layers.Model(inputs=encode_input, outputs=encode_output, name='encoder')# 解码器部分decode_input = layers.Input(shape=(16,))h2 = layers.Reshape((4, 4, 1))(decode_input)h2 = layers.Conv2DTranspose(16, 3, activation='relu')(h2)h2 = layers.Conv2DTranspose(32, 3, activation='relu')(h2)h2 = layers.UpSampling2D(3)(h2)h2 = layers.Conv2DTranspose(16, 3, activation='relu')(h2)decode_output = layers.Conv2DTranspose(1, 3, activation='relu')(h2)decode_model = layers.Model(inputs=decode_input, outputs=decode_output, name='decoder')# 自编码器模型autoencoder_input = layers.Input(shape=(28, 28, 1))h3 = encode_model(autoencoder_input)autoencoder_output = decode_model(h3)autoencoder = layers.Model(inputs=autoencoder_input, outputs=autoencoder_output, name='autoencoder')autoencoder.summary()

3. 复杂网络结构构建

3.1 多输入与多输出网络

函数式API可以处理多输入多输出的模型。例如,以下是一个定制票排序模型,接受三个输入并输出优先级和部门:

import numpy as npfrom tensorflow.keras import layersimport tensorflow as tfnum_words = 2000num_tags = 12num_departments = 4# 输入层body_input = layers.Input(shape=(None,))title_input = layers.Input(shape=(None,))tag_input = layers.Input(shape=(num_tags,))# 嵌入层body_feat = layers.Embedding(num_words, 64)(body_input)title_feat = layers.Embedding(num_words, 64)(title_input)# 特征提取层body_feat = layers.LSTM(32)(body_feat)title_feat = layers.LSTM(128)(title_feat)features = layers.Concatenate()([title_feat, body_feat, tag_input])# 分类层priority_pred = layers.Dense(1, activation='sigmoid')(features)department_pred = layers.Dense(num_departments, activation='softmax')(features)model = layers.Model(inputs=[body_input, title_input, tag_input], outputs=[priority_pred, department_pred])model.summary()

3.2 小型残差网络

函数式API还可以构建残差网络。以下是一个简单的ResNet模型示例:

from tensorflow.keras import layersinputs = layers.Input(shape=(32, 32, 3))h1 = layers.Conv2D(32, 3, activation='relu')(inputs)h1 = layers.Conv2D(64, 3, activation='relu')(h1)block1_out = layers.MaxPooling2D(3)(h1)h2 = layers.Conv2D(64, 3, activation='relu', padding='same')(block1_out)h2 = layers.Conv2D(64, 3, activation='relu', padding='same')(h2)block2_out = layers.add([h2, block1_out])  # 残差连接h3 = layers.Conv2D(64, 3, activation='relu', padding='same')(block2_out)h3 = layers.Conv2D(64, 3, activation='relu', padding='same')(h3)block3_out = layers.add([h3, block2_out])h4 = layers.Conv2D(64, 3, activation='relu')(block3_out)h4 = layers.GlobalMaxPool2D()(h4)h4 = layers.Dense(256, activation='relu')(h4)h4 = layers.Dropout(0.5)(h4)outputs = layers.Dense(10, activation='softmax')(h4)model = layers.Model(inputs=inputs, outputs=outputs)model.summary()

4. 共享网络层

函数式API的另一个优点是可以共享层。例如,以下是一个共享嵌入层的示例:

from tensorflow.keras import layers# 定义共享嵌入层share_embedding = layers.Embedding(1000, 64)# 创建两个输入input1 = layers.Input(shape=(None,))input2 = layers.Input(shape=(None,))# 通过共享嵌入层进行嵌入feat1 = share_embedding(input1)feat2 = share_embedding(input2)

5. 模型复用

函数式API支持模型复用。例如,可以通过函数式模型访问中间层输出并进行特征提取:

from tensorflow.keras.applications import VGG16vgg16 = VGG16()feature_list = [layer.output for layer in vgg16.layers]# 创建特征提取模型feat_ext_model = layers.Model(inputs=vgg16.input, outputs=feature_list)# 使用模型提取特征img = np.random.random((1, 224, 224, 3)).astype('float32')ext_features = feat_ext_model(img)

6. 自定义网络层

TF Keras允许用户定义自定义网络层。以下是一个简单的全连接层实现:

class MyDense(layers.Layer):    def __init__(self, units=32):        super(MyDense, self).__init__()        self.units = units    def build(self, input_shape):        self.w = self.add_weight(shape=(input_shape[-1], self.units),                               initializer='random_normal',                               trainable=True)        self.b = self.add_weight(shape=(self.units,),                               initializer='random_normal',                               trainable=True)    def call(self, inputs):        return tf.matmul(inputs, self.w) + self.b# 创建模型inputs = layers.Input(shape=(4,))outputs = MyDense(10)(inputs)model = layers.Model(inputs, outputs)model.summary()

何时使用函数式API

  • 简单模型:对于大多数深度学习模型,函数式API足够强大。
  • 灵活性需求:需要构建非线性拓扑结构或共享层时。
  • 可扩展性:需要对模型进行复用或序列化时。

与Sequential模型相比,函数式API提供了更高级别的构建方式,同时支持更多功能,如模型复用和动态架构。

转载地址:http://ngivz.baihongyu.com/

你可能感兴趣的文章
Objective-C实现newton_raphson牛顿拉夫森算法(附完整源码)
查看>>
Objective-C实现NLP中文分词(附完整源码)
查看>>
Objective-C实现NLP中文分词(附完整源码)
查看>>
Objective-C实现not gate非门算法(附完整源码)
查看>>
Objective-C实现number of digits解字符数算法(附完整源码)
查看>>
Objective-C实现NumberOfIslands岛屿的个数算法(附完整源码)
查看>>
Objective-C实现n皇后问题算法(附完整源码)
查看>>
Objective-C实现OCR文字识别(附完整源码)
查看>>
Objective-C实现odd even sort奇偶排序算法(附完整源码)
查看>>
Objective-C实现page rank算法(附完整源码)
查看>>
Objective-C实现PageRank算法(附完整源码)
查看>>
Objective-C实现pascalTriangle帕斯卡三角形算法(附完整源码)
查看>>
Objective-C实现perfect cube完全立方数算法(附完整源码)
查看>>
Objective-C实现PNG图片格式转换BMP图片格式(附完整源码)
查看>>
Objective-C实现pollard rho大数分解算法(附完整源码)
查看>>
Objective-C实现quick select快速选择算法(附完整源码)
查看>>
Objective-C实现recursive bubble sor递归冒泡排序算法(附完整源码)
查看>>
Objective-C实现recursive insertion sort递归插入排序算法(附完整源码)
查看>>
Objective-C实现RedBlackTree红黑树算法(附完整源码)
查看>>
Objective-C实现redis分布式锁(附完整源码)
查看>>