Ray Compiled Graphs: Optimized AI Workloads with Native GPU Communication

By Sang Cho, Sam Chan and Stephanie Wang   
image8

LinkIntroduction

As AI models continue to rise in complexity and size, workloads and applications around these models have driven new requirements for the underlying software infrastructure and primitives. Unlike traditional CPU-based workloads, large AI model workloads like training and inference are overwhelmingly GPU-intensive and often require distributed computation and coordination across tens or hundreds of accelerators. 

Compiled Graphs provide minimal task submission overhead (~50 us) compared to Ray’s standard task submission overheads (1~2 ms). While these default overheads are negligible for long-running, throughput-centric workloads such as data processing or batch processing, they become untenable for sub-second workloads such as auto-regressive token generation. 

Compiled Graphs also supports native GPU to GPU transfer while automatically resolving deadlock and overlapping communication with computation. Without Compiled Graphs, users must leverage out-of-band communication primitives such as NCCL in order to enable low-latency GPU transfer.

image6

These improvements open up exciting new opportunities for Ray programs:

  • Ray Compiled Graphs can transfer large tensors across actors with millisecond-level latencies

  • By using Compiled Graphs, you can express complex pipeline scheduling algorithms with less than 70 lines of code while reaching the performance parity with state of the art libraries like Deepspeed for a Llama 7B model on 4xA100 GPUs

  • Compiled Graphs based multi-modal training workload with heterogeneous parallelism and GPUs achieves a 43% improvement in token-per-dollar efficiency compared to PyTorch's standard FSDP implementation.

LinkIntroducing Compiled Graphs in Ray

image6

Compiled Graphs is a new feature in Ray that offers a classic Ray Core-like API with three major advantages:

  1. Reduced system overhead for repetitive task graphs 

  2. Native support for GPU-GPU communication via NVIDIA NCCL

  3. Optimized scheduling to avoid deadlock and best utilize compute and communication resources

In comparison to the standard Ray API, Compiled Graphs are expressed as a static computation graph. 

By taking advantage of the static nature of the program, Ray is able to allocate resources for a task ahead of time and reuse them for future invocations. Further, Ray can allocate memory properly ahead of time to leverage optimized communication primitives such as NCCL that require senders and receivers to be symmetric. As shown below, Ray Compiled Graphs can improve latencies for simple GPU tensor transfers by up to 140x and CPU data transfer by 17x.

Let’s dive into a simplified version of a common pattern in machine learning workloads: scatter-gather.

First, let’s start by installing compiled graphs:

1pip install -U "ray[adag]"

Scatter-gather is a common distributed computation pattern in machine learning workloads. It sends the same input to multiple workers (Ray Actors) and gathers the result from all workers. For example, for tensor parallel inference, the program sends the same CPU data to all actors, then the actors move CPU data to GPU and run the torch model that only loads the sharded weights. Let’s express this program with Compiled Graphs.

We start by creating normal Ray actors, in this case 3 of them:

1import ray
2from time import perf_counter
3
4@ray.remote
5class TensorProcessor:
6  def fwd(self, tensor: str) -> str:
7    # some_gpu_ops(tensor.to(torch.device()))
8    # In a typical application, this actor will move the 
9    # tensor to GPU and run a torch module. 
10    # We instead do nothing and return the string input immediately for simplicity.
11    return tensor
12
13N = 3
14actors = [TensorProcessor.remote() for _ in range(N)]

First, let’s see how to express this with regular Ray programs.

1# warmup actors
2for _ in range(10):
3    ray.get([actor.fwd.remote("hello") for actor in actors])
4
5s = perf_counter()
6result = ray.get([actor.fwd.remote("hello") for actor in actors])
7print("Normal ray programs took", (perf_counter() - s) * 1000, "ms")
8print(result)

Then, we define and compile a compiled graph that passes the same input placeholder to all actors. Here, we use Ray’s DAG API, which builds an intermediate representation capturing the computation graph in a static fashion. We use the MultiOutputNode syntax to wrap the outputs, which is necessary when we have more than one output node.

1import ray.dag
2
3# Define a DAG for lazy execution.
4with ray.dag.InputNode() as inp:
5    # Bind each actor task to the same input placeholder.
6    outputs = [actor.fwd.bind(inp) for actor in actors]
7    dag = ray.dag.MultiOutputNode(outputs)

This produces a Ray DAG like this:

image5

Now, to use Compiled Graphs, we use the following experimental_compile command. Ray will pre-allocate all resources needed to run the graph, leading to much faster execution of the graph compared to the standard dynamic runtime:

1compiled_graph = dag.experimental_compile()
2
3# warmup actors
4for _ in range(10):
5    ray.get(compiled_graph.execute("hello"))
6
7# Execute the DAG with different arguments:
8s = perf_counter()
9result = ray.get(compiled_graph.execute("hello"))
10print("Compiled Graphs took", (perf_counter() - s) * 1000, "ms")
11print(result)
12# ["hello", "hello", "hello"]

That’s it! You can use the same program to scale your workloads. The APIs also work when there are multi nodes.

Underneath the hood, resources are pre-allocated to reduce overhead during execution:

  • The new Compiled Graphs backend statically allocates input and output buffers for each actor task upon compilation instead of dynamically allocating them each time the DAG is executed. These buffers are reused at execution time, and actors always push results directly to the process that needs them. 

  • All actors that are on the same Ray node will share the same physical input buffer, which is synchronized by the Ray Compiled Graphs backend. This helps reduce the per-task overhead from serializing the task arguments, allocating memory for the arguments, and invoking the task.

  • The backend also allocates the actor’s execution loop ahead of time. Instead of waiting for an RPC to execute its next task, each actor waits in a loop for the arguments (passed via the allocated buffers) for the next echo task.

Now, what if we want to pipeline the execution across different actor tasks? One example of this is pipeline-parallel inference, where we pass intermediate outputs from one actor to the next through shared memory, and the data transfers should be pipelined with the compute tasks. We can pipeline execution across different actors by executing the same DAG multiple times before retrieving the output:

1# Teardown the previous dag.
2compiled_graph.teardown()
3with ray.dag.InputNode() as inp:
4  for actor in actors:
5    # Pass each actor task output as input to the next actor task.
6    inp = actor.fwd.bind(inp)
7  dag = inp

This produces a directed acyclic graph (DAG)  like this:

image1

Which we can compile and execute like this:

1compiled_graph = dag.experimental_compile()
2# Call dag.execute() several times. The executions will be pipelined across the different actors.
3refs = [compiled_graph.execute(f"hello{i}") for i in range(N)]
4# Get the results, flushing the pipeline.
5for ref in refs:
6  print(ray.get(ref))
7# "hello0"
8# "hello1"
9# "hello2"

To demonstrate GPU-GPU communication, we can construct an actor that sends a tensor to another actor. To run this example, make sure you have at least 2 GPUs in your cluster.

1import ray
2import ray.dag
3
4import torch
5
6assert ray.cluster_resources().get("GPU") >= 2, ("Insufficient number of GPUs available in the cluster.")
7
8@ray.remote(num_gpus=1)
9class GPUSender:
10  def send(self, shape):
11    return torch.zeros(shape, device="cuda")
12
13@ray.remote(num_gpus=1)
14class GPUReceiver:
15  def recv(self, tensor: torch.Tensor):
16    assert tensor.device.type == "cuda"
17    return tensor.shape
18
19sender = GPUSender.remote()
20receiver = GPUReceiver.remote()

Next, we define and compile an compiled graph that passes a CUDA tensor from one actor to the other. With the TorchTensorType hint below, Ray will use NCCL under the hood to transport the tensors via GPU RDMA between 2 GPUs.

1from ray.experimental.channel.torch_tensor_type import TorchTensorType
2with ray.dag.InputNode() as inp:
3  dag = sender.send.bind(inp)
4  # You have to specify transport="nccl", otherwise, it uses shared memory to transfer GPU tensors.
5  dag = dag.with_type_hint(TorchTensorType(transport="nccl"))
6  dag = receiver.recv.bind(dag)
7
8compiled_graph = dag.experimental_compile()
9# Execute the DAG. Ray aDAG will orchestrate any NCCL ops.
10assert ray.get(compiled_graph.execute((10, ))) == (10, )
image5

You can check out the developer guide for more information.

LinkBenchmarks

We compared Ray Core (labeled “Ray Standard Runtime”) with Ray Compiled Graphs on a variety of communication patterns. 

  • Round Trip: In this pattern, we construct an actor to repetitively send data to a receiver actor.

  • Scatter-Gather: In this pattern, the driver sends data to all actors and gathers the result from each actor.

  • Chain of tasks: In this pattern, we construct a pipeline of actors that receive the data from a previous actor and send it to the next actor.

First, we ran the benchmark with 1 byte of CPU data on a single and multiple nodes for the above computation pattern to measure the system overheads. The following benchmark ran on a m5.16xlarge instance which has 64 CPUs.

In the graph below, we observe that round trip communication on a single node can be up to 17x faster, and on a multi-node can be 2.7x faster. 

For more complex workloads, such as scatter-gather or chain of tasks, we see that Ray Compiled Graphs can improve latency by up to 20x.

image7

We also compared Ray Core with Ray Compiled Graph for GPU to GPU transfer. We ran a simple round trip benchmark with a 40MB CUDA tensor on a machine with NVLink (A100) and without NVLink (A10G), where we transferred the tensor 10 times and measured the end-to-end latency.

image4

By default, Ray Core does not have native support zero copy serialization for torch tensors nor does it use NCCL for communication. 

On the other hand, Ray Compiled Graph uses NCCL under the hood for GPU to GPU transfer to optimize the performance. On a machine with 2 A10G devices without NVLink, Ray Compiled Graphs can transfer tensors with 19x better latency over Ray Core. On A100 with NVLink, latency can be reduced by nearly 140x.

These benchmarks show that compiled graphs are powerful primitives that can open up a new class of emerging AI applications.

LinkUse Cases

LinkFaster model-parallel inference

We’ve integrated compiled graphs into vLLM, a popular open source project for LLM inference, to enable tensor parallelism and pipeline parallelism inference using Ray. With pipeline parallelism, we’re able to see a 10-15% throughput/latency improvement compared to the default NCCL backend. 

LinkState of the art distributed training

In our distributed training experiments, we leveraged Compiled Graphs to implement more complex and efficient training strategies. As an example, we significantly increased the training throughput of multimodal contrastive models, like CLIP, by using heterogeneous devices and parallelism techniques. Expressing such a program is much more difficult with vanilla Pytorch.

Specifically, by applying different parallelism methods to the text and vision encoders and placing the smaller text encoder on a more cost-effective GPUs, we observed a 20% increase in training throughput and a 43% improvement in token-per-dollar efficiency compared to PyTorch's standard FSDP implementation.

We also prototyped various pipeline parallel scheduling algorithms such as afab, 1f1b and zero bubble (https://arxiv.org/pdf/2401.10241). We could express complex pipeline scheduling algorithms with less than 70 lines of code while reaching the performance parity with state of the art libraries like Deepspeed for a Llama 7B model on 4xA100 GPUs. We are planning to further improve pipeline parallel training performance with larger scale workloads.

We plan to have follow-up blogs specifically for these use cases. Stay tuned!

LinkConclusion

Try out Ray Compiled Graphs today on the latest version of Ray via pip install “ray[adag]”. The feature is at its alpha stage and under active development. In the future, we plan on sharing more about how to use Ray Compiled Graphs to implement distributed inference and training. We are excited to hear more use cases and welcome contributions! 

To do so:

  • Join ray.slack.com and join the #ray-accelerated-dag channel

  • Checkout what we’re working on in Github.

Ready to try Anyscale?

Access Anyscale today to see how companies using Anyscale and Ray benefit from rapid time-to-market and faster iterations across the entire AI lifecycle.