Auto Sharding¶
This document describes the automatic sharding mechanism in TorchCAP, detailing the implementation of the auto-sharding pipeline. It covers the architecture of the sharding optimizer, the formulation used to determine an optimal sharding plan, and the transformation process applied to the model graph.
Automatic Sharding Overview¶
The high-level API for automatic sharding is available via torchcap.optimize (api.py). The optimizer performs the following steps:
Converts the input model into an FX graph using torch.export
Estimates the runtime and memory consumption of each operator in the graph
Determines the optimal sharding strategy for the graph
Transforms the model into a distributed version using the selected strategy
Users may also provide custom sharding strategies for specific operations. The optimizer will compute the optimal strategy for the remaining parts of the graph accordingly.
Automatic Sharding Solver¶
The solver formulation is based on the integer linear programming (ILP) approach proposed in Alpa, with modifications to support PyTorch DTensor. The implementation is available in parallel_solver.py.
PyTorch DTensor supports three types of sharding (also referred to as placement in the PyTorch documentation):
Shard(dim)(S(dim)): Shards the specified tensor dimension over the mesh dimension.Replicate(R): Replicates the tensor across the mesh dimension.Partial(P): Indicates the tensor is pending reduction across devices.
For an N-dimensional mesh, a vector of length N represents the sharding strategy across each dimension. For example, S(0)R indicates dimension 0 is sharded over mesh dimension 0 and replicated over mesh dimension 1. See the PyTorch DTensor documentation for further details.
For each operator in the graph, the solver enumerates all possible sharding strategies based on the operator sharding rules defined in both Pytorch and sharding_strategy.py. For example, a sharding strategy of a linear operator R, S(0), S(0) -> S(1), representing that first argument is replicated, the second and third arguments are sharded over the tensor dimension 0 and the output is sharded over the tensor dimension 1.
Each operator u is assigned a one-hot vector \(s_u\), where \(s_u[i] = 1\) denotes that the i-th strategy has been selected for operator u.
When two operators require incompatible sharding for a shared tensor, communication is needed for resharding. For instance, if the output of a linear operator is S(1), but the consumer operator expects it as R, an all-gather operation is required—introducing communication overhead. For the resharding cost between operator u and operator v, the solver constructs a resharding cost matrix \(R_{uv}\), where \(R_{uv}[i][j]\) is the cost of resharding the output of strategy i for operator u into the input of strategy j for operator v.
The objective of the formulation is
This formulation captures both the choice of strategy and the communication cost, similar to Alpa’s formulation but with communication overheads folded into the resharding costs.
Memory Constraint¶
To bound memory usage per device, a memory constraint is added to the formulation. Let \(u_0, u_1, \ldots, u_{n-1}\) be the operators in topological order. Let \(m_t\) be the memory consumed by the output tensor of operator \(u_t\). It is calculated as:
Here, \(\text{output\_size}(u_t)[i]\) is the output size of operator \(u_t\) under strategy i.
Using liveness analysis, we extract the live range of each output tensor as \([start_k, end_k]\). Define \(\delta[t]\) as the net memory change at step \(t\):
where the first term is the memory allocation of the output tensor of operator \(t\) and the second term is the memory deallocation of the output tensor last used by operator \(t\).
The cumulative memory consumption at step \(t\), denoted as \(M_t\), is then:
Therefore, the memory constraint can be represented as