Unified API for Deep Network Distributed Training

Recently Released TensorFlow v2.9 present a new API for training the model, data, and space-parallel (i.e., spatially tiled) deep network. DTensor aims to decouple partitioning directives from model code by providing higher-level utilities to partition model and batch parameters between devices. The work is part of recent effort (e.g. GPipe, TF mesh, GShard, DeepSpeed, Fair scale, ColossalAI) to reduce development time to create large-scale training workloads.

Training test loss logarithmic scales with the number of network parameters, data size and computation time for large models (language). Therefore, task-level improvements were highly dependent on the size of the deep network in recent years. Using an ever-increasing number of acceleration devices has also required a large amount of engineering work due to the distributed nature of these training platforms (i.e. GPUs and FPGAs come with a limited built-in memory unit). DTensor provides an alternative abstraction on top of the training units (i.e. device independent) by providing a mesh configuration for devices and a layout for placement of tensors.

As an example, the illustration below shows the placement of a tensor on two different mesh configurations with three different layouts. In the second mesh case, it is possible to opt for column or row sharding by indicating which dimension should be “unsharded” and instead duplicated between devices.

By creating separate mesh and layout objects, DTensor provides the flexibility to adopt different training topologies without hard-coding device configurations. As an example, it provides an easy way to implement spatial partitioning of tensors along any dimension without using a specialized API for computer vision applications (unlike TensorFlow TPUEstimator spatial partitioning). It should be noted that the device mesh API can be used with TF virtual devices (via logical device mechanism), therefore, different partitioning scenarios can be experienced using the DTensor API.

Although DTensor has an experimental API, it currently supports direct Keras integration. In the code snippet below, an example of fully replicated weight layouts is provided to a dense layer:

mesh = dtensor.create_mesh([("batch", 8)], devices=devices)
kernel_layout = dtensor.Layout.replicated(mesh, 2)
bias_layout = dtensor.Layout.replicated(mesh, 1)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),

DTensor provides a direct replacement for most tensor operations, so it can be used by tf.function and tf.GradientTape APIs too. However, the current version of TensorFlow does not support built-in training loop for Keras, a custom loop must be written for training DTensor shard models. It also supports single and multi-client training jobs, multiple processes can leverage the API natively.

Additional information can be obtained from DTensor Insight Documentation. The TensorFlow website also provides examples for low-level distributed training and Keras training.

Comments are closed.