Using TF-Records on Spark Cluster
Creating and Reading TF Records on Spark Cluster
TF-Records abbreviated as TensorFlow Records is one the data formats that has serveral benefits such as performance (time) of the Tensorflow training. Using just TF-Records, I was able to get a direct decrease in the training time 3x times.
Most of the example blogs on TF-Records however, had images as example. I was working on a usecase of training a tabular data on spark cluster. I am posting this blog for easing out TF-Records adoption for scenarios such as mine. Hope this is useful.
Horovod
which uses Petastorm
file format to train the data in the spark cluster. But those topics are beyond the scope of this post
Some Advantage of TF-Records are that, it :
- serialises data and stores. This means reduced space requirement for storage, faster data read and copy
- uses Protocol Buffer format to store data which makes reading of data faster.
- loads only required data into the memory. This is useful especially for large datasets
- The data is never brought to Python level and is always dealt with C++ level which makes training faster
- Tensorflow moves the data to GPU while training is performed
But to save/retrieve the records in TF-Records format, we need to have schema information to it which makes the process of creating TF-Records different from CSV / Excel / SQL.
Library required
Creating the TF-Records in Spark cluster is easy. Thanks to the Spark Tensorflow Connector (spark-tensorflow-connector_2.11-1.10.0.jar
).
tf.data
API) to read the data, we might not have to bother installing this library if we are using TF-Records created else where.
The API is so neat and simple to create TF-Records. Since the schema can be infered from the dataframe itself, we need not provide the same to write the data. This might not be the case when we use the traditional TFRecordWriter
.
Suppose you have a spark-dataframe preprocessed_df
. The easiest way to create the TF-Records is :
preprocessed_df.write.format('tfrecords').save(path_to_save)
We could save the data in an actual spark table as a backup.
spark.read
API) compared to saving it as spark table (as parquet
files). So I have a copy of the data in spark table and one in TF-Records format. I am saving to a seperate spark table that I use for analysis of data. I am using the same table to infer schema for retrieving the TF-Records. If you dont like to dump the data, all you need is the list of columns and its type to decode the data in the end.
preprocessed_df.write.format('parquet').mode('overwrite').saveAsTable(preprocessed_table_name)
Reading TF-Records
To read the TF-Records for usage in tensorflow, we can use the tf.data
API.
As discussed, TF-Records have a catch that to read the data from TF-Record files, we need to know the schema of the data to read/decode the files. Since we have spark table stored, we infer its schema from the spark table that we already saved.
preprocessed_df = spark.read.table(preprocessed_table_name)
We collect the schema dtypes as key/value pair into a dictionary
column_dtypes = {col:dtype for col,dtype in preprocessed_df.dtypes}
Create features from column dictionary
Now that we have the column dictionary, we have to create a features dictionary where we specify if the data is FixedLenFeature (typically for mandatory data) or VarLenFeature (typically for optional variables) and its data-type. I use FixedLenFeature
as I dont have missing values
import tensorflow as tf
# Since we have all fixed length features, we create a lambda helper function to create features
_fixed_feature = lambda x: tf.io.FixedLenFeature([],x)
def create_features(dtype_dict):
features={}
for dtype_tup in dtype_dict.items():
if dtype_tup[1] in ('int','bigint','integer'):
features[dtype_tuple[0]] = _fixed_feature(tf.int64)
elif dtype_tup[1] in ('double','float','long'):
features[dtype_tuple[0]] = _fixed_feature(tf.float32)
elif dtype_tup[1] in ('string'):
features[dtype_tuple[0]] = _fixed_feature(tf.string)
As we can create features using the above convenient function
features = create_features(column_dtypes)
Having created the features we could decode the TF-Records. To simplify the process we create a decode method.
Initially I used tf.io.Example
API to decode. It down performed the training time as compared to training from CSV file, since it decodes one record at a time.
tf.io.Example
API, you might face performance lag in decoding. This is because the data is decoded one record at a time. use of tf.io.parse_example
will parse the data of the entire batch (refer the documentation for more understanding).
It is to be noted that while decoding the serial data we must return a tuple of independent and dependent variables. But the data we stored doesnt have the knowledge of which columns constitute to both. Hence we might have to handle that in our decode method
output_cols={} # Specify the single or list of columns that are dependent variables
def decode(serial_data):
# We use `parse_example` instead of `Example` for decoding in batches
parsed_data = tf.io.parse_example(serialized=serial_data,features=features)
# We segregate the dependent variables seperately for returning the appropriate tuple
y = {col:parsed_data.pop(col) for col in output_cols}
return parsed_data, tf.transpose(a=list(y.values()))
tf.data
Read data using If we have created the data using spark, the files are put in the folder which we specify in the save
method. This creates multiple TF-Record files. But it also contains the _SUCCESS
file along with it. We might have to ignore that file while providing data to the TF Data API to avoid a potential error. This can be done using a simple method as follows:
from pathlib import Path
def get_tf_records(folder_path):
return [i for i in map(str,Path(folder_path).iterdir()) if i.name != '_SUCCESS']
Now, we can simply use this method to get all TF-Record files in the folder
training_records = get_tf_records(path_to_save)
The next part is to feed the data into the tf.data
API. The API has a built-in method tf.data.Dataset.from_tensor_slices
.
train_dataset = tf.data.Dataset.from_tensor_slices(training_records)
tf.data
is an excellent API. It has lots of features that can improve the performance of the training. Some of the parameters like num_parallel_calls
,block_length
,cycle_length
were useful in particular. The shuffle
API allows us to provide a number of items to be picked up and shuffled to get batch_size
items. The details on tuning tf.data
for performance must be a post by itself.
... = model.fit(train_dataset.batch(batch_size).map(decode).prefetch(1),...)
Conclusion
Along side TF-Records once the TF Data API is tuned we got performance improvement around 3.75-4x times. Just plain TF-Records was providing plain 3x times performance improvement on tabular data.
We could see blogs such as these by Sebastian Wallkötter which claim 7x improvement in Image dataset.
Based on my understanding, the reasons on why I couldnt achieve that improvement are as follows:
- Data was stored in SSD (both CSV and TF-Records) which by itself does faster reads. Hence the impact of TF-Records being read faster became less predominant
- Training time for tabular batch are typically slow compared to image data as tabular data has fewer features. (making our performance bottle neck CPU bound)
We could try to distribute training using a distribution strategy available in tensorflow in spark. Before distribution, we can consider if the issue is CPU/GPU bound (i.e) the performance bottle neck is on reading the data or training the model.
For instance, GPU which does the training might be waiting for the CPU to get the training data for each batch, if the training finishes before retrieving the data for the next batch. This makes the performance bottle-neck CPU bound.
If the training take more time than the time to retrieve the data, we might have CPU wait for GPU to complete the training. This makes the problem GPU bound.
This idea of wheather an issue is CPU or GPU bound gives us idea on what further course of action can be done to improve the performance.
Kindly share your experience and perspectives on training with TF-Records and on spark clusters