Serving Nemotron-Super-120B with a 1M token context on a 2-node DGX Spark cluster

Serving Nemotron-Super-120B with a 1M token context on a 2-node DGX Spark cluster
Photo by Shubham Dhage / Unsplash

This is a build log. We had two NVIDIA DGX Spark workstations (GB10 / SM121, 128 GB unified memory each), 200 GbE ConnectX-7 NICs, and the goal of serving NVIDIA's Nemotron-3-Super-120B-A12B-NVFP4 with the model's full 1 million token context. The path there crossed several traps that aren't documented in any one place: a missing Ray binary in the latest NGC vLLM image, environment-variable propagation quirks across nodes, host-memory starvation that survives only with cgroup-style discipline, and a handful of vllm serve flags that move between releases.

The full repository is at https://github.com/TechPreacher/dgx-spark-vllm-cluster. Below is the narrative.

Hardware and topology

Each Spark has a Grace-Blackwell SoC with native FP4 tensor cores (SM121) and 128 GB of host-GPU unified memory. The two boxes are linked by two ConnectX-7 dual-port NICs on each side: four 200 GbE ports per node, ~800 GbE aggregate data plane. The interfaces show up under predictable RoCE names (rocep1s0f0, rocep1s0f1, roceP2p1s0f0, roceP2p1s0f1) on top of standard netdev names (enp1s0f0np0 etc).

We split traffic deliberately:

  • Control plane — one of the four interfaces (enp1s0f1np1 on this hardware) carries Ray's GCS, PyTorch tensor-parallel rendezvous, and anything that needs a single IP per node.
  • Data plane — all four interfaces are exposed to NCCL via NCCL_IB_HCA and to UCX via UCX_NET_DEVICES. NCCL talks RoCE directly to the four HCAs; the netdev names are also exported to NCCL_SOCKET_IFNAME, GLOO_SOCKET_IFNAME, and OMPI_MCA_btl_tcp_if_include as a TCP fallback path.

A subtle correctness detail that bit us early: the per-node bring-up script must pass --device=/dev/infiniband --cap-add=IPC_LOCK --ulimit memlock=-1:-1 to docker run. The head-side copy of run_cluster.sh had these flags; the worker-side copy did not. NCCL fell back to TCP on the worker without complaining, and we lost the 800 GbE data plane until we synced the two copies (now kept byte-identical and verified at every bring-up).

Ray topology

We run two Ray nodes (head, worker), with vLLM scheduling tensor-parallel shards across them. The bring-up scripts both call into a shared cluster/{head,worker}/run_cluster.sh that wraps docker run with the right flags and starts ray start --block inside the container.

# Node 1 (head)
cd cluster/head && bash run_headnode_2.sh

# Node 2 (worker)
cd cluster/worker && bash run_workernode_2.sh

Each script blocks on the container's foreground ray start. Closing the terminal tears the cluster down. The launcher script for the model is a separate process that uses docker exec to step inside the head container and run vllm serve there — Ray picks up the request and dispatches TP rank 1 to the worker over the data plane.

Picking the model

We had Qwen3 FP8 variants (30B-A3B-Thinking, 122B-A10B) serving cleanly via the same Ray topology, but neither uses the SM121 FP4 tensor cores. For Nemotron-3-Super, NVIDIA ships an NVFP4-quantized checkpoint specifically targeted at this generation of hardware. The model itself is a LatentMoE hybrid: Mamba-2 state-space layers interleaved with full attention layers and a sparse MoE on top. 120 B total parameters, 12 B active per token. The Mamba layers carry no KV cache (just a fixed-size SSM state), which is the reason a 1 M token context is physically tractable on consumer-scale memory — only the attention layers' KV grows linearly with sequence length.

The first obstacle: NGC dropped Ray from vLLM

NVIDIA's NGC publishes a vLLM container roughly monthly. We were running nvcr.io/nvidia/vllm:25.11-py3 for the Qwen path because that's what we had pinned when the cluster was first built. The HF model card for Nemotron-3-Super recommends vllm/vllm-openai:v0.20.0 (or newer), and 25.11 was too old to recognize super_v3 as a reasoning parser, didn't have --async-scheduling, didn't accept --mamba-ssm-cache-dtype float16, and didn't expose the --reasoning-parser-plugin flag we'd have needed to side-load the parser.

So we moved to the latest NGC vLLM build at the time, nvcr.io/nvidia/vllm:26.05.post1-py3. The head container started and then immediately printed:

/bin/bash: line 1: ray: command not found

The image had no ray in $PATH. We checked further:

docker run --rm --entrypoint /bin/bash nvcr.io/nvidia/vllm:26.05.post1-py3 -c \
  'which ray; python -c "import ray"; pip show ray'
# ray binary not in PATH
# ModuleNotFoundError: No module named 'ray'
# WARNING: Package(s) not found: ray

NGC had removed Ray entirely from this build. Not a PATH issue — ray is genuinely absent. vLLM upstream installs Ray as a transitive dependency, so this looks intentional on NGC's part (smaller image, fewer CVEs to scan).

The fastest fix that keeps everything else about the cluster unchanged: layer Ray back on top of the NGC base in a thin local image. We added cluster/Dockerfile:

ARG BASE_IMAGE=nvcr.io/nvidia/vllm:26.05.post1-py3
FROM ${BASE_IMAGE}

# Restore Ray. The NGC vllm:26.05 images dropped it (verified via `pip show ray`).
# `ray[default]` pulls the dashboard + observability extras the CLI expects.
RUN pip install --no-cache-dir "ray[default]" \
 && ray --version

And a one-line builder:

#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BASE_IMAGE="${BASE_IMAGE:-nvcr.io/nvidia/vllm:26.05.post1-py3}"
TAG="${TAG:-local/vllm-ray:26.05.post1}"
docker build --build-arg BASE_IMAGE="${BASE_IMAGE}" -t "${TAG}" "${SCRIPT_DIR}"

The bring-up scripts default VLLM_IMAGE to local/vllm-ray:26.05.post1. We build the image once on each node:

bash cluster/build-image.sh    # on Node 1
bash cluster/build-image.sh    # on Node 2

After that, the standard bring-up flow worked. About a 50 MB delta layered on a 9 GB base.

The second obstacle: Ray doesn't propagate os.environ across nodes

Nemotron-3-Super requires four vLLM runtime environment variables to make its kernels and collectives pick consistent code paths on DGX Spark:

VLLM_NVFP4_GEMM_BACKEND=marlin
VLLM_FLASHINFER_ALLREDUCE_BACKEND=trtllm
VLLM_USE_FLASHINFER_MOE_FP4=0
VLLM_ALLOW_LONG_MAX_MODEL_LEN=1

The first three drive kernel selection inside model code (FP4 GEMM backend, the inter-rank all-reduce backend, and disabling a known-buggy FlashInfer MoE path on this hardware). The fourth lifts vLLM's sanity check that refuses to honor --max-model-len past the model's declared limit.

Our first instinct was to set them via docker exec -e VAR=val from the launcher, which is how the Qwen launchers pass VLLM_API_KEY to the container. That works for the head process (TP rank 0) — but not for rank 1, which Ray spawns inside the worker container. Ray does not propagate the driver's os.environ to remote workers across nodes; rank 1 inherits only the env that was present at the worker container's docker run time. If the four vars are missing on the worker, rank 1 picks a different FP4 GEMM kernel than rank 0 and the all-reduce backend disagrees between ranks: the first matmul or the first collective crashes the model.

We needed to forward these env vars into both containers at start time, but without baking model-specific knowledge into the generic cluster bring-up. The fix was a small extension to the bring-up scripts. They now read a space-separated VLLM_FORWARD_VARS from the parent shell and forward each named variable as -e VAR=value to docker run:

EXTRA_ENV_ARGS=()
for V in ${VLLM_FORWARD_VARS:-}; do
  [[ -n "${!V:-}" ]] && EXTRA_ENV_ARGS+=(-e "$V=${!V}")
done
# ... appended to docker run args ...

Model-specific profiles live in the model's directory. For Nemotron:

# nemotron/cluster-env.sh
export VLLM_NVFP4_GEMM_BACKEND=marlin
export VLLM_FLASHINFER_ALLREDUCE_BACKEND=trtllm
export VLLM_USE_FLASHINFER_MOE_FP4=0
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1

export VLLM_FORWARD_VARS="VLLM_NVFP4_GEMM_BACKEND VLLM_FLASHINFER_ALLREDUCE_BACKEND VLLM_USE_FLASHINFER_MOE_FP4 VLLM_ALLOW_LONG_MAX_MODEL_LEN"

The bring-up workflow then becomes:

# Node 1
source nemotron/cluster-env.sh
cd cluster/head && bash run_headnode_2.sh

# Node 2
source nemotron/cluster-env.sh
cd cluster/worker && bash run_workernode_2.sh

When the profile is not sourced, VLLM_FORWARD_VARS is unset, the loop is a no-op, and the bring-up scripts behave exactly as before. The Qwen path doesn't need a profile (its FP8 deployment requires no VLLM_* runtime overrides), so we didn't create one — to remove the temptation to source it "just in case".

To make the failure mode loud rather than silent, the Nemotron launcher probes the head container's env before running vllm serve and refuses to start if the four vars are missing inside it:

MISSING_VARS=$(docker exec "${VLLM_CONTAINER}" /bin/bash -c '
  set -u
  missing=""
  for V in VLLM_NVFP4_GEMM_BACKEND VLLM_FLASHINFER_ALLREDUCE_BACKEND \
           VLLM_USE_FLASHINFER_MOE_FP4 VLLM_ALLOW_LONG_MAX_MODEL_LEN; do
    [[ -z "${!V:-}" ]] && missing="${missing} $V"
  done
  echo "${missing}"
' | xargs)

if [[ -n "${MISSING_VARS}" ]]; then
  cat >&2 <<EOF
ERROR: Required NVFP4 env vars are not set inside the Ray container:
  ${MISSING_VARS}

These must be present at container START time on BOTH nodes; they cannot be
added now via docker exec because Ray-spawned rank-1 workers on the worker
node would still be missing them. Tear the cluster down and bring it back up
after sourcing nemotron/cluster-env.sh on each node.
EOF
  exit 1
fi

The cost is one extra docker exec round-trip at launch. The benefit is that the failure prints the exact remediation instead of crashing five minutes into a model load.

The launcher

The actual model-launch script doesn't run a fresh container. The Ray cluster is already up; we step into it. Skeleton (full version in nemotron/launch-nemotron-120b.sh):

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/../cluster/lib.sh"   # load_env + find_ray_container helpers
load_env "${SCRIPT_DIR}"                   # source nemotron/.env (HF_TOKEN, VLLM_API_KEY)
: "${VLLM_API_KEY:?VLLM_API_KEY not set}"
: "${HF_TOKEN:?HF_TOKEN not set}"

VLLM_CONTAINER=$(find_ray_container)
# ... missing-vars probe described above ...

docker exec -it \
  -e VLLM_API_KEY="${VLLM_API_KEY}" \
  -e HF_TOKEN="${HF_TOKEN}" \
  -e MODEL_CKPT \
  -e MAX_MODEL_LEN \
  -e GPU_MEM_UTIL \
  ... \
  "${VLLM_CONTAINER}" /bin/bash -c '
    set -euo pipefail
    PARSER=/root/.cache/huggingface/super_v3_reasoning_parser.py
    if [[ ! -f "${PARSER}" ]]; then
      curl -fsSL -o "${PARSER}" \
        "https://huggingface.co/${MODEL_CKPT}/raw/main/super_v3_reasoning_parser.py"
    fi
    exec vllm serve "${MODEL_CKPT}" \
      --tensor-parallel-size 2 --pipeline-parallel-size 1 --data-parallel-size 1 \
      --quantization fp4 --moe-backend marlin \
      --dtype auto --kv-cache-dtype fp8 --mamba-ssm-cache-dtype auto \
      --max-model-len "${MAX_MODEL_LEN}" \
      --gpu-memory-utilization "${GPU_MEM_UTIL}" \
      --max-cudagraph-capture-size 128 \
      --enable-chunked-prefill --async-scheduling \
      --trust-remote-code \
      --reasoning-parser-plugin "${PARSER}" --reasoning-parser super_v3 \
      --enable-auto-tool-choice --tool-call-parser qwen3_coder
  '

The reasoning parser plugin is a Python file in the model's HF repo. We fetch it once inside the container into the HF cache (which is host-bind-mounted, so it persists across container restarts) and pass its path with --reasoning-parser-plugin. Without this, thinking tokens come back in message.content instead of being split into a reasoning_content field — fine for chat, but breaks any client that expects the structured channel.

Version-dependent flag gates

The HF model card lists flags assuming a vLLM build that has every recent feature. nvcr.io/nvidia/vllm:25.11-py3 doesn't. We hit (and patched out) a small string of "unrecognized argument" errors during the upgrade:

  • --mamba-ssm-cache-dtype float16 → 25.11 only accepts auto|float32. Set to auto and let the model self-declare.
  • --reasoning-parser super_v3 → 25.11 doesn't know the parser, and its --reasoning-parser-plugin flag doesn't extend the choices list at runtime.
  • --swap-space 0 → removed in 26.05+. CPU swap is now controlled implicitly or via --cpu-offload-gb.
  • --async-scheduling, --moe-backend marlin, --max-cudagraph-capture-size 128 → all newer than 25.11.

Rather than freeze the launcher against a single image tag, we env-gated each of these:

ENABLE_REASONING_PARSER="${ENABLE_REASONING_PARSER:-1}"
ENABLE_ASYNC_SCHEDULING="${ENABLE_ASYNC_SCHEDULING:-1}"
MOE_BACKEND="${MOE_BACKEND:-marlin}"
CUDAGRAPH_CAPTURE_SIZE="${CUDAGRAPH_CAPTURE_SIZE:-128}"
MAMBA_SSM_DTYPE="${MAMBA_SSM_DTYPE:-auto}"

Defaults match the current image (local/vllm-ray:26.05.post1). If you ever fall back to 25.11 to compare, ENABLE_REASONING_PARSER=0 ENABLE_ASYNC_SCHEDULING=0 MOE_BACKEND= CUDAGRAPH_CAPTURE_SIZE= ./launch-nemotron-120b.sh strips every newer flag without editing code.

Memory: why we're cautious, how we got to 1M

A previous attempt to run gpt-oss-120b on a single Spark starved the host of memory so badly that ICMP kept replying but sshd became unreachable. The node had to be power-cycled remotely. With headless workstations and no IPMI watchdog wired up at the time, that was an expensive incident, and it's why our Nemotron defaults are deliberately under what NVIDIA's model card recommends:

  • --gpu-memory-utilization 0.75 (NVIDIA suggests 0.9)
  • --max-model-len started at 256k, not 1M
  • An opt-in ENABLE_EAGER=1 that disables CUDA graphs in case capture spikes memory on first inference

Spreading the model across two Sparks via Ray TP=2 halves the per-node weight footprint (NVFP4 weights are ~60–80 GB total → ~30–40 GB per node), which by itself materially de-risks the host-starvation pattern. We pushed --max-model-len in steps:

  • 256k tokens — model loaded, single-stream and concurrent inference both clean. Host MemAvailable comfortably above the warning threshold.
  • 1 M tokens — pushed it. 1 M doesn't double total memory (only attention-layer KV scales linearly, and Nemotron-3-Super has roughly a quarter of its layers as full attention), so the projected delta from 512k → 1 M is ~9–10 GB per node, landing near but still above the 6 GB warning floor.

512k tokens — also stable. Memory snapshot during inference:

time      used        free        buff/cache  avail       total
17:39:42  103.5G     0.9G       19.8G        18.2G       121.7G
17:39:44  103.5G     0.9G       19.8G        18.2G       121.7G
17:39:46  103.6G     0.9G       19.8G        18.1G       121.7G

About 18 GB of headroom on a 121 GB node. Tight but stable.

A handful of complementary safeguards live outside the launcher: OOMScoreAdjust=-1000 on sshd via systemctl edit ssh, earlyoom installed and enabled, and an external laptop-side watchdog that hits /health every 30 seconds and power-cycles the offender via the PDU on repeated failures. We can't add a cgroup --memory cap to the running Ray container after the fact (that's a docker run flag), so host-side discipline is the safety net.

Measuring throughput

We wanted a number to compare runs across context lengths and concurrency levels. vLLM bundles a benchmarking CLI; the simplest invocation from inside the container:

docker exec -it $(docker ps --format '{{.Names}}' | grep '^node-' | head -n1) \
  vllm bench serve \
    --backend openai-chat \
    --base-url http://localhost:8000 \
    --endpoint /v1/chat/completions \
    --model nvidia/nemotron-3-super \
    --tokenizer nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 \
    --dataset-name random \
    --random-input-len 1024 --random-output-len 256 \
    --num-prompts 100 --max-concurrency 8

It reports request throughput, output-token throughput, time-to-first-token, and inter-token latency with percentiles. We swept --max-concurrency from 1 to 32 to find the knee. For a single number worth quoting, concurrency 1 is the most informative for chat latency, and concurrency 8–16 is what you'd publish as aggregate throughput. We benchmark from a third machine (a laptop on the same network) rather than the Spark itself, so the HTTP loop doesn't compete with Ray for CPU.

The /metrics endpoint also exposes Prometheus-style counters — useful for a Grafana board if you want graphs without writing a benchmark loop:

curl -sf http://<head-ip>:8000/metrics \
  | grep -E '^vllm:(generation_tokens_total|prompt_tokens_total|time_per_output_token)'

What surprised us

A few things were less obvious going in:

  1. Cross-node env propagation in Ray. This was the highest-leverage gotcha. Documentation talks about Ray's runtime_env for per-actor env propagation, but for vLLM's use case it's much simpler to just bake the env into the container at docker run time and ensure both nodes match. The VLLM_FORWARD_VARS mechanism is small and general enough to handle future models too.
  2. NGC images vary in what they ship. Going from :25.11-py3 to :26.05.post1-py3 lost Ray but gained vLLM's plugin system. Going forward we expect more component movement like this, so the cluster/Dockerfile is a small ongoing maintenance cost we'll keep.
  3. Hybrid models change the KV/context math. The intuition "1 M context costs 4× more memory than 256k" is wrong for Nemotron-3-Super. Only attention layers scale; Mamba-2 layers carry constant-size SSM state. The empirical 512k → 1 M step cost ~10 GB per node, not 30.
  4. Stale Ray GCS state across container recreates. If a previous head container is somehow still bound to port 6379 on the host (network host mode + interrupted cleanup), the new head's ray start --head will refuse to overwrite the existing session, and the bring-up dies with a session-name assertion. The fix is docker rm -f $(docker ps -aq --filter 'name=^node-') on each node before retrying. We didn't add automatic cleanup to the bring-up scripts because the failure points at a real human-intent question ("did the previous run actually finish?") rather than something to paper over.

Repo layout

cluster/
  Dockerfile              # FROM nvcr.io/nvidia/vllm:26.05.post1-py3 + ray[default]
  build-image.sh          # docker build wrapper
  lib.sh                  # load_env + find_ray_container helpers
  head/
    run_headnode_2.sh     # 4-port data plane bring-up
    run_cluster.sh        # generic docker run wrapper for Ray
    ray_inference_health.sh
  worker/
    run_workernode_2.sh
    run_cluster.sh        # byte-identical to head's
qwen/
  launch-qwen-30b.sh
  launch-qwen-122b.sh
  .env.example
nemotron/
  cluster-env.sh          # NVFP4 runtime env profile
  launch-nemotron-120b.sh
  .env.example

Bring-up:

# One-time, per node
bash cluster/build-image.sh

# Each session
source nemotron/cluster-env.sh
cd cluster/head && bash run_headnode_2.sh   # Node 1
# (other node)
source nemotron/cluster-env.sh
cd cluster/worker && bash run_workernode_2.sh

# Launch
cd nemotron && ./launch-nemotron-120b.sh

What's next

Several follow-ups we'll likely tackle:

  • Build pipeline for the local image. Right now cluster/build-image.sh runs on each node manually. A small Makefile target with docker save | ssh node2 docker load would remove that drift.
  • Quantitative tokens-per-second numbers. We have the harness; we haven't yet committed to a benchmark protocol we're happy to publish. Probably a bench.sh that runs the same sweep at every image bump and writes a CSV.
  • Speculative decoding (MTP). Nemotron-3-Super publishes an MTP draft head. We have it env-gated (ENABLE_MTP=1) but haven't measured whether the cross-node speculative round-trip is a net win on this topology — the 800 GbE link is generous but every round-trip costs us. Likely needs evaluation per context length.
  • Out-of-repo hardening. sshd OOM score, earlyoom, and the external watchdog are documented but live outside the repo. We'd like to either pull them into a setup script or — more honestly — write them up as a node-bootstrap document so a fresh Spark can be brought into the cluster with a single checklist.

The full repository, including the launchers, the Dockerfile, and the bring-up scripts, is what this post is built on. If you're running the same hardware and chasing the same model, copying the scripts directly is probably the fastest path.