3.3.1 Tensorflow for large dataset

    Since MNIST dataset is a quick and small example, the issue of GPU out-of-memory (OOM) will not be raised. However, large dataset is often seen in the real world, which might cause program crashed due to GPU OOM.

    A general approach is to load data from batches into GPU. Even if the method "model.fit" and "model.predict" contains the argument "batch_size", the behavior of Tensorflow tends to acquire more space than a batch which leads to the risk of GPU OOM. Hence, please add the following code to limit the GPU memory usage:

    1. In the beginning of main.py

    import tensorflow as tf
    gpus = tf.config.experimental.list_physical_devices(device_type="GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    2. At the end of data preparation

    # after the preparation of: xTrain, yTrain, xTest, yTest
    # after setup batch_size
    with tf.device("CPU"):
        trainset = tf.data.Dataset.from_tensor_slices( (xTrain, yTrain) ).shuffle(8192).batch(batch_size)
        testset = tf.data.Dataset.from_tensor_slices( (xTest, yTest) ).shuffle(8192).batch(batch_size)

    3. Use for loop to train and validate batches

    Before After
    model.fit(xTrain, yTrain, epochs=1, batch_size=batch_size) for xTrainBatch, yTrainBatch in trainset:
        model.fit(xTrainBatch, yTrainBatch, epochs=1, batch_size=batch_size)
    model.predict(xTest, yTest, epochs=1, batch_size=batch_size) for xTestBatch, yTestBatch in testset:
        model.predict(xTestBatch, yTestBatch, epochs=1, batch_size=batch_size)

    Taiwan AI Labs (AILabs.tw) Copyright © 2023Powered by Bludit