TF-Records abbreviated as TensorFlow Records is one the data formats that has serveral benefits such as performance (time) of the Tensorflow training. Using just TF-Records, I was able to get a direct decrease in the training time 3x times.

Most of the example blogs on TF-Records however, had images as example. I was working on a usecase of training a tabular data on spark cluster. I am posting this blog for easing out TF-Records adoption for scenarios such as mine. Hope this is useful.

Note: The intention of this post is about how to use TF-Records for better performances. We could run the training distributed across the cluster. Tensorflow API supports distributed training. We could also try Horovod which uses Petastorm file format to train the data in the spark cluster. But those topics are beyond the scope of this post

Some Advantage of TF-Records are that, it :

  • serialises data and stores. This means reduced space requirement for storage, faster data read and copy
  • uses Protocol Buffer format to store data which makes reading of data faster.
  • loads only required data into the memory. This is useful especially for large datasets
  • The data is never brought to Python level and is always dealt with C++ level which makes training faster
  • Tensorflow moves the data to GPU while training is performed

But to save/retrieve the records in TF-Records format, we need to have schema information to it which makes the process of creating TF-Records different from CSV / Excel / SQL.

Note: CSV/Excel doesn’t need even column names to both read and write. SQL doesn’t need datatype for querying however might need column name if we need only a specified column. This is not the case for TF Records. It need both name and its data-type for both write and read operations

Creating TF-Records

The first step is to write the data in the desired TF-Records format. The basic way to create TF-Records is to use tf.python_io.TFRecordWriter API. But there is a simpler way to create TF-Records in Spark cluster.

Library required

Creating the TF-Records in Spark cluster is easy. Thanks to the Spark Tensorflow Connector (spark-tensorflow-connector_2.11-1.10.0.jar).

Note: We are skipping the part on how we install this library into the cluster. This is a JAR install on the cluster and is usually generic. Also note that this library is needed only needed to create TF-Records. since we use Tensorflow (tf.data API) to read the data, we might not have to bother installing this library if we are using TF-Records created else where.

Writing the TF-Records

The API is so neat and simple to create TF-Records. Since the schema can be infered from the dataframe itself, we need not provide the same to write the data. This might not be the case when we use the traditional TFRecordWriter.

Suppose you have a spark-dataframe preprocessed_df. The easiest way to create the TF-Records is :

preprocessed_df.write.format('tfrecords').save(path_to_save)

We could save the data in an actual spark table as a backup.

Tip: We might have to read the data from TF-Records for non tensorflow purpose as well (like data analysis). But it is typically slow to read the data (even using spark.read API) compared to saving it as spark table (as parquet files). So I have a copy of the data in spark table and one in TF-Records format. I am saving to a seperate spark table that I use for analysis of data. I am using the same table to infer schema for retrieving the TF-Records. If you dont like to dump the data, all you need is the list of columns and its type to decode the data in the end.
preprocessed_df.write.format('parquet').mode('overwrite').saveAsTable(preprocessed_table_name)

Reading TF-Records

To read the TF-Records for usage in tensorflow, we can use the tf.data API.

As discussed, TF-Records have a catch that to read the data from TF-Record files, we need to know the schema of the data to read/decode the files. Since we have spark table stored, we infer its schema from the spark table that we already saved.

Infering Schema

To read the records we need to have the list of features and their types. We could use the backup spark table's schema to get that information. So we just read the spark table that we already stored to infer its schema.

preprocessed_df = spark.read.table(preprocessed_table_name)

We collect the schema dtypes as key/value pair into a dictionary

column_dtypes = {col:dtype for col,dtype in preprocessed_df.dtypes}

Create features from column dictionary

Now that we have the column dictionary, we have to create a features dictionary where we specify if the data is FixedLenFeature (typically for mandatory data) or VarLenFeature (typically for optional variables) and its data-type. I use FixedLenFeature as I dont have missing values

import tensorflow as tf

# Since we have all fixed length features, we create a lambda helper function to create features 
_fixed_feature = lambda x: tf.io.FixedLenFeature([],x)

def create_features(dtype_dict):
  features={}
  for dtype_tup in dtype_dict.items():
    if dtype_tup[1] in ('int','bigint','integer'):
      features[dtype_tuple[0]] = _fixed_feature(tf.int64)
    elif dtype_tup[1] in ('double','float','long'):
      features[dtype_tuple[0]] = _fixed_feature(tf.float32)
    elif dtype_tup[1] in ('string'):
      features[dtype_tuple[0]] = _fixed_feature(tf.string)

As we can create features using the above convenient function

features = create_features(column_dtypes)

Decoding

Having created the features we could decode the TF-Records. To simplify the process we create a decode method.

Initially I used tf.io.Example API to decode. It down performed the training time as compared to training from CSV file, since it decodes one record at a time.

Important: If you used the common tf.io.Example API, you might face performance lag in decoding. This is because the data is decoded one record at a time. use of tf.io.parse_example will parse the data of the entire batch (refer the documentation for more understanding).

It is to be noted that while decoding the serial data we must return a tuple of independent and dependent variables. But the data we stored doesnt have the knowledge of which columns constitute to both. Hence we might have to handle that in our decode method

output_cols={} # Specify the single or list of columns that are dependent variables
def decode(serial_data):
  # We use `parse_example` instead of `Example` for decoding in batches
  parsed_data = tf.io.parse_example(serialized=serial_data,features=features)
  # We segregate the dependent variables seperately for returning the appropriate tuple
  y = {col:parsed_data.pop(col) for col in output_cols}
  return parsed_data, tf.transpose(a=list(y.values()))

Read data using tf.data

If we have created the data using spark, the files are put in the folder which we specify in the save method. This creates multiple TF-Record files. But it also contains the _SUCCESS file along with it. We might have to ignore that file while providing data to the TF Data API to avoid a potential error. This can be done using a simple method as follows:

from pathlib import Path
def get_tf_records(folder_path):
  return [i for i in map(str,Path(folder_path).iterdir()) if i.name != '_SUCCESS']

Now, we can simply use this method to get all TF-Record files in the folder

training_records = get_tf_records(path_to_save)

The next part is to feed the data into the tf.data API. The API has a built-in method tf.data.Dataset.from_tensor_slices.

train_dataset = tf.data.Dataset.from_tensor_slices(training_records)

Note: tf.data is an excellent API. It has lots of features that can improve the performance of the training. Some of the parameters like num_parallel_calls,block_length,cycle_length were useful in particular. The shuffle API allows us to provide a number of items to be picked up and shuffled to get batch_size items. The details on tuning tf.data for performance must be a post by itself.

Training the module

To train the module we need to:

  • Get the dataset
  • Batch
  • Decode
  • Prefetch (not mandatory, just a performance tweak)

TF Data API has simplified this pipeline process by using chaining. So, we can simply fit like:

... = model.fit(train_dataset.batch(batch_size).map(decode).prefetch(1),...)

Conclusion

Along side TF-Records once the TF Data API is tuned we got performance improvement around 3.75-4x times. Just plain TF-Records was providing plain 3x times performance improvement on tabular data.

We could see blogs such as these by Sebastian Wallkötter which claim 7x improvement in Image dataset.

Based on my understanding, the reasons on why I couldnt achieve that improvement are as follows:

  • Data was stored in SSD (both CSV and TF-Records) which by itself does faster reads. Hence the impact of TF-Records being read faster became less predominant
  • Training time for tabular batch are typically slow compared to image data as tabular data has fewer features. (making our performance bottle neck CPU bound)

We could try to distribute training using a distribution strategy available in tensorflow in spark. Before distribution, we can consider if the issue is CPU/GPU bound (i.e) the performance bottle neck is on reading the data or training the model.

For instance, GPU which does the training might be waiting for the CPU to get the training data for each batch, if the training finishes before retrieving the data for the next batch. This makes the performance bottle-neck CPU bound.

If the training take more time than the time to retrieve the data, we might have CPU wait for GPU to complete the training. This makes the problem GPU bound.

This idea of wheather an issue is CPU or GPU bound gives us idea on what further course of action can be done to improve the performance.

Kindly share your experience and perspectives on training with TF-Records and on spark clusters