Mariana 并行策略
Mariana Distributed Parallelism Guide
Goal: This document explains the effective distributed-training strategies used in Mariana across its main model families, especially VLM and LLM training, but also other Mariana-supported multimodal, omni, audio/speech, and reward-model paths when they materially change or inherit the parallelism story, with a beginner-friendly focus on:
- what each parallelism strategy means
- what gets computed locally versus communicated remotely
- how Mariana implements it
- which kernels, collectives, or helper libraries are involved
- how computation and communication are overlapped
- how to choose among the strategies and understand their trade-offs
- whether
fluxalready supports it, and what would be needed if it does not- which Mariana model families actually map onto each composition template
- what representative Mariana workload defaults look like in Python config classes and Hydra YAMLs
1. What Mariana Actually Uses
The first thing to clarify is that Mariana has two different kinds of “parallelism concepts”:
- parallel dimensions such as TP, SP, EP, PP, DP
- strategy frameworks such as FSDP2 and Megatron that combine several dimensions together
The effective Mariana-wide picture is:
| Strategy / dimension | What it splits | Representative Mariana model families | Main Mariana entry points | FLUX support |
|---|---|---|---|---|
| Data parallel (DP) | training samples | VLM, LLM, OmniPT, RM | mariana/mariana/distributed/parallel_state.py:45-232 |
Very weak |
| FSDP / HSDP / FSDP2 | parameters, grads, optimizer state | VLM, LLM, RM, some multimodal trainer paths | mariana/mariana/trainer/strategy/fsdp2/strategy.py:37-195, mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:551-589, mariana/mariana/trainer/strategy/fsdp2_torch.py:57-518 |
Not a first-class FLUX feature |
| Legacy FSDP / LSDP | selected or full parameter shards | VLM, LLM, RM, some multimodal trainer paths | mariana/mariana/trainer/strategy/fsdp.py, mariana/mariana/trainer/strategy/lsdp.py:79-255 |
Not a first-class FLUX feature |
| Tensor parallel (TP) | weight matrices / hidden dims | LLM, Megatron VLM, MegatronOmni | mariana/megatron/core/tensor_parallel/layers.py:16-45, mariana/megatron/core/tensor_parallel/layers.py:973-1390, mariana/mariana/models/stable/m11/tp.py:1-159 |
Strong |
| Sequence parallel (SP / Ulysses / DSP) | sequence tokens around attention | VLM, LLM, MegatronOmni | mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:22-97, mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:49-170 |
Strong |
| Context parallel (CP) | long-context sequence positions | long-context LLM, MegatronOmni, limited generic VLM | mariana/mariana/models/utils/context_parallel_manager.py:17-245 |
Weak / absent |
| Expert parallel (EP) | MoE experts and routed tokens | MoE VLM, MoE LLM, selected Omni/text backbones | mariana/mariana/models/multimodal/parallel/fsdp/ep_prefetch.py:1-97, mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:8-396, mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:7-76 |
Medium to strong, but incomplete |
| Pipeline parallel (PP) | layer stages | Megatron LLM, Megatron VLM, MegatronOmni, some FSDP2 MP paths | mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:206-345, mariana/mariana/trainer/strategy/megatron.py:318-430 |
Partial |
| Parallel encoder | vision/audio encoders across ranks | VLM, Omni vision/audio encoders | mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:37-255, mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:943-956 |
No direct support |
| Megatron strategy | composition of TP + PP + DP + optional CP/EP/SP | LLM, Megatron VLM, MegatronOmni | mariana/mariana/trainer/strategy/megatron.py:318-430 |
FLUX helps selected subpaths |
| FSDP2 MP strategy | composition of FSDP + EP/OE + PP | VLM, newer text/multimodal paths, some RM-style trainer variants | mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:32-120, mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:206-345 |
FLUX helps selected subpaths |
1.1 Which Model Families This Guide Covers
Mariana supports more than a simple VLM-versus-LLM split. The most useful families to keep straight are:
| Model family | Concrete evidence in tree | Best mental starting point for this guide |
|---|---|---|
| Text causal LLMs | mariana/tasks/gpt2, mariana/tasks/nanogpt, mariana/mariana/models/text, mariana/mariana/models/stable |
Use the LLM sections below |
| Megatron VLMs | mariana/tasks/multimodal/train_megatron.py:56-86, mariana/mariana/models/multimodal/parallel/megatron/config.py:7-32 |
Start from the VLM sections below |
| MegatronOmni / OmniPT (vision + audio + text) | mariana/tasks/seed2/megatron/ARCHITECTURE.md:1-3, mariana/tasks/seed2/megatron/model_registry.py:31-56, mariana/tasks/seed2/megatron/ARCHITECTURE.md:372-383 |
Start from the Megatron LLM template, then add multimodal encoder concerns |
| Omni audio/speech and diffusion subfamilies | mariana/mariana/models/omni/usm/modeling_usm.py:20-77, mariana/mariana/models/omni/acceleration/usm/fused_usm.py:16-111, mariana/mariana/models/omni/diffusion/na_audiodit.py:25-76 |
Usually inherit the omni/VLM template rather than invent a new top-level parallel dimension |
| Reward models (RM) on text or VLM backbones | mariana/tasks/rm/engine/fsdp/configuration.py:34-145, mariana/tasks/rm/engine/megatron/rm_model.py:509-560 |
Pick the underlying LLM/VLM backbone strategy first, then add the reward head |
So below, most per-strategy subsections still contrast VLM and LLM because those are the two dominant templates. The extra families above are called out explicitly when they change the parallelism picture enough to matter.
2. Typical Compositions In Mariana
2.1 Typical VLM Composition
The common VLM stack in Mariana is:
|
|
Why:
- VLM training usually wants a strong memory-saving strategy first, so FSDP2 is a natural base
- attention and multimodal encoders benefit from sequence-parallel style splits
- MoE variants want expert parallelism
- multimodal models may need extra encoder load balancing that pure LLMs do not
The main code path is:
- strategy selection in
mariana/mariana/trainer/base_trainer.py:460-531 - FSDP2 strategy selection in
mariana/mariana/trainer/strategy/fsdp2/strategy.py:90-95 - MPMD composition in
mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:32-120
2.2 Typical LLM Compositions
Mariana LLM training appears in two major families.
A. Megatron-family LLM training
|
|
This is the classic large-scale text-model path. MegatronStrategy itself is not “one parallel dimension”; it is the framework that orchestrates several dimensions together (mariana/mariana/trainer/strategy/megatron.py:318-430).
B. FSDP-family LLM or multimodal-text training
|
|
This path is closer to the VLM stack, but often without parallel_encoder.
2.3 How Other Mariana-Supported Families Map To These Templates
If you want a compact rule of thumb:
| Model family | Best starting template | What is special in Mariana |
|---|---|---|
Megatron VLM (m8, m10, m12, m13) |
VLM template | uses Megatron text backbones plus multimodal encoder adaptors and encoder-specific parallel knobs (mariana/tasks/multimodal/train_megatron.py:56-86, mariana/mariana/models/multimodal/parallel/megatron/config.py:82-110) |
| MegatronOmni / OmniPT | Megatron LLM template | adds multimodal encoders plus DSP and CP; encoders stay replicated while the LLM backbone uses Megatron 3D parallelism (mariana/tasks/seed2/megatron/ARCHITECTURE.md:372-383, mariana/tasks/seed2/megatron/model_registry.py:31-56) |
| Omni audio/speech and diffusion modules | Omni/VLM template | adds modality-specific fused kernels such as the USM fused frontend, but usually does not introduce a new top-level distributed dimension (mariana/mariana/models/omni/usm/modeling_usm.py:20-77, mariana/mariana/models/omni/acceleration/usm/fused_usm.py:16-111, mariana/mariana/models/omni/diffusion/na_audiodit.py:25-76) |
| Reward models | underlying LLM or VLM template | the reward head wraps a text or VLM backbone, so the backbone’s DP/FSDP/Megatron choice still dominates (mariana/tasks/rm/engine/fsdp/configuration.py:34-145, mariana/tasks/rm/engine/megatron/rm_model.py:509-560) |
2.4 One Important Mariana Limitation
In the generic trainer setup, cp_size is still hardcoded to 1 in _setup_parallel_state() (mariana/mariana/trainer/base_trainer.py:517-530).
That means:
- CP exists in Mariana as a real concept
- but the more active CP usage today is in the Megatron/task path rather than the generic FSDP trainer mesh path
3. Beginner Mental Model
If you are new to distributed training, this is the most useful way to think about the strategies:
| If you split… | You are probably using… | Main cost |
|---|---|---|
| samples | DP | gradient synchronization |
| parameters | FSDP / HSDP / LSDP | parameter all-gather + grad reduce-scatter |
| matrix dimensions | TP | collectives around GEMMs |
| sequence tokens around attention | SP / Ulysses | all-to-all before and after attention |
| long-context positions | CP | distributed attention + output reorder |
| experts | EP | token dispatch + combine |
| layers | PP | stage-to-stage activation and gradient transport |
| modality encoders | parallel encoder | all-to-all redistribution of encoder inputs |
For each strategy below, keep three questions in mind:
- What does each rank compute locally?
- What data must move between ranks?
- Can communication happen while useful compute is still running?
3.1 Communication Schemes Primer
Most beginner confusion in distributed training comes from a small set of communication words that keep reappearing. The same collective can mean different things depending on whether the model is split by parameters, tokens, experts, or layers.
| Communication scheme | Beginner meaning | Mariana strategies that use it | Representative Mariana code paths |
|---|---|---|---|
| Process group / device mesh | a subset of ranks that cooperate on one split dimension | all strategies | mariana/mariana/distributed/parallel_state.py:45-232, mariana/mariana/distributed/parallel_state.py:162-197 |
| All-gather / unshard | every rank contributes one shard and receives the full tensor view | FSDP, TP, parallel encoder planning/output handling | mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:564-585, mariana/megatron/core/tensor_parallel/layers.py:22-36, mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:22, mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:186-255 |
| Reduce-scatter | ranks first sum contributions together, then each rank keeps only one shard of the summed result | FSDP/HSDP, TP | mariana/mariana/trainer/strategy/fsdp2_torch.py:456-518, mariana/megatron/core/tensor_parallel/layers.py:22-36 |
| All-reduce | ranks sum values and everyone keeps the same full summed result | DP/HSDP, TP, some grad-sync paths | mariana/megatron/core/tensor_parallel/layers.py:22-36, mariana/megatron/dp_communication.py:27-96 |
| All-to-all | each rank sends a different slice to every other rank, so ownership changes rather than becoming more replicated | SP/Ulysses, EP, parallel encoder | mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:49-170, mariana/tasks/gpt2/m12_pretrain/model/bigop/helper_funcs.py:325-420, mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:186-255 |
| Point-to-point send/recv | one stage or rank sends directly to one other stage or rank | PP, some shortcut transport paths | mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:206-285, mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:84-175 |
| Shuffle / reorder / metadata planning | ranks compute or store extra layout metadata so distributed kernels can reconstruct original token order later | CP, SP, EP | mariana/mariana/models/utils/context_parallel_manager.py:17-245, mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:90-170, mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:310-392 |
| Dispatch / combine | a routed form of communication where tokens move according to the MoE router rather than a fixed tensor split | EP / MoE | mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:290-392, mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:12-76, mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:177-235 |
One useful beginner trick is to ask what the communication is trying to do:
- make a tensor more replicated: usually
all-gatherorall-reduce - make a tensor more sharded: usually
reduce-scatter - change which rank owns which slice: usually
all-to-all - move data along model depth: usually
send/recv - route tokens by data-dependent decisions: usually EP dispatch/combine
3.2 How To Choose A Parallelism Strategy
There is no universally best strategy. Mariana usually combines several dimensions because each one solves a different bottleneck.
DP / FSDP / HSDP / LSDP: Choose this first when model-state memory is the main bottleneck and you want the most general trainer support. Extra compute overhead is usually low to medium and comes from unshard / reshard bookkeeping, mixed-precision conversions, and sometimes recompute; communication overhead is medium to high because parameters are all-gathered before use and gradients are reduce-scattered after backward. Mariana reduces this cost with FSDP2 forward/backward prefetch chaining, async unshard, inflight reduce-scatter tuning, selective LSDP wrapping, and optional manual lifecycle control throughFSDPCtrl(mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:564-585,mariana/mariana/trainer/strategy/fsdp2_torch.py:456-518,mariana/mariana/trainer/strategy/lsdp.py:111-114,mariana/mariana/accelerate/mfsdp.py:117-214). Main strengths: broad applicability and strong memory savings. Main drawbacks: lots of communication around wrapped modules, and it does not by itself solve very wide layers or MoE routing hotspots.TP: Choose this when a single layer is too wide or too expensive for one GPU, especially in Megatron-style LLM stacks. Extra compute overhead is small because each rank still runs a local GEMM, but communication overhead is high frequency because many layers need all-gather, reduce-scatter, or all-reduce around every forward/backward. Mariana reduces this with fused TP paths,mem_efficient_column_parallel, Megatron overlap flags, and especially FLUX fusedAGKernel/GemmRSoperators (mariana/megatron/core/tensor_parallel/layers.py:22-36,mariana/mariana/models/stable/m11/tp.py:35-38,flux/python/flux/cpp_mod.py:107-138). Main strengths: excellent width scaling and good math efficiency. Main drawbacks: sensitive to interconnect latency and more rigid than FSDP when shapes or topology are awkward.SP / Ulysses / DSP: Choose this when attention activation memory or long sequences are the main problem, but you do not yet need full context parallelism. Extra compute overhead comes from tensor transpose, packing, permutation, and head/token reshaping; communication overhead usually comes from twoall_to_all_singlestyle exchanges around attention. Mariana reduces this with async A2A handles,AsyncAll2AllSingle, post-processing overlap, functional collectives, and FLUX A2A fusion operators (mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:49-170,mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:22-73,mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:203-204,flux/python/flux/cpp_mod.py:162-181). Main strengths: cuts activation memory without fully sharding weights. Main drawbacks: irregular layouts, divisibility constraints, and a confusing split between Ulysses-style SP and Megatron DSP if you do not keep the two codepaths separate.CP: Choose this when sequence length is so large that regular SP/Ulysses is not enough or the decomposition needs to follow context chunks instead of head swaps. Extra compute overhead comes from distributed-attention bookkeeping, rmpad/origin conversions, and output reconstruction; communication overhead is usually backend-specific and can involve ring or blockwise exchange patterns rather than one simple collective. Mariana reduces this mainly through thedist_attn.context_parallelbackend and theContextParallelManagermetadata container (mariana/mariana/models/utils/context_parallel_manager.py:17-245). Main strengths: one of the few workable ways to scale extreme long-context workloads. Main drawbacks: the most complex attention layout in this guide, and the generic trainer path still leavescp_sizeat1.EP: Choose this when MoE experts or routed FLOPs dominate the model. Extra compute overhead comes from routing, token packing/unpacking, load-balance bookkeeping, and expert-local reordering; communication overhead is often the highest of all dimensions because tokens must be dispatched to expert-owning ranks and then combined back. Mariana reduces this with expert prefetch under FSDP/FSDP2, chunked-overlap fused MoE registration, FLUX dispatch/combine plus grouped GEMM, and specialized GB backends such as HybridEP and FlashComm with dedicated streams and non-blocking paths (mariana/mariana/models/multimodal/parallel/fsdp/ep_prefetch.py:15-77,mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:337-340,mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:246-247,mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:290-392,mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:60-69). Main strengths: scales expert count and memory very effectively. Main drawbacks: very communication-sensitive, sensitive to routing imbalance, and the backend choice matters much more than in simpler dimensions.PP: Choose this when model depth or per-stage memory is the key issue and you can tolerate an assembly-line schedule. Extra compute overhead comes from pipeline bubbles, stage imbalance, and microbatch scheduling; communication overhead is stage-boundary send/recv on every microbatch, but only between neighboring stages rather than the whole group. Mariana reduces this withSchedule1F1B, stage splitting helpers, Megatron pipeline groups, and optional FLUXAsyncSendRecvtransport (mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:206-285,mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:84-175). Main strengths: scales depth cleanly and can lower per-rank activation memory. Main drawbacks: bubble overhead, stage-balance sensitivity, and harder debugging than purely data-parallel training.Parallel encoder: Choose this only when multimodal encoders are imbalanced enough that some ranks idle while others do heavy vision/audio encoding. Extra compute overhead comes from planning, batch redistribution, and rebuilding embeddings; communication overhead comes from planning all-gathers plus all-to-all redistribution of encoder inputs and outputs. Mariana reduces this not by one special kernel, but by planner and balance-policy logic that moves encoder work to the right ranks (mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:37-255,mariana/tasks/multimodal/configs/model/default.yaml:11-14,mariana/tasks/multimodal/configs/model/default.yaml:42-43,mariana/tasks/multimodal/configs/model/default.yaml:86-88). Main strengths: better MFU and less multimodal imbalance. Main drawbacks: extra A2A traffic and planner complexity, and it is irrelevant for text-only workloads.
In practice, the most common Mariana choice order is:
- pick the composition framework first (
MegatronorFSDP2/FSDP2Torch/LSDP) - solve the largest bottleneck first (state memory, layer width, sequence length, MoE routing, depth, or multimodal imbalance)
- only then add extra dimensions one by one
4. Strategy-By-Strategy Deep Dive
4.1 Data Parallel And The FSDP Family
Intuition
|
|
What problem it solves
This family is mostly about memory scaling, not raw communication avoidance.
- plain DP replicates the whole model
- FSDP shards model parameters, gradients, and optimizer states
- HSDP is FSDP plus an outer replicate dimension
- LSDP only shards selected blocks instead of the whole model
Computation pattern
- each rank still runs the same layer computations on different data
- the main difference is when full parameters are materialized
- FSDP-style strategies often unshard right before a layer/block and reshard after backward
Communication pattern
Main collectives are:
- parameter all-gather or unshard before forward
- gradient reduce-scatter after backward
- optional outer-group all-reduce in HSDP
The group construction comes from Mariana’s device mesh and flattened mesh views in mariana/mariana/distributed/parallel_state.py:162-197.
Mariana implementation
Core trainer and mesh setup:
mariana/mariana/trainer/base_trainer.py:460-557mariana/mariana/distributed/parallel_state.py:45-232
Main strategy implementations:
- FSDP2 (VeScale):
mariana/mariana/trainer/strategy/fsdp2/strategy.py:37-195 - FSDP2 pure wrap path:
mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:551-589 - FSDP2Torch:
mariana/mariana/trainer/strategy/fsdp2_torch.py:57-518 - legacy FSDP1:
mariana/mariana/trainer/strategy/fsdp.py - LSDP:
mariana/mariana/trainer/strategy/lsdp.py:79-255
Related manual-control helper:
mariana/mariana/accelerate/mfsdp.py:117-214exposesFSDPCtrl, which is not a new parallel dimension but is useful when Mariana wants explicit unshard / prefetch / reshard control.
Compute kernels and communication methods
This strategy family does not introduce one Mariana-specific compute kernel. It wraps the model’s normal kernels.
Typical compute kernels still come from:
- FlashAttention / FA3 / FA4
- fused MLP kernels
- MoE kernels
- normalization kernels
Communication methods are mostly supplied by:
vescale.parallel.fsdp2- native
torch.distributed.fsdp.fully_shard - legacy
torch.distributed.fsdp.FullyShardedDataParallel lsdpsharding and prefetch logic
Overlap strategy in Mariana
This is one of the most important parts of the FSDP story.
FSDP2 overlap mechanisms:
- adjacent FSDP modules are linked with
set_modules_to_forward_prefetch(...)andset_modules_to_backward_prefetch(...) model._set_unshard_async_op(True)enables non-blocking unshard
Source:
mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:564-585
FSDP2Torch overlap mechanisms:
- similar forward/backward prefetch chaining
- configurable
_max_num_inflight_reduce_scatters
Source:
mariana/mariana/trainer/strategy/fsdp2_torch.py:456-518
LSDP overlap mechanisms:
lsdp.PrefetchOptionsallows forward/backward prefetch tuning
Source:
mariana/mariana/trainer/strategy/lsdp.py:111-114
VLM versus LLM
- VLM: FSDP2 / FSDP2MP is the dominant pattern
- LLM: either Megatron or FSDP2/FSDP2Torch depending on model family and deployment choice
Does FLUX support this?
Short answer: not as a full strategy.
FLUX is not a general FSDP or ZeRO runtime. It does not own:
- parameter sharding
- optimizer-state sharding
- param lifecycle and unshard scheduling
- full gradient synchronization policy
What FLUX does support in this neighborhood:
- utility communication helpers such as
flux.InplaceCast, which Mariana uses inmariana/megatron/dp_communication.py:27-96 - fused TP or EP operators that may live inside an otherwise FSDP-driven model
If you wanted this in FLUX
FLUX would need a new layer above its current operator set:
- a sharded-parameter runtime
- parameter all-gather / grad reduce-scatter state machines
- persistent communication buffers keyed by module/block
- prefetch and reshard policies
- optimizer-state sharding and checkpoint integration
In practice, this is closer to building a VeScale/FSDP backend than extending one operator.
4.2 Tensor Parallel (TP)
Intuition
|
|
TP is used when one layer is too large or too expensive to keep entirely on one rank.
Computation pattern
The usual patterns are:
- column parallel linear: split output columns across ranks
- row parallel linear: split input rows across ranks
- vocab parallel embedding: split vocabulary across ranks
Each rank performs a local GEMM on its weight shard.
Communication pattern
Because each rank only computes a partial result, TP needs collectives around the layer:
- all-gather
- reduce-scatter
- all-reduce
- sometimes all-to-all in more specialized attention paths
Megatron TP wiring is centered in:
mariana/megatron/core/tensor_parallel/layers.py:16-45mariana/megatron/core/tensor_parallel/layers.py:196-260mariana/megatron/core/tensor_parallel/layers.py:973-1390
Mariana implementation
Megatron TP:
ColumnParallelLinear:mariana/megatron/core/tensor_parallel/layers.py:973RowParallelLinear:mariana/megatron/core/tensor_parallel/layers.py:1215
Stable / big-op TP examples:
mariana/mariana/models/stable/m11/tp.py:1-159
FSDP2Torch can also create TP-like shard plans:
mariana/mariana/trainer/strategy/fsdp2_torch.py:106-145mariana/mariana/trainer/strategy/fsdp2_torch.py:437-440
Compute kernels and communication methods
Compute side often uses:
- local GEMM wrappers from
mason.lego.gemm - fused TP blocks in Megatron and big-op text paths
- standard attention / MLP kernels after the local shard is made visible
Communication side uses:
timed__all_gather_basetimed__reduce_scatter_basetimed_all_reducetimed_all_to_all
Source:
mariana/megatron/core/tensor_parallel/layers.py:22-36
Overlap strategy in Mariana
This is where FLUX is especially relevant.
The best TP overlap strategy is to fuse the collective with the GEMM instead of doing:
- collective
- wait
- GEMM
Mariana already uses this style in selected paths via FLUX:
flux.AGKernelfor all-gather + GEMMflux.GemmRSfor GEMM + reduce-scatter
Relevant FLUX exports:
flux/python/flux/cpp_mod.py:107-138flux/src/pybind/gemm_rs.cc:28-102flux/src/pybind/ag_gemm.cc:34-149
The m11 TP path is a good Mariana example of TP-oriented fusion:
mariana/mariana/models/stable/m11/tp.py:35-38
VLM versus LLM
- LLM: TP is one of the core scaling dimensions in Megatron-style training
- VLM: TP is less universal, but it appears in selected text towers, big-op paths, and FSDP2Torch shard-plan usage
Does FLUX support this?
Yes. This is one of FLUX’s strongest areas.
The high-level FLUX view is explicit in CommOpEnum:
AllGatherReduceScatterAGKernel
Source:
flux/include/flux/flux.h:466-478
If you wanted more in FLUX
FLUX already has the right substrate. The next improvements would be:
- cover more Mariana TP call sites with unified wrappers
- expand inter-node TP coverage where needed
- expose more attention-specific TP fused paths through a stable Mariana adapter
4.3 Sequence Parallel (SP / Ulysses / DSP)
Intuition
|
|
This reduces per-rank activation memory for attention-heavy models.
Computation pattern
SP usually keeps:
- only a slice of sequence on each rank before or after attention
- but temporarily rearranges data so each rank can run attention on the required head subset
In ViT-style SP, image patches are split.
In LLM-style Ulysses, attention tensors are rearranged around head and token dimensions.
Communication pattern
The core collective is usually all_to_all_single.
Main Mariana paths:
- ViT / multimodal path:
mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:22-97mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:206-308
- LLM fused-op path:
mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:49-170
Compute kernels and communication methods
Communication methods:
torch.distributed.all_to_all_singletorch.distributed._functional_collectives.all_to_all_single
Local helper kernels:
bumi.kernel.index_copybumi.kernel.index_copy_reversed
Source:
mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:4-6mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:105-109mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:139-143
The actual attention compute is still done by the usual attention kernels after the data reshuffle.
Overlap strategy in Mariana
LLM fused-op path:
- asynchronous A2A is launched
- Mariana does permutation post-processing after the handle completes
- this overlaps part of the communication with surrounding computation and staging
Source:
mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:49-87mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:90-170
ViT path:
- there is explicit async support through
AsyncAll2AllSingle - but the code also contains a TODO saying communication could be fused with QKV projections to overlap even better
Source:
mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:22-73mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:203-204
VLM versus LLM
- VLM: used around visual encoder attention
- LLM: used in Ulysses/fused-op text paths and distributed-sequence-parallel big-op paths
Does FLUX support this?
Yes, strongly.
FLUX treats Ulysses/SP as a first-class communication family:
PreAttnAllToAllTransposePreAttnQKVPackAllToAllPostAttnAllToAllTransposePostAttnAllToAllOnlyAllToAllSingle
Source:
flux/include/flux/flux.h:466-478
Exported FLUX operators include:
flux.GemmAllToAllTransposeflux.AllToAllTransposeGemmflux.All2AllSingleflux.AllToAllSingleGemm
Source:
flux/python/flux/cpp_mod.py:162-181flux/src/pybind/flux_coll_op.cc:383-417
If you wanted more in FLUX
What FLUX still needs for cleaner Mariana integration is not raw operator support. It needs:
- a single Mariana-side SP adapter that can choose Torch A2A or FLUX A2A
- better support for ragged sequence layouts and irregular split metadata
- tighter integration with ViT-specific SP code paths
4.4 Context Parallel (CP)
Intuition
|
|
CP is for very long contexts where even SP is not enough or is not the right decomposition.
Computation pattern
- each rank holds only part of the full sequence positions
- attention or loss processing is distributed across those sequence chunks
- outputs, labels, and masks may need to be shuffled back into original order
Communication pattern
This is more complex than simple SP:
- shuffle from original order into distributed layout
- distributed attention or distributed output reconstruction
- shuffle back to original order
Mariana’s generic CP metadata container is:
mariana/mariana/models/utils/context_parallel_manager.py:17-177
Output reconstruction helpers include:
get_cp_output(...)fsdp_cp_output_process(...)
Source:
mariana/mariana/models/utils/context_parallel_manager.py:178-245
Mariana implementation
Generic manager:
mariana/mariana/models/utils/context_parallel_manager.py:17-245
Task/model usage examples:
mariana/tasks/gpt2/m12_pretrain/model/m12_megatron.pymariana/tasks/gpt2/m13_pretrain/model/m13_megatron.pymariana/tasks/gpt2/m14_pretrain/model/m14_megatron.py
These paths import dist_attn.context_parallel, which is the main CP compute/communication backend.
Compute kernels and communication methods
Communication and distributed-attention logic come mainly from:
dist_attn.context_parallel- Megatron CP process groups
The manager itself mostly stores the metadata needed by those kernels and reorder helpers.
Overlap strategy in Mariana
Compared with TP/SP/EP, Mariana exposes less CP-specific overlap policy at the trainer level.
The overlap is mostly pushed down into:
- the distributed attention backend
- ring or blockwise CP kernels
- output shuffle and rmpad routines
So the important practical point is:
CP overlap exists, but it is mostly backend-driven, not a visible top-level Mariana scheduling feature in the same way as FSDP prefetch or Flux fused TP kernels.
VLM versus LLM
- LLM: CP is most meaningful here, especially for long-context text paths
- VLM: possible in principle, but less central in the current generic VLM trainer path
Does FLUX support this?
Not as a first-class CP strategy.
FLUX has building blocks such as:
- all-to-all
- P2P transport
- shared-memory buffer bootstrap
But it does not provide:
- a CP metadata manager
- distributed attention kernels for CP
- CP-specific reorder / loss reconstruction APIs
If you wanted this in FLUX
The cleanest approach would be:
- define a
ContextParallelLayoutobject - expose CP process-group aware attention kernels
- provide shuffle/origin conversion helpers
- integrate with Mariana through
ContextParallelManager
If the goal is only to accelerate the backend, it may be simpler to keep dist_attn as the CP engine and use FLUX only for specific transport pieces.
Attention parallelism by scale and input type
The most common beginner mistake is to think “attention parallelism” is one knob. In practice, Mariana uses different attention-parallel decompositions for different sequence scales and input structures.
As a rule of thumb:
- use TP first when the main problem is layer width
- use SP / Ulysses or DSP when the main problem is attention activation memory
- move to CP only when the problem is extreme context length or very irregular long-context handling
| Regime | Typical inputs | Main bottleneck | Usually best first move in Mariana | Why it helps | Main risk / implementation detail |
|---|---|---|---|---|---|
| short or moderate text with very wide hidden states | dense text, modest sequence lengths | layer width and per-layer parameter size | TP | splitting QKV and MLP weights is often cheaper than paying two A2As per attention layer | if TP spans slow links, frequent all-gather / reduce-scatter can dominate (mariana/megatron/core/tensor_parallel/layers.py:22-36, mariana/megatron/core/tensor_parallel/layers.py:973-1390) |
| long dense text | regular text batches around the long-sequence regime | attention activation memory | Ulysses SP on FSDP paths, DSP on Megatron paths | SP keeps the full sequence local only when needed and pays mostly two A2As around attention | short sequences may not amortize those A2As well (mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:49-170, mariana/mariana/utils/megatron/init.py:25-29) |
| extreme long context | very long text or packed long-context batches | total context length and distributed KV-style attention work | CP | CP splits sequence positions directly and pushes the complexity into the distributed-attention backend | highest layout complexity and weakest generic-trainer support (mariana/mariana/models/utils/context_parallel_manager.py:17-245, mariana/mariana/trainer/base_trainer.py:517-530) |
| vision patches / ViT attention | patch tokens, sometimes with variable image geometry | visual-attention memory and patch count | ViT SP adaptor | patch-aware SP is a natural fit for visual encoder attention | head divisibility or padding becomes a real implementation concern (mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:206-233) |
| multimodal or ragged batches | audio+vision+text, use_rmpad, dyn_bsz, mixed sequence lengths |
irregular layouts and metadata handling | backbone-dependent: SP/DSP for the backbone, parallel encoder for encoder imbalance, CP only when context forces it | keeps encoder balancing separate from attention splitting | metadata and padding logic become as important as the kernel itself (mariana/mariana/models/utils/context_parallel_manager.py:17-38, mariana/mariana/utils/megatron/init.py:144-156, mariana/tasks/seed2/megatron/ARCHITECTURE.md:360-383) |
Three implementation details matter a lot here:
-
Ulysses and DSP are not the same code path.
Megatron DSP is enabled whendistributed_sequence_parallel_size > 1in Megatron init (mariana/mariana/utils/megatron/init.py:25-29). Separately,dist_seq_par_use_all2allinEnvConfigmeans DSP may choose an all-to-all transport mode that looks Ulysses-like, but it is still a Megatron-path feature, not the same thing as FSDPsp_size(mariana/mariana/utils/env_config.py:250-251). -
Ragged inputs change the attention-parallel story.
ContextParallelManagerstorescu_seqlens,cu_seqlens_rmpad, andcu_seqlens_splited, which is exactly the kind of metadata you need once batches are no longer regular (mariana/mariana/models/utils/context_parallel_manager.py:17-38). In Megatron init,variable_seq_lengthsis tied touse_rmpadordyn_bsz, but it is disabled whenoverlap_p2p_commis on (mariana/mariana/utils/megatron/init.py:144-156). That is a concrete example of a real trade-off between irregular inputs and overlap policy. -
Vision attention has a different constraint surface from text attention.
In the fused LLM Ulysses path, the code assertsnum_head % sp_size == 0(mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses.py:64-76). In contrast, the ViT adaptor can optionallyallow_head_padding, which means non-divisible head counts can still run by padding the head dimension before A2A (mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:206-233).
4.5 Expert Parallel (EP / MoE)
Intuition
|
|
EP is the most communication-sensitive strategy in MoE models.
Computation pattern
- each rank owns only a subset of experts
- tokens are routed to one or more experts
- each rank runs expert MLP compute only for the tokens it receives
- outputs are then returned and merged back
Communication pattern
This is not a simple all-reduce. EP needs:
- routing metadata
- per-expert or per-rank token counts
- dispatch layout
- token movement to expert owners
- combine movement back to the original token order
Mariana uses several different EP backends and styles depending on model and path.
One concrete integration to remember is the GB/Blackwell m12 big-op path. ParallelM12ForCausalLM always initializes FluxManager first and then, when use_blackwell is true, switches on megatron_config.ep_impl to initialize either DeepEPManager (hybrid_ep) or FlashCommManager (flash_comm) (mariana/tasks/gpt2/m12_pretrain/model/m12_megatron.py:611-637). In other words, Mariana often uses FLUX as a broad substrate while HybridEP and FlashComm act as specialized EP backends.
Mariana implementation
FSDP2 MP mesh composition:
mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:83-119
FSDP/FSDP2 expert prefetch:
mariana/mariana/models/multimodal/parallel/fsdp/ep_prefetch.py:1-97
Seed-kernel EP and fused MoE registrations:
mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:330-346mariana/mariana/models/multimodal/parallel/common/sac_offload.py:526-529
Blackwell EP backend selection:
mariana/tasks/gpt2/m12_pretrain/model/m12_megatron.py:611-637
FLUX-backed EP manager:
mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:177-235
HybridEP-backed EP manager:
mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:231-392
FlashComm-backed EP manager:
mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:7-76
FlashComm call wrappers:
mariana/tasks/gpt2/m12_pretrain/model/bigop/helper_funcs.py:325-420
Text-model fused MoE entry points:
mariana/mariana/models/text/m12/modeling_m12_dev.py:2764-2774mariana/mariana/models/text/m13/modeling_m13_dev.py:2778-2789mariana/mariana/models/text/m14/modeling_m14_dev.py:2582-2590
Compute kernels and communication methods
Mariana’s EP stack can call several families of kernels:
seed_fused_moeSeedFusedExpertParallelMoERingExpertParallelMoESeedChunkedOverlapFusedMoeExpertFunctionHybridEPGPUFusedMoeFunctionflash_comm.ep.EPKernels- FLUX operators such as
DisScatterForward,DisScatterBackward,All2AllSingle,GemmGroupedV3AGScatter,GemmGroupedV3GatherRS
Relevant EP-oriented FLUX exports:
flux/python/flux/cpp_mod.py:140-181flux/src/pybind/flux_coll_op.cc:83-333
On the GB m12 path, HybridEP enters Mariana through deep_ep.HybridEPBuffer plus dispatch_with_permute(...) and combine_with_unpermute(...) inside DeepEPManager (mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:250-392). FlashComm enters through flash_comm.ep.EPKernels constructed in FlashCommManager (mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:12-76).
Overlap strategy in Mariana
This is the richest overlap area in Mariana outside TP.
Mariana uses several kinds of EP overlap:
-
expert prefetch under FSDP/FSDP2
The next experts can be unsharded or prepared before they are needed.Source:
mariana/mariana/models/multimodal/parallel/fsdp/ep_prefetch.py:15-77
-
chunked overlap fused MoE
Mariana explicitly registersSeedChunkedOverlapFusedMoeExpertFunction, meaning the EP backend can overlap chunks of communication and expert compute.Source:
mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:337-340
-
FLUX dispatch/combine plus grouped GEMM
FLUX lets Mariana keep token communication close to the expert GEMM path instead of treating them as fully separate layers. -
specialized backend overlap
flash_commandHybridEPpush even more of the dispatch/combine pipeline into specialized kernels. In practice,DeepEPManagercreates a dedicated CUDA stream and exposes non-blocking dispatch/combine (mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:246-247,mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:290-342), whileFlashCommManagercan placeEPKernelson either the default stream or a dedicated comm stream (mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:60-69).
EP size versus topology
EP size is not just a model hyperparameter. It directly changes:
- how many experts each rank stores
- how many tokens each rank may receive
- whether the hot path stays in HBM, within one NVLink domain, across a whole node, or across nodes
Two first-order formulas from the in-tree implementations are useful:
num_local_experts = num_global_experts / ep_sizein Mariana’sDeepEPManager(mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:231-244)- FlashComm sizes worst-case receive buffers as roughly
max_m * topk * capacity_coeff, rounded up to a dispatch buffer size (triton-distributed-flashcomm/FlashComm/python/flash_comm/ep/ep_context.py:156-163)
So increasing EP size usually helps memory first, but it also tends to:
- reduce local expert count per rank
- reduce expert-local GEMM granularity
- increase token-routing fan-out and fan-in
That means EP gets better until the saved expert memory and compute outweigh the cost of smaller grouped GEMMs and larger token movement.
| EP scope | Main computation effect | Main communication effect | Usually good when | Main risk |
|---|---|---|---|---|
EP = 1 |
all experts are local; no dispatch/combine overhead; maximum expert memory on each rank | no EP communication | model is small enough or MoE is not the bottleneck | expert weights and optimizer state stay concentrated on one GPU |
| EP inside one NVLink domain | fewer local experts and lower per-rank expert-state memory; local grouped GEMMs become smaller | dispatch/combine stays on fast peer-visible intra-node transport | this is usually the first EP size increase to try | if tokens per expert become too small, grouped GEMM utilization drops |
| EP across a full node | local expert memory drops further and more experts fit globally | fan-out, barriers, and buffer pressure grow with rank count, but still stay node-local | one node has strong mutual connectivity and expert memory is the main problem | if the node is actually multiple connectivity domains, “full node EP” may already cross slower links |
| EP across nodes | smallest expert-state footprint per rank and largest global expert pool | network becomes dominant; hierarchy and overlap become mandatory | expert count or expert memory forces scale-out | flat token movement over RDMA can erase the compute win if routing is imbalanced |
HybridEP is the clearest in-tree example of topology-aware EP sizing. HybridEPBuffer detects accessible ranks, allows an override through NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN, and then derives local_rank, node_rank, and num_of_nodes from that topology (hybrid-ep/deep_ep/hybrid_ep_buffer.py:61-90). That is the right mental model for choosing EP groups: start with one fast domain, then expand only when memory or expert count forces you to.
The backend itself also acknowledges that single-node and multi-node EP are different regimes. In Configurer, HybridEP defaults sms_dispatch and sms_combine to 24 for one node but only 8 for multi-node (hybrid-ep/csrc/hybrid_ep/config.cuh:314-316). That is a strong hint that once EP spans nodes, the bottleneck shifts away from “throw more SMs at dispatch” and toward topology-aware communication balance.
FlashComm makes the opposite trade-off explicit: it builds a strong single-node EP context around SymmetricTensor buffers, but asserts config.nnodes == 1 and leaves dispatch_internode / combine_internode unimplemented (triton-distributed-flashcomm/FlashComm/python/flash_comm/ep/ep_context.py:184-185, triton-distributed-flashcomm/FlashComm/python/flash_comm/ep/ep_kernels.py:361-367).
SOTA-style EP backend patterns you can repeat
The most advanced EP backends in this checkout all solve the same six subproblems, even though they package them differently:
- topology-aware group formation
- persistent peer-visible buffers
- GPU-side metadata/layout preprocessing
- dispatch/combine kernels with as little extra permute work as possible
- compute/communication overlap around expert GEMM
- hierarchical transport when EP spans nodes
| Design pattern | Why it matters | HybridEP | FlashComm | FLUX |
|---|---|---|---|---|
| topology-aware group sizing | keeps the hottest token traffic on the fastest links | explicit NVLink-domain detection and node_rank / local_rank (hybrid-ep/deep_ep/hybrid_ep_buffer.py:61-90) |
assumes one node and stops there (triton-distributed-flashcomm/FlashComm/python/flash_comm/ep/ep_context.py:184-185) |
operators accept local-world information, but there is no full EP topology runtime (mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:177-214) |
| persistent peer-visible buffers | avoids per-step allocation and enables direct remote access | ExtendedMemoryAllocator plus runtime buffers |
SymmetricTensor / EPContext own the hot buffers (triton-distributed-flashcomm/FlashComm/python/flash_comm/ep/ep_context.py:152-215) |
shared-memory / team bootstrap plus operator-owned comm buffers (flux/src/ths_op/flux_shm.cc:417-491) |
| GPU-side layout object | lets kernels consume routing metadata directly on GPU | explicit metadata preprocessing and handles (mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:310-392) |
explicit EPCommLayoutDesc and staged preprocessing (triton-distributed-flashcomm/FlashComm/python/flash_comm/ep/ep_kernels.py:369-389) |
mostly operator-centric tensor arguments rather than one public EP context object (flux/src/pybind/flux_coll_op.cc:83-333) |
| minimal extra permutation | avoids extra global-memory passes | fused dispatch_with_permute / combine_with_unpermute (mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:327-392) |
staged dispatch_postprocess / combine_preprocess around EPKernels (triton-distributed-flashcomm/FlashComm/python/flash_comm/ep/ep_kernels.py:382-389) |
primitives are available, but no single fused EP product shape like HybridEP (flux/python/flux/cpp_mod.py:140-181) |
| compute/communication overlap | hides dispatch/combine behind expert compute | dedicated stream + non-blocking dispatch/combine (mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:246-247, mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:290-342) |
comm stream or default stream, plus GPU layout/transport pipeline (mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:60-69) |
grouped GEMM and dispatch/all-to-all operators are designed to sit close to compute (mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:177-235) |
| cross-node hierarchy | the only realistic way to scale EP beyond one fast domain | explicit NVLink + RDMA design | not implemented | possible to build from primitives, but not a first-class EP subsystem |
If you want to reproduce a strong EP backend from scratch, the safest recipe is:
- start with single-node EP inside one NVLink domain
- implement a persistent layout/context object and peer-visible buffers first
- keep routing metadata and permutation on GPU
- overlap token movement with grouped expert GEMM on separate streams or fused operators
- only after that add cross-node hierarchy with explicit topology detection
This is exactly why the three reference projects are complementary:
- FlashComm is the cleanest single-node EP teaching example
- HybridEP is the clearest topology-aware and multi-node EP reference
- FLUX is the clearest example of embedding EP primitives inside a broader fused communication/compute stack
VLM versus LLM
EP is important in both:
- VLM: often combined with FSDP2MP and multimodal routing
- LLM: often combined with Megatron or big-op text MoE paths
- Omni / multimodal audio-text: usually follows the Megatron LLM template but inherits extra encoder-side concerns from the VLM stack
- RM: inherits the EP story of the wrapped backbone; the reward head does not create a new EP pattern
Does FLUX support this?
Yes, but not completely enough to replace every EP backend Mariana uses.
What FLUX supports well:
- MoE dispatch/combine primitives:
flux.DisScatterForwardflux.DisScatterBackward
- expert all-to-all:
flux.All2AllSingle
- grouped-GEMM MoE variants:
flux.GemmGroupedV3AGScatterflux.GemmGroupedV3GatherRS
Sources:
flux/python/flux/cpp_mod.py:140-181flux/src/pybind/flux_coll_op.cc:83-333flux/src/moe_ag_scatter/ths_op/gemm_grouped_v3_ag_scatter.cc
What FLUX still does not fully provide at the same level as specialized EP systems:
- one Mariana-wide EP abstraction
- a clear layout/handle object like FlashComm’s
EPCommLayoutDesc - a production-oriented multi-node EP stack comparable to HybridEP
- fused permute/unpermute-style EP APIs
- the same kind of graph-friendly specialized EP runtime that HybridEP emphasizes
How FLUX, HybridEP, and FlashComm explain the EP design space
This is where flux, hybrid-ep, and triton-distributed-flashcomm/FlashComm are useful reference points.
HybridEP shows what a specialized production EP backend looks like:
- hierarchical NVLink + RDMA transport
- explicit buffer/config/runtime objects
- fused permute/unpermute
- runtime JIT specialization
FlashComm shows what a compact single-node EP backend looks like:
- symmetric-memory buffer model
- explicit
EPContextandEPKernels - GPU-side layout construction
- TMA / mbarrier dispatch-combine pipeline
FLUX shows what an operator-centric EP substrate looks like:
DisScatterForward/DisScatterBackward/All2AllSingleas reusable primitives- grouped-GEMM MoE operators that sit close to the dispatch path
- one shared stack that also serves TP and SP/Ulysses
- less EP-specific product shape than FlashComm or HybridEP
So, if FLUX needs stronger EP support, the two most realistic paths are:
-
borrow FlashComm’s structure for single-node EP
- add a first-class layout object
- separate layout planning from transport kernels
- make dispatch/combine handles explicit
-
borrow HybridEP’s structure for large-scale EP
- add topology-aware hierarchical transport
- add runtime buffer/config abstractions
- add graph-friendly non-blocking dispatch/combine
That leads to a very practical topology rule:
- if EP can stay inside one fast node-local domain, a FlashComm-style or FLUX-plus-layout design is often enough
- if EP must cross nodes, HybridEP-style hierarchy is the reference design to study
- if you want one shared stack across TP, SP, and EP, FLUX is the broadest substrate, but it still needs more EP-specific runtime structure as EP size grows
4.6 Pipeline Parallel (PP)
Intuition
|
|
PP is about splitting the model by depth rather than by weights or tokens.
Computation pattern
- Each rank owns a stage of the model
- Activations move from one stage to the next in forward
- Gradients move back in the reverse direction
- Multiple microbatches are interleaved to keep stages busy
Communication pattern
The core communication is stage-to-stage:
- Send activations forward
- Send gradients backward
This is why PP is tightly connected to microbatch scheduling.
Mariana implementation
FSDP 2 MP pipeline construction:
mariana/mariana/trainer/strategy/fsdp 2/_mpmd_fsdp. Py: 206-285
This code uses:
PipelineStagepipeline_module_splitSchedule 1 F 1B
Megatron strategy note:
MegatronStrategy.Backward ()is effectively a no-op because forward/backward scheduling happens inside the training step
Source:
mariana/mariana/trainer/strategy/megatron. Py: 408-430
Compute kernels and communication methods
PP itself does not define new math kernels. It orchestrates when ordinary layer kernels run.
Communication methods are:
- PyTorch or Megatron stage-to-stage P 2P
- VeScale pipeline stage wrappers
- Optional FLUX shortcut / send-recv transport in some Mariana model paths
Overlap strategy in Mariana
The main overlap is the 1 F 1 B schedule:
- One microbatch can be in forward on one stage while another is in backward on another stage
In other words, PP is itself an overlap strategy at the schedule level.
Mariana also has FLUX-backed transport helpers for shortcut or pipeline-adjacent movement:
flux. AsyncSendRecvis constructed inmariana/tasks/gpt 2/m 12_pretrain/model/flux_manager. Py: 84-175- Text model hooks also instantiate
AsyncSendRecvin:mariana/mariana/models/text/m 12/pipeline_hooks_m 12_dev. Pymariana/mariana/models/text/m 13/pipeline_hooks_m 13_dev. Py
VLM versus LLM
- LLM: PP is common in Megatron-scale training
- VLM: possible and supported, but less universal than FSDP 2 + SP + EP
Does FLUX support this?
Partially.
FLUX supports the transport primitive:
flux. AsyncSendRecv
Source:
flux/python/flux/cpp_mod. Py: 172-181flux/src/pybind/flux_coll_op. Cc: 421-467
But FLUX does not provide the whole PP runtime:
- Microbatch schedule
- Stage ownership
- Dependency graph
- Gradient timing rules
If you wanted this in FLUX
The clean split would be:
- Keep the PP scheduler in Mariana / Megatron / VeScale
- Let FLUX provide a standardized stage-transport backend
That is much cheaper and safer than trying to turn FLUX into a full PP framework.
4.7 Parallel Encoder (VLM / Omni Encoders)
Intuition
|
|
This is not a generic LLM strategy. It is a multimodal load-balancing strategy for vision/audio encoders.
Computation pattern
- Different ranks can be assigned different amounts of vision/audio work
- The assignment is based on estimated FLOPs
- only the selected ranks run each encoder’s heavy compute
Communication pattern
The parallel encoder first needs a plan:
- gather FLOP or sequence-length information across ranks
Then it redistributes encoder inputs:
- all-to-all input movement
- Gather/rebuild of final embeddings
Mariana implementation
Planner and configuration:
mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder. Py: 37-255
Factory:
mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:943-956
Model integration examples:
mariana/mariana/models/omni/omniv1/modeling_omniv1.pymariana/mariana/models/multimodal/allinone/modeling_allinone.pymariana/tasks/seed2/train_mm.py:364-370
Compute kernels and communication methods
Communication side includes:
all_gather_primitives(...)all_to_all_single(...)
Source:
mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder.py:22mariana/mariana/models/multimodal/parallel/fsdp/parallel_encoder. Py: 186-255
The compute kernels are the normal modality-encoder kernels:
- ViT kernels for vision
- audio encoder kernels for audio
Overlap strategy in Mariana
This strategy is more about load balancing than about explicit low-level comm/compute overlap.
The major gain is:
- Use communication to put work on the right ranks
- reduce total imbalance and idle time
So the main overlap effect is indirect: better balancing keeps more ranks busy.
VLM versus LLM
- VLM: important
- MegatronOmni / other multimodal encoder-heavy paths: also important when vision/audio encoders need balancing or redistribution
- LLM: not applicable in the same form
Does FLUX support this?
Not directly.
FLUX does not expose a ready-made “parallel encoder” abstraction.
However, FLUX could help with parts of it:
All2AllSinglefor redistributionAllToAllSingleGemmif a fused encoder-projection pattern becomes attractive
Source:
flux/python/flux/cpp_mod.py:177-181
If you wanted this in FLUX
The best design is not to move the whole strategy into FLUX. Instead:
- keep planning and FLOP accounting in Mariana
- use FLUX only as an optional high-performance redistribution backend
- expose a thin adapter for the specific all-to-all data motion
4.8 Megatron And FSDP2MP Are Composition Frameworks
It is worth stating this explicitly because many readers mix them up with “parallel dimensions.”
Megatron
Megatron is the LLM-oriented composition framework that typically manages:
- TP
- PP
- DP
- Optional CP
- optional EP
Source:
mariana/mariana/trainer/strategy/megatron.py:318-430
FSDP2MP
FSDP2MP is Mariana’s VeScale-oriented composition path that combines:
- FSDP
- EP / OE
- optional PP
- module-parallel plans
Source:
mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:32-120mariana/mariana/trainer/strategy/fsdp2/_mpmd_fsdp.py:206-345
This is why the right question is often not:
“Does Mariana use Megatron or EP?”
but rather:
“Which framework composes which dimensions for this model family?”
5. FLUX Support Matrix
The table below compresses the previous sections into a planning-friendly view.
| Parallelism area | FLUX support today | How FLUX supports it | Biggest missing piece |
|---|---|---|---|
| DP / DDP | Weak | utility ops like InplaceCast in mariana/megatron/dp_communication. Py: 27-96 |
no DP runtime or fused gradient-sync framework |
| FSDP / HSDP / LSDP | Weak | not a first-class feature | no param sharding / all-gather / grad-RS state machine |
| TP | Strong | AGKernel, GemmRS, AllGatherOp, inter-node variants |
broader wrapper coverage across all Mariana TP call sites |
| SP / Ulysses | Strong | GemmAllToAllTranspose, AllToAllTransposeGemm, All2AllSingle, AllToAllSingleGemm |
unified adapter for irregular or modality-specific SP cases |
| CP | Weak | only low-level building blocks | no CP layout, distributed attention, or reorder runtime |
| EP | Medium to strong | DisScatterForward, DisScatterBackward, All2AllSingle, grouped-GEMM MoE ops |
no single Mariana-wide EP runtime abstraction; weaker than specialized EP stacks |
| PP | Partial | AsyncSendRecv transport |
no pipeline schedule / stage engine |
| Parallel encoder | None | could reuse A2A ops | no planner / modality-aware strategy |
6. Representative Workload Defaults In Mariana
There is not one global JSON file that defines Mariana parallelism. In practice, the defaults live in two places:
- Python config classes such as
MegatronConfig - Hydra YAML templates under
tasks/*/confortasks/*/configs
Read these as representative templates, not universal production values. Many real jobs override them heavily.
6.1 Conservative Packaged Defaults
mariana/mariana/models/text/config.py:1141-1205: the baseMegatronConfigconstructor defaults are intentionally conservative:tensor_parallel_size=1,pipeline_parallel_size=1,context_parallel_size=1,expert_parallel_size=1,expert_parallel_size_in_dp=1,num_layers_per_virtual_pipeline_stage=0,micro_batch_size=1,sequence_parallel=True,sequence_data_parallel_size=-1, anddistributed_sequence_parallel_size=-1. This is the cleanest “framework default” starting point for Megatron-style text backbones.mariana/tasks/gpt2/m12_pretrain/conf/default.yaml:497-582: a representative Megatron LLM YAML also starts withmegatron_tensor_parallel_size: 1,megatron_pipeline_parallel_size: 1,megatron_context_parallel_size: 1,megatron_expert_parallel_size: 1, andmegatron_expert_parallel_size_in_dp: 1, while still enabling many engineering optimizations such asmegatron_overlap_p2p_comm: true,megatron_overlap_dp_grad_comm: true,megatron_overlap_dp_param_comm: true,megatron_early_prefetch_dp_allgather: true, andmegatron_ep_impl: "flux". In other words, the default sizes are conservative, but overlap defaults may already be aggressive.mariana/tasks/gpt 2/unsup/conf/fsdp 2_default. Yaml: 124-130: a representative text FSDP 2 template usespp_size: 1,fsdp_size: -1,ep_size: 1,oe_size: 1, andsp_size: 1. This is Mariana’s common “no extra split unless you opt in” FSDP-family baseline.mariana/tasks/multimodal/configs/model/default.yaml:11-14,mariana/tasks/multimodal/configs/model/default.yaml:44,mariana/tasks/multimodal/configs/model/default.yaml:79: a representative multimodal model default keepsparallel_encoder.enable: false,network.context_parallel_size: -1, andvit.sp_size: 1. This means the generic model config does not assume parallel encoder or SP are enabled until a job opts in.mariana/tasks/multimodal/configs/trainer/lsdp_m12dev2_default.yaml:38-43: a representative multimodal trainer default usesfsdp_size: -1,ep_size: 1,oe_size: 1, andsp_size: 1. This matches the general Mariana pattern of conservative packaged defaults plus task-specific overrides later.mariana/tasks/rm/engine/fsdp/conf/fsdp2_default.yaml:89-97: reward-model FSDP defaults again usepp_size: 1,fsdp_size: -1,ep_size: 1,oe_size: 1, andsp_size: 1, which reinforces that RM usually inherits the backbone-oriented FSDP baseline rather than introducing a unique distributed recipe.
6.2 What 1 And -1 Usually Mean
1usually means no extra split along that dimension. Examples:tensor_parallel_size=1,pipeline_parallel_size=1,expert_parallel_size=1,sp_size=1.-1usually means unset, disabled, or “let outer logic decide” for that specific field, but it is not globally standardized. The clearest in-tree example isfsdp_size: -1 # default -1 means no use of hsdpinmariana/tasks/gpt 2/unsup/conf/fsdp 2_default. Yaml: 127-130.- Because of that, do not assume all
-1values mean the same thing. Read the inline comment or the consuming code for that field.
6.3 Real Tuned Jobs Often Look Very Different
mariana/tasks/multimodal/configs/ci/uniformloader_seed_m12v2_2b5_32k_muon_omni_runner_ci.yaml:181-190: a real multimodal/omni-style job example usesoe_size: 8,ep_size: 8,sp_size: 4, andfsdp_size: -1, plus activation offload. This is very different from the conservative1/-1packaged defaults.mariana/tasks/multimodal/configs/megatron_vlm/algorithmic/per_llm/m12/20b_600b_moe.yaml:40-47: a scaled Megatron VLM/MoE example usesmegatron_pipeline_parallel_size: 12,megatron_expert_parallel_size_in_dp: 16, andmegatron_num_layers_per_virtual_pipeline_stage: 1. This shows that large production-style jobs often lean hard into PP and EP-in-DP even when the framework default starts at1.
The main practical conclusion is:
- framework defaults in Mariana are usually conservative and readable
- real tuned workloads are usually compositions of several dimensions
- The best config to study depends on whether you want to understand the framework’s baseline or a scaled production recipe
7. Practical Recommendations
If your real goal is to understand “what to use where,” this is the simplest summary.
For VLM training
Start with this mental picture:
- Use FSDP 2 as the base memory-scaling strategy
- add SP around attention-heavy paths
- add EP if the model uses MoE
- add parallel encoder only for multimodal encoder load balancing
- use PP only when model depth or memory pressure justifies it
For LLM training
Think in two families:
- Megatron family
- Best when TP + PP + CP style scaling is the center of gravity
- FSDP family
- best when FSDP-style memory sharding is the center of gravity and TP is lighter or more selective
For other Mariana-supported families
Use the backbone-first rule:
- MegatronOmni / audio+vision+text
- start from the Megatron family mental model (
TP + PP + DP, plus optionalDSPandCP), then treat modality encoders as extra multimodal components rather than a new global strategy (mariana/tasks/seed2/megatron/ARCHITECTURE.md:1-3,mariana/tasks/seed2/megatron/ARCHITECTURE.md:372-383)
- start from the Megatron family mental model (
- Omni audio/speech or diffusion modules
- Usually inherit the omni/VLM composition and then rely on modality-specific fused kernels such as the USM fused frontend (
mariana/mariana/models/omni/acceleration/usm/fused_usm. Py: 16-111) rather than a brand-new distributed dimension
- Usually inherit the omni/VLM composition and then rely on modality-specific fused kernels such as the USM fused frontend (
- Reward models
- decide the backbone strategy first; RM layers wrap either text or VLM backbones (
mariana/tasks/rm/engine/megatron/rm_model.py:509-560), so DP/FSDP/Megatron choices still follow the underlying model more than the reward head
- decide the backbone strategy first; RM layers wrap either text or VLM backbones (
For FLUX planning
Use FLUX where it is already naturally strong:
- TP fusion
- SP all-to-all + compute fusion
- Selected EP token-routing and grouped-GEMM paths
- transport helpers for PP-adjacent movement
Do not assume FLUX should be forced into every distributed strategy. For FSDP and CP in particular, the better near-term plan is usually:
- keep the high-level strategy in Mariana / VeScale / Megatron
- Use FLUX only where it can replace a well-defined hot collective or fused comm+compute block
8. Key Source Files To Read Next
If you want the shortest path from this doc to the code, read these in order:
- Trainer and mesh setup
mariana/mariana/trainer/base_trainer.py:460-557mariana/mariana/distributed/parallel_state.py:45-232
- FSDP family
mariana/mariana/trainer/strategy/fsdp2/strategy.py:37-195mariana/mariana/trainer/strategy/fsdp2/_fsdp.py:551-589mariana/mariana/trainer/strategy/fsdp2_torch.py:374-518mariana/mariana/trainer/strategy/lsdp. Py: 79-255
- Megatron family
mariana/mariana/trainer/strategy/megatron.py:318-430mariana/megatron/core/tensor_parallel/layers.py:973-1390
- Sequence and context parallel
mariana/mariana/models/multimodal/parallel/common/encoder_seq_parallel.py:22-97mariana/mariana/models/multimodal/parallel/fsdp/fused_op/functional_ulysses. Py: 49-170mariana/mariana/models/utils/context_parallel_manager.py:17-245
- EP backends
mariana/mariana/models/multimodal/parallel/fsdp/ep_prefetch.py:1-97mariana/tasks/gpt2/m12_pretrain/model/m12_megatron.py:611-637mariana/tasks/gpt2/m12_pretrain/model/flux_manager.py:177-235mariana/tasks/gpt2/m12_pretrain/model/deepep_manager.py:231-392mariana/tasks/gpt2/m12_pretrain/model/flash_comm_manager.py:7-76mariana/mariana/models/multimodal/parallel/common/sac_offload.py:526-529
- FLUX side
flux/include/flux/flux. H: 466-478flux/python/flux/cpp_mod.py:72-181flux/src/ths_op/flux_shm.cc:418-776flux/src/pybind/flux_coll_op.cc:83-467flux/src/pybind/gemm_rs.cc:28-102flux/src/pybind/ag_gemm.cc:34-149
9. Bottom Line
The most important takeaway is:
- Mariana does not have one single “distributed parallelism strategy”
- It has a toolbox of dimensions and composition frameworks
In practice:
- VLM is usually centered on
FSDP2 + SP + EP + parallel_encoder - LLM is usually centered on
Megatron (TP + PP + DP + optional CP/EP)or a newerFSDP2/FSDP2Torchstack - MegatronOmni and related multimodal audio-text stacks usually start from the Megatron template and then add replicated or redistributed encoders plus DSP/CP where needed
- RM usually reuses the wrapped text or VLM backbone strategy rather than introducing a new parallel dimension
- HybridEP / FlashComm matter mainly as specialized EP backends on selected MoE paths, especially GB/Blackwell, not as whole-trainer strategies
- default YAML/config-class values are usually conservative baselines; real tuned workloads often enable several dimensions together
- FLUX is best understood as a high-performance communication/fusion backend that accelerates selected dimensions, not as a replacement for the whole trainer
If your next goal is to optimize or merge communication backends, the most profitable areas are usually:
- TP fusion
- SP all-to-all fusion
- EP dispatch/combine
Those are the places where Mariana, FLUX, HybridEP, and FlashComm are closest in spirit, and where backend choices most directly affect end-to-end training speed.
-
No backlinks found.