Building a RAG Batch Inference Pipeline with Anyscale and Union

By Kevin Su and Kai-Hsun Chen   

Ray is an open-source unified compute framework that makes it easy to scale AI and Python workloads—from data processing, training, and tuning to model serving. This blog showcases the versatility of Ray by demonstrating embedding generation and LLM batch inference with Ray in two Flyte pipelines.

Flyte is an open-source orchestrator that facilitates building production-grade data and machine learning pipelines. AI workflows often require custom infrastructure that evolves rapidly across use cases. A pipeline orchestrator that brings together different personas, such as data engineers, platform engineers, infrastructure engineers, and ML engineers, is essential for boosting productivity.

An ML pipeline consists of multiple workloads, such as data processing, training, tuning, and serving. Ray is the go-to solution for ML engineers and platform engineers to manage the end-to-end lifecycle of ML workloads as a unified computing engine. In addition, there are some challenges in ML pipelines, such as accessing data from data warehouses for data engineers and managing task lifecycles for infrastructure engineers. Flyte is a great solution to address these gaps. Hence, a unified distributed computation framework like Ray and a workflow orchestrator like Flyte are necessary.

In this blog, we’ll review why Anyscale is the best place to run Ray workloads, and Union is the best place to orchestrate Flyte pipelines. We’ll then dive into a RAG example to show the perfect marriage between Anyscale and Union.

Anyscale, built by the creators of Ray, provides a seamless user experience for developers and AI teams to deploy AI/ML workloads at scale. Companies using Anyscale benefit from rapid time-to-market and faster iterations across the entire AI lifecycle.

The Anyscale Platform offers a streamlined interface for developers to leverage state-of-the-art open source large language models (LLMs) to power AI applications. Deploying in a private cloud environment allows teams to meet their specific privacy, control, and customization requirements.

Union, built by the technical founding team behind Flyte, abstracts away the infrastructure, providing a turnkey system that lets ML engineers and data scientists focus on what they do best without the need for a dedicated team to manage the platform.

LinkPipeline Architecture

image3

Suppose you want to build a chat bot that responds to open-source GitHub issues. You can achieve this by creating two Flyte pipelines. The architecture diagram above illustrates these two pipelines:

LinkEmbedding Generation Pipeline

This pipeline generates embeddings for Flyte documentation and Slack data using Anyscale Jobs. Here’s how you can set it up:

  • Step 1: Load the Flyte GitHub codebase, documents, and messages from the Flyte Slack workspace, and split them into sentence-sized chunks.

  • Step 2: Save these chunks to cloud storage, such as AWS S3, which Union and Anyscale share.

  • Step 3: Launch an Anyscale Job to create embeddings with Ray Data. Save the generated embeddings back to cloud storage.

LinkBatch Inference Pipeline

This pipeline monitors GitHub issues in Flyte repositories and uses the Anyscale Platform to serve an LLM with RAG to perform batch inference and reply to the GitHub issues. Here’s how it works:

  • Step 1: Load and preprocess GitHub issues from Flyte repositories every few hours.

  • Step 2: Launch Anyscale as the backend for LLM inference.

  • Step 3: Start an Anyscale Job to run batch inference using Ray Data with RAG. This involves:

    • Launching a vector database (e.g., FAISS) and load embeddings from cloud storage into it.

    • Retrieving context from the vector database.

    • Sending the request with the original prompt and context to the Llama Model.

  • Step 4: Reply to the GitHub issues using the results.

Consider scheduling the embedding generation pipeline to run weekly and the batch inference pipeline to execute daily.

LinkEmbedding Generation Pipeline

LinkStep 1 & 2: Chunk Data

To reply to GitHub issues in the Flyte repository, the prompt should have enough context about the Flyte codebase and documentation. However, the text lengths of each document and code script vary, and many are quite large chunks.

The following code snippet demonstrates how to use the LangChain API to load the Flytekit documents and break each document into smaller chunks to facilitate indexing and querying. Flyte will then upload these documents to cloud storage which Union and Anyscale share.

1@task(container_image=image_spec, cache_version="1", cache=True)
2def load_flyte_document(repo: FlyteDirectory, chunk_size: int) -> List[Document]:
3    docs = GitLoader(
4        repo_path=repo,
5        branch="master",
6        file_filter=lambda file_path: file_path.endswith(".rst"),
7    ).load()
8    text_splitter = RecursiveCharacterTextSplitter(
9        chunk_size=chunk_size, chunk_overlap=200
10    ).from_language(Language.RST)
11    documents = text_splitter.split_documents(docs)
12    print(f"Loaded {len(documents)} documents from Flyte documents")
13    return documents

Each chunk in the documents will be represented in the following format:
page_content: max_parallelism can be used to control the number of parallel nodes to run within the workflow.
source: flyteidl/protos/docs/admin/admin.rst

LinkStep 3: Generate Embedding

At this point, we’ve successfully split the documents into small chunks. Next, we need to generate embeddings to enable similarity search for retrieving the most relevant chunks for a given query. To achieve this, we’ll launch an Anyscale Job to use Ray Data to generate embeddings in a distributed manner and save them to the shared S3 bucket.

To optimize the process of building vector databases, we begin by dividing the data into multiple batches. We then process each batch by mapping the EmbedChunks function using Ray's map_batches method, which allows parallel processing across multiple workers. After building the databases, we merge them into a unified vector database for efficient data retrieval. Ray’s auto scaling feature helps accelerate the indexing process by dynamically adjusting resources for optimal performance and scalability.

Union makes it easy to run this Ray job on the Anyscale Platform. By simply specifying the AnyscaleConfig, you can submit the job to Anyscale effortlessly, without needing an in-depth understanding of the Anyscale API. This approach simplifies deployment, making powerful infrastructure accessible to more users with minimal setup.

1class EmbedChunks:
2    def __init__(self):
3        self.embedding_model = HuggingFaceEmbeddings(
4            model_name="sentence-transformers/all-mpnet-base-v2"
5        )
6
7    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
8        results = FAISS.from_documents(batch["data"], self.embedding_model)
9        return {"embeddings": [results]}
10
11@task(task_config=AnyscaleConfig(compute_config="flyte-rag"))
12def embedding_generation(
13    flytekit_code: List[Document],
14    flyte_code: List[Document],
15    flyte_document: List[Document],
16    slack: List[Document],
17) -> FlyteDirectory:
18    docs = flytekit_code + flyte_code + flyte_document + slack
19    batches = np.array_split(docs, 8)
20    ds = ray.data.from_numpy(batches)
21    res = ds.map_batches(
22        EmbedChunks,
23        num_gpus=1,
24        batch_size=1000,
25        concurrency=2,
26    ).take_all()
27
28    retrun merge_embeddings(res)
image1
image2

LinkBatch Inference Pipeline

LinkLlama Predictor

In this section, we'll describe how we leverage Ray actors to execute batch inference efficiently. The Ray actor is responsible for carrying out the following operations:

  • Launching a FAISS instance and loading the embeddings, generated by the embedding generation pipeline into the vector database.

  • Generating an embedding for the user's prompt, which in this demo represents  a GitHub issue.

  • Performing a similarity search with the prompt’s embedding to retrieve the most relevant chunks from FAISS.

Using LangChain to combine the retrieved chunks with the prompt, and sending the updated prompt to the Llama model to generate responses.

1class LlamaPredictor:
2    def __init__(self, vector_database: FlyteDirectory):
3        db = load_database(vector_database)
4        retriever = db.as_retriever()
5        llm = VLLM(
6            model="meta-llama/Meta-Llama-3.1-70B",
7            trust_remote_code=True,  # mandatory for hf models
8            max_new_tokens=512,
9        )
10
11        def format_docs(docs):
12            return "\n\n".join(doc.page_content for doc in docs)
13
14        response_schemas = [
15            ResponseSchema(name="answer", description="answer to the user's question"),
16            ResponseSchema(
17                name="source",
18                description="source used to answer the user's question, should be a website.",
19            ),
20        ]
21        output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
22
23        format_instructions = output_parser.get_format_instructions()
24        prompt = PromptTemplate(
25            template="You are a Flyte expert helping people resolve the github issues. "
26            "Use the following pieces of retrieved context to answer the questions "
27            "and propose code changes as best as possible.\n{context}"
28            "\n{format_instructions}\n{question}",
29            input_variables=["question", "context"],
30            partial_variables={"format_instructions": format_instructions},
31        )
32
33        self.rag_chain = (
34            {"context": retriever | format_docs, "question": RunnablePassthrough()}
35            | prompt
36            | llm
37            | output_parser
38        )
39
40    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
41        questions = batch["data"]
42        batch["output"] = [self.rag_chain.invoke(question) for question in questions]
43        return batch

LinkBatch Inference Using Ray Data

The batch_inference task performs large-scale inference on a set of issues using a pre-established vector database. It begins by converting a list of issues into a Ray dataset with ray.data.from_numpy for efficient parallel processing.

By leveraging Ray's data handling and parallel processing capabilities, this task efficiently processes large volumes of issues, generating predictions quickly. Running this task on Anyscale's hosted platform makes it easy to allocate GPUs, further enhancing performance and making it well-suited for complex inference tasks.

We use FlyteDirectory as the task input, which contains the document indexing generated in the previous step. One of the key advantages of using FlyteDirectory is its support for lazy downloading, meaning that the data is only downloaded when it is actually utilized. In this specific case, the FlyteDirectory is passed to the Ray actor, ensuring that the vector database is downloaded and initialized directly within the actor. This approach optimizes resource usage by deferring the download to the actor's initialization phase rather than loading it on the Ray Head node, thereby enhancing efficiency and reducing unnecessary data transfer.

1@task(container_image=container_image, task_config=anyscale_config, enable_deck=True)
2def batch_inference(issues: typing.List[Document], vector_database: FlyteDirectory):
3    questions = [issue.page_content for issue in issues]
4    ds = ray.data.from_numpy(np.asarray(questions))
5    predictions = ds.map_batches(
6        LlamaPredictor,
7        num_gpus=2,
8        batch_size=10,
9        concurrency=2,
10        fn_constructor_kwargs={"vector_database": vector_database},
11    )

LinkGithub Issue

The issue described below is one of the known issues from the Flyte repository that we use as input for our batch inference pipeline:

image4

LinkFlyte Deck

You can conveniently preview responses directly in Flyte Deck on Union without the need to download files from remote storage. This feature enables inspecting and validating the outputs immediately after a task completes. By providing instant access to results within Union, you can verify the accuracy and relevance of generated responses.

image5

LinkAcknowledgement

Anyscale: Julia Martins
Union: Shalabh Chaudhri, Samhita Alla, Daniel Sola, Troy Chiu, Han-Ru Chen

LinkConclusion

We’ve just scratched the surface of what’s possible with Union and Anyscale in building a batch inference pipeline. With Union simplifying the submission of Ray jobs and Anyscale providing a platform for batch inference, you’re set to handle even the most complex tasks with ease. 

Curious to see how it all comes together? Dive deeper into our GitHub repository for more insights or book a demo to experience Anyscale in action.

Ready to try Anyscale?

Access Anyscale today to see how companies using Anyscale and Ray benefit from rapid time-to-market and faster iterations across the entire AI lifecycle.