六. tf.dataset输入模块

一. from_tensor_slices

import tensorflow as tf
if __name__ == '__main__':
    #把给定的元组、列表和张量等数据进行特征切片,张量
    dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5])
    for ele in dataset:
        print(ele)
    '''
    tf.Tensor(1, shape=(), dtype=int32)
    tf.Tensor(2, shape=(), dtype=int32)
    tf.Tensor(3, shape=(), dtype=int32)
    tf.Tensor(4, shape=(), dtype=int32)
    tf.Tensor(5, shape=(), dtype=int32)
    '''

    #第一维度进行切分
    dataset = tf.data.Dataset.from_tensor_slices([[1,2], [3,4], [5,6], [7,8], [9,10]])
    for ele in dataset:
        print(ele)
    '''
    tf.Tensor([1 2], shape=(2,), dtype=int32)
    tf.Tensor([3 4], shape=(2,), dtype=int32)
    tf.Tensor([5 6], shape=(2,), dtype=int32)
    tf.Tensor([7 8], shape=(2,), dtype=int32)
    tf.Tensor([ 9 10], shape=(2,), dtype=int32)
    '''

    for ele in dataset:
        print(ele.numpy())
    '''
    [1 2]
    [3 4]
    [5 6]
    [7 8]
    [ 9 10]
    '''

    #字典
    dataset_dic = tf.data.Dataset.from_tensor_slices(
        {
            'a': [1, 2, 3, 4],
            'b': [5, 6, 7, 8],
            'c': [9,10,11,12]
        }
    )
    print(dataset_dic)
    #<TensorSliceDataset shapes: {a: (), b: (), c: ()}, types: {a: tf.int32, b: tf.int32, c: tf.int32}>
    for ele in dataset_dic:
        print(ele)
    '''
    'a': <tf.Tensor: shape=(), dtype=int32, numpy=1>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=5>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=9>}
    {'a': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=6>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=10>}
    {'a': <tf.Tensor: shape=(), dtype=int32, numpy=3>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=7>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=11>}
    {'a': <tf.Tensor: shape=(), dtype=int32, numpy=4>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=8>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=12>}
    '''

二. shuffle,repeat,square

    dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7])
    dataset = dataset.shuffle(7)#乱序,7个元素乱序一次
    for ele in dataset:
        print(ele.numpy())
    #5 3 7 2 4 1 6
    #dataset = dataset.shuffle(7).repeat(3)
    dataset =dataset.repeat(3)#乱序三次
    '''
    2 4 6 7 5 1 3
    2 4 3 7 1 5 6
    3 5 4 2 6 7 1
    '''
    #dataset = dataset.batch(3)#三个一组
    '''
    [1 5 2]
    [6 7 4]
    [3 3 7]
    [6 4 5]
    [2 1 6]
    [7 4 2]
    [3 1 5]
    '''
    dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7])
    dataset = dataset.map(tf.square)#平放
    '''
    1
    4
    9
    16
    25
    36
    49
    '''

三. 实例

import tensorflow as tf                                                                      
                                                                                             
                                                                                             
if __name__ == '__main__':                                                                   
    #mnist数据集                                                                                
    (train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data
    train_images = train_images/255#(60000, 28, 28)                                          
    test_images = test_images/255#(10000, 28, 28)                                            
    ds_train_img = tf.data.Dataset.from_tensor_slices(train_images)                          
    ds_train_labels = tf.data.Dataset.from_tensor_slices(train_labels)                       
    ds_test_img = tf.data.Dataset.from_tensor_slices(test_images)                            
    ds_test_labels = tf.data.Dataset.from_tensor_slices(test_labels)                         
                                                                                             
    ds_train = tf.data.Dataset.zip((ds_train_img,ds_train_labels))#()元组形式合并                  
    #<ZipDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)> ()代表一个数字             
    ds_test = tf.data.Dataset.zip((ds_test_img,ds_test_labels))                              
                                                                                             
    ds_train = ds_train.shuffle(60000).repeat().batch(64)#10000张图片,64张图片一组,无限乱序              
    ds_test = ds_test.batch(64)                                                              
    model = tf.keras.Sequential([                                                            
        tf.keras.layers.Flatten(input_shape=(28,28)),                                        
        tf.keras.layers.Dense(128,activation='relu'),                                        
        tf.keras.layers.Dense(10,activation='softmax')                                       
    ])                                                                                       
    model.compile(                                                                           
        optimizer='adam',                                                                    
        loss='sparse_categorical_crossentropy',                                              
        metrics=['accuracy']                                                                 
    )                                                                                        
    steps_per_epochs = train_images.shape[0]//64#数据集总数/batch_size等于有多少个batch.//整除            
    history = model.fit(                                                                     
        ds_train,                                                                            
        epochs=5,                                                                            
        steps_per_epoch = steps_per_epochs,#循环多少个batch算一个epoch                               
        validation_data=ds_test,#打印测试集                                                       
        validation_steps=10000//64                                                           
    )                                                                                        
    print(history)                                                                           
Epoch 1/5
937/937 [==============================] - 1s 1ms/step - loss: 0.2947 - accuracy: 0.9177 - val_loss: 0.1588 - val_accuracy: 0.9527
Epoch 2/5
937/937 [==============================] - 1s 1ms/step - loss: 0.1355 - accuracy: 0.9607 - val_loss: 0.1145 - val_accuracy: 0.9649
Epoch 3/5
937/937 [==============================] - 1s 1ms/step - loss: 0.0951 - accuracy: 0.9723 - val_loss: 0.0944 - val_accuracy: 0.9719
Epoch 4/5
937/937 [==============================] - 1s 1ms/step - loss: 0.0733 - accuracy: 0.9782 - val_loss: 0.0857 - val_accuracy: 0.9738
Epoch 5/5
937/937 [==============================] - 1s 1ms/step - loss: 0.0584 - accuracy: 0.9825 - val_loss: 0.0804 - val_accuracy: 0.9743

版权声明:本文为weixin_40519901原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。