In collaboration with Berkeley-LMsys, we are excited to introduce a novel routing framework based on human preference data, designed to direct simple queries to a more cost-effective model. You can review our research paper here.
This blog post provides a comprehensive guide to building one of our most robust router models (causal LLM) at Anyscale. It covers every step, from data labeling and fine-tuning LLMs to offline evaluation. We've also linked the associated notebook to help you get started on building your own LLM Router.
We introduce a framework for training state-of-the-art LLM routers, systems that dynamically direct queries to either high-quality closed LLMs or cost-effective open-source LLMs, based on query complexity, optimizing both response quality and cost.
This tutorial provides an in-depth guide on building an LLM router based on a causal-LLM classifier, starting with generating labeled data, finetuning an LLM-based classifier with Anyscale's API, and finally running offline evaluations.
In collaboration with the Berkeley LMSys group, we release an arXiv paper presenting extensive evaluations of this model along with other models. Overall, our LLM Routers can achieve the same performance as our baselines with up to a 70% cost reduction on MT Bench, a 30% cost reduction on MMLU, and a 40% cost reduction on GSM8K.
When developing applications using Large Language Models (LLMs), achieving high-quality responses while maintaining a budget is a key challenge. Closed models like GPT-4 provide superior quality but are costly, especially with a high volume of queries. Conversely, Open Source Software (OSS) models are more economical but may not match the quality, especially for complex or domain-specific queries.
An LLM Router helps balance these aspects by deciding which queries are routed to a closed LLM and which to an OSS LLM based on the query's complexity or domain specificity. Below is a schematic representation of an LLM Router:
Given a set of user queries, an LLM router enables generating high-quality LLM responses while minimizing the overall cost.
In this tutorial, we'll demonstrate how to train a causal-LLM classifier on the Anyscale platform as an effective LLM router. We make the following design choices:
Model Choices: We’ll use GPT-4 as an example of a closed LLM and Mixtral-8x7B as the OSS LLM, so our causal LLM classifier will route between these two models.
Response Quality Rating: We'll quantify the quality of an LLM response on a scale of 1 to 5 stars, with higher scores indicating better quality. For simplicity, we'll assume that GPT-4 always achieves a 5-star rating, so it serves as a reference for Mixtral-8x7B.
Causal LLM Classifier: We'll finetune a Llama3-8B model as our causal LLM classifier and leverage Anyscale's powerful API. Our research shows that this model offers superior routing performance compared to smaller architectures.
More concretely, the objective of the causal LLM classifier is to direct "simple" queries to Mixtral-8x7B, thereby maintaining high overall response quality (e.g., an average score of 4.8/5) while significantly reducing costs (e.g., by 50%).
We show that it's possible to build LLM routers that achieve outstanding performance. Below are results from our best-performing LLM routers, the Causal LLM and a Matrix Factorization (MF) model, evaluated on the MT Bench benchmark, which demonstrate that our routers can achieve higher quality with lower costs (i.e., fewer calls to GPT-4) compared to the random baseline and public LLM routing systems from Unify AI and Martian. For more details on these results and additional ones, refer to our paper.
In the following sections, we discuss the steps that enable anyone to build a strong LLM router.
Prepare Labeled Data: The foundation of a robust LLM router is high-quality labeled data. In this section, we'll guide you through preparing this training data.
Finetune a Router Model: We demonstrate how to finetune a causal-LLM classifier using Anyscale's finetuning API, transforming it into an effective LLM router.
Offline Evaluation: Using the public codebase (RouteLLM), we will walk through an offline evaluation on standard benchmarks.
Time to complete: Approximately 120 minutes, including time to train on a node with 8xA10 GPUs.
1# Install required packages
2!pip install -r requirements.txt
3
4# Store your ANYSCALE_API_KEY and OPENAI_API_KEY in /home/ray/default/.env
5from dotenv import load_dotenv
6load_dotenv("/home/ray/default/.env")
The llm router essentially functions as a binary classifier, deciding whether to route a query to GPT-4 or Mixtral-8x7B based on the query text. Initially, we considered labeled data in the format (query, routing_label)
, where routing_label
is 1 if the query should be routed to Mixtral-8x7B and 0 if it should be routed to GPT-4.
However, our early experiments revealed that binary labels do not provide sufficient signal for training a robust router model. Therefore, we adopted a different labeling approach using a 1-5 scoring system, which reflects how well Mixtral-8x7B can effectively respond to the user's query. More specifically:
4-5: Mixtral-8x7B produces a very strong answer, showing deep understanding, creativity, detailed insight, and high relevance.
3: Mixtral-8x7B provides an adequate answer with moderate detail, relevance, and factual accuracy.
1-2: Mixtral-8x7B struggles to produce a strong answer due to the question's difficulty, vagueness, or the model's limitations.
We use labeled samples in the format (query, score_label)
. The routing_label
can be derived from the score_label
by setting a score threshold for quality, i.e. routing_label = 1 if score_label >= 4 else 0
.
Next, we'll dive into the detailed process of preparing our labeled dataset.
We want our llm router to be effective in open-ended chat domains. So, our first step is to collect a set of generic queries from the Nectar dataset. We chose the Nectar dataset for two reasons: it combines queries from many different domains, including open-ended chat, and it has responses from many models, including over 191K responses from GPT-4.
1from src.utils import load_and_display_nectar
2
3nectar_df = load_and_display_nectar()
prompt | answers | turns | source | good_natured | |
0 | \n\nHuman: 0.002 = 1000 \n1 = x?\n\nAssistant: | [ | 1 | [sharegpt] | True |
1Number of queries with GPT-4 responses: 191487
We will use a subset of the Nectar data that includes responses from GPT-4, as these will be used to generate scores (as seen below). We will process this data by focusing on single-turn conversations, filtering for good-natured interactions, and cleaning up the prompts and responses to maintain high quality. Additionally, we will sample a small subset from the dataset for the purpose of this tutorial; however, you can skip sampling to work with the full dataset.
1from src.utils import preprocess_nectar
2
3nectar_gpt4_df = preprocess_nectar(
4 nectar_df, model="gpt-4", response_column="gpt4_response"
5)
6
7# Sample a small subset from the dataset for the purpose of this tutorial
8N_SUBSET = 30
9dataset_df = nectar_gpt4_df.sample(N_SUBSET, random_state=42)
1display(dataset_df.head())
prompt | source | gpt4_response | |
6062 | Based on the features mentioned, which hotel d... | [evol_instruct] | Based on the features mentioned, Hotel A seems... |
113830 | Provide step-by-step instructions on how to cr... | [ultrachat] | Sure, here's a simple step-by-step guide on ho... |
138869 | What are the 10 largest cities in the US by po... | [lmsys-chat-1m] | As of the most recent data available, the 10 l... |
169249 | Write a comparison essay of at least 500 words... | [ultrachat] | Title: A Comparative Analysis of Driving a Car... |
116934 | Q: You are provided with an "Event", "Intent" ... | [flan_v2_niv2] | PersonX might feel satisfied or content using ... |
We don't have human labels for scores, so we will use the LLM-as-a-Judge approach. GPT-4 will act as an evaluator, reviewing the query and Mixtral's response to provide a score from 1-5. As shown in the paper, the most robust way to get labels is by providing a reference answer for comparison. Here, GPT-4's own response serves as the reference, and Mixtral's response is evaluated against it.
There are two main steps in this process:
Generate Mixtral-8x7B responses for all queries: We will use an online batch-inference method utilizing Ray and Anyscale endpoints.
Generate LLM-as-a-Judge labels: We will ask GPT-4 to evaluate the Mixtral responses against its own reference answers and provide a score from 1-5.
1import os
2from src.online_inference import generate_mixtral_responses
3
4dataset_df = generate_mixtral_responses(
5 dataset_df, os.getenv("ANYSCALE_API_KEY"), response_column="mixtral_response"
6)
1Starting batch inference on 30 queries...
2
3# queries un-processed: 29, in-progress: 1, ready: 0
4# queries un-processed: 28, in-progress: 2, ready: 0
5# queries un-processed: 27, in-progress: 3, ready: 0
6# queries un-processed: 26, in-progress: 4, ready: 0
7# queries un-processed: 25, in-progress: 5, ready: 0
8# queries un-processed: 24, in-progress: 6, ready: 0
9# queries un-processed: 23, in-progress: 7, ready: 0
10# queries un-processed: 22, in-progress: 8, ready: 0
11# queries un-processed: 21, in-progress: 9, ready: 0
12# queries un-processed: 20, in-progress: 10, ready: 0
13# queries un-processed: 19, in-progress: 11, ready: 0
14# queries un-processed: 18, in-progress: 12, ready: 0
15# queries un-processed: 17, in-progress: 13, ready: 0
16# queries un-processed: 16, in-progress: 14, ready: 0
17# queries un-processed: 15, in-progress: 15, ready: 0
18# queries un-processed: 14, in-progress: 16, ready: 0
19# queries un-processed: 13, in-progress: 17, ready: 0
20# queries un-processed: 12, in-progress: 18, ready: 0
21# queries un-processed: 11, in-progress: 18, ready: 1
22# queries un-processed: 10, in-progress: 18, ready: 1
23# queries un-processed: 9, in-progress: 18, ready: 1
24# queries un-processed: 8, in-progress: 18, ready: 1
25# queries un-processed: 7, in-progress: 18, ready: 1
26# queries un-processed: 6, in-progress: 19, ready: 0
27...
28Done in 19.21sec.
1display(dataset_df.head())
prompt | source | gpt4_response | mixtral_response | |
6062 | Based on the features mentioned, which hotel d... | [evol_instruct] | Based on the features mentioned, Hotel A seems... | Based on the information provided, I would su... |
113830 | Provide step-by-step instructions on how to cr... | [ultrachat] | Sure, here's a simple step-by-step guide on ho... | Sure, I'd be happy to help you make a homemad... |
138869 | What are the 10 largest cities in the US by po... | [lmsys-chat-1m] | As of the most recent data available, the 10 l... | Here are the 10 largest cities in the U.S. by... |
169249 | Write a comparison essay of at least 500 words... | [ultrachat] | Title: A Comparative Analysis of Driving a Car... | Title: The Great Debate: Driving a Car vs. Ri... |
116934 | Q: You are provided with an "Event", "Intent" ... | [flan_v2_niv2] | PersonX might feel satisfied or content using ... | PersonX likely feels comfortable and focused,... |
Let's first take a look at an example query we will send to GPT-4 for judgement
1from src.utils import inspect_llm_judge_queries
2
3inspect_llm_judge_queries(dataset_df)
1[Instruction]
2Evaluate the AI assistant's proficiency in answering the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, adherence to real-world facts, depth, creativity, and level of detail of the response. You will be given a reference answer which is considered of high quality. Your assessment will have two lines: First line has a rating on a scale of 1 to 5 with a higher rating representing higher response quality. Follow strictly this format: "[[rating]]", for example: "[[3]]". Second line contains a short explanation of your rating.
3
4[Question]
5Q: You are provided with an "Event", "Intent" related to PersonX. Guess a reaction/reaction of PersonX about the given event and their intention.
6Event:PersonX uses ___ in class. Intent: 1) to use his prefered writing implement
7A:
8
9[Reference Answer]
10PersonX might feel satisfied or content using their preferred writing implement in class, as it aligns with their intention to utilize a comfortable and desired tool for writing.
11Confidence: 85%
12
13[Assistant Answer]
14 PersonX probably feels comfortable and focused in class, as they are using their preferred writing implement. This may help them engage more effectively with the material being taught.
15
16Guidelines for Rating:
17 - High Rating (4-5): Reserved for responses that are very close to the quality of the reference or even better.
18 - Medium Rating (3): Reserved for responses that have moderate quality compared to the reference.
19 - Low Rating (1-2): Allocated to response that are much lower quality compared to the reference or completely wrong.
20
21Assessment:
Now, we apply a similar online batch-inference method to generate our labels.
1import os
2from src.online_inference import generate_llm_judge_labels
3
4dataset_df = generate_llm_judge_labels(dataset_df, os.getenv('OPENAI_API_KEY'))
1Starting batch inference on 30 queries...
2
3# queries un-processed: 29, in-progress: 1, ready: 0
4# queries un-processed: 28, in-progress: 2, ready: 0
5# queries un-processed: 27, in-progress: 3, ready: 0
6# queries un-processed: 26, in-progress: 4, ready: 0
7# queries un-processed: 25, in-progress: 5, ready: 0
8# queries un-processed: 24, in-progress: 6, ready: 0
9# queries un-processed: 23, in-progress: 7, ready: 0
10# queries un-processed: 22, in-progress: 7, ready: 1
11# queries un-processed: 21, in-progress: 7, ready: 1
12# queries un-processed: 20, in-progress: 8, ready: 0
13# queries un-processed: 19, in-progress: 8, ready: 1
14# queries un-processed: 18, in-progress: 9, ready: 0
15# queries un-processed: 17, in-progress: 10, ready: 0
16# queries un-processed: 17, in-progress: 9, ready: 1
17# queries un-processed: 16, in-progress: 9, ready: 1
18# queries un-processed: 15, in-progress: 9, ready: 1
19...
20Done in 16.43sec.
1display(dataset_df.head())
prompt | source | gpt4_response | mixtral_response | mixtral_score | |
6062 | Based on the features mentioned, which hotel d... | [evol_instruct] | Based on the features mentioned, Hotel A seems... | Based on the information provided, I would su... | 5 |
113830 | Provide step-by-step instructions on how to cr... | [ultrachat] | Sure, here's a simple step-by-step guide on ho... | Sure, I'd be happy to help you make a homemad... | 3 |
138869 | What are the 10 largest cities in the US by po... | [lmsys-chat-1m] | As of the most recent data available, the 10 l... | Here are the 10 largest cities in the U.S. by... | 5 |
169249 | Write a comparison essay of at least 500 words... | [ultrachat] | Title: A Comparative Analysis of Driving a Car... | Title: The Great Debate: Driving a Car vs. Ri... | 4 |
116934 | Q: You are provided with an "Event", "Intent" ... | [flan_v2_niv2] | PersonX might feel satisfied or content using ... | PersonX likely feels comfortable and focused,... | 5 |
We have previously generated the full labeled datasets, created a train and validation splits, and published them as a public huggingface dataset routellm/gpt4_dataset
. Let's load the dataset and explore the score distribution.
1from datasets import load_dataset
2from src.utils import visualize_label_distribution
3
4full_dataset_df = load_dataset("routellm/gpt4_dataset")
5train_df = full_dataset_df["train"].to_pandas()
6
7print(f"Train size: {len(train_df)}")
8display(train_df.head())
9visualize_label_distribution(train_df, key="mixtral_score")
1Train size: 109101
| source | gpt4_response | mixtral_response | mixtral_score | |
0 | I'll give you a review, can you extract the fo... | [lmsys-chat-1m] | Sure, here's the analysis of the review:\n\n1.... | Food aspects and opinion words:\n\n1. Made to ... | 4 |
1 | Answer the following question: Claim: "Joker m... | [flan_v2_cot] | The answer is no.\nChain of thoughts: Stan Lee... | The answer is no.\n\nChain of thoughts: While ... | 5 |
2 | TASK DEFINITION: In this task you will be give... | [flan_v2_niv2] | ZdoublexpropheciesS | ZdoublexpropheciesS\n\nIn this task, you are a... | 5 |
3 | Detailed Instructions: In this task, you need ... | [flan_v2_niv2] | Yes | No, 'station' is not the longest word in the s... | 5 |
4 | A guy pick up a woman Then he puts her down Ex... | [sharegpt] | This phrase could be interpreted as a joke bec... | This joke is a play on words and relies on the... | 5 |
Higher counts for 4-5 scores indicate that Mixtral-8x7B consistently produces high-quality responses, demonstrating its competitive performance compared to the June 2023 version of GPT-4, whose responses are logged in the Nectar dataset.
Let us assume that if the score is >= 4, we will route to the OSS model (indicating the response quality is good enough); otherwise, we will route to the closed model. Under this assumption, the data distribution looks like this:
1train_df["routing_label"] = train_df["mixtral_score"].apply(
2 lambda x: 1 if x >= 4 else 0
3)
4
5visualize_label_distribution(train_df, key="routing_label")
In this section, we will explain how to finetune a causal LLM classifier to be an effective router. While our data contains gpt4_response
and mixtral_response
, we will only use the pair (query
, mixtral_score
) for training. The goal is for the router to rely solely on the query text to determine which model to route to. Our approach is straightforward: we train a 5-way classifier to predict the mixtral_score
from the query
. At inference time, we will route to Mixtral if our router predicts a high score (i.e., 4-5) and to GPT-4 otherwise.
We will discuss a few preprocessing steps to prepare the data for finetuning an LLM classifier.
We use the instruction-following framework to finetune an LLM as a router. The task instructions guide the model to predict the score label for a given query. They ensure the model understands the evaluation criteria and can accurately assess the query's complexity and expected response quality.
1from src.utils import inspect_instructions
2
3inspect_instructions()
1[Instruction]
2Based on the question provided below, predict the score an expert evaluator would give to an AI assistant's response, considering its helpfulness, relevance, adherence to facts, depth, creativity, and detail. Your prediction should infer the level of proficiency needed to address the question effectively. Use a scale from 1 to 5, where a higher score indicates a higher anticipated quality of response. Provide your prediction as: "[[predicted rating]]".
3
4Score criteria:
5- **4-5**: The AI assistant can produce a very strong answer, showing deep understanding, creativity, detailed insight, and high relevance.
6- **3**: The AI assistant can provide an adequate answer with moderate detail, relevance, and factual accuracy.
7- **1-2**: The AI assistant will struggle to produce a strong answer due to the question's difficulty, vagueness, or the assistant's limitations.
8
9[Question]
10{question}
11
12Prediction:
To finetune the model, we must format the data to be compatible with Anyscale's finetuning API.
1from src.utils import prepare_ft_messages
2
3train_df["messages"] = prepare_ft_messages(train_df, "mixtral_score")
4
5# here's what the API data format looks like:
6display(train_df["messages"].iloc[0])
1[{'role': 'system',
2 'content': '[Instruction]\nBased on the question provided below, predict the score an expert evaluator would give to an AI assistant\'s response, considering its helpfulness, relevance, adherence to facts, depth, creativity, and detail. Your prediction should infer the level of proficiency needed to address the question effectively. Use a scale from 1 to 5, where a higher score indicates a higher anticipated quality of response. Provide your prediction as: "[[predicted rating]]".\n\nScore criteria:\n- **4-5**: The AI assistant can produce a very strong answer, showing deep understanding, creativity, detailed insight, and high relevance.\n- **3**: The AI assistant can provide an adequate answer with moderate detail, relevance, and factual accuracy.\n- **1-2**: The AI assistant will struggle to produce a strong answer due to the question\'s difficulty, vagueness, or the assistant\'s limitations.\n'},
3 {'role': 'user',
4 'content': "[Question]\nI'll give you a review, can you extract the food aspects and the opinion words of these aspects and analyze the sentiment of these opinion from this review? the review is:They tore the old NAME_1 down then built another one...? Anyway, they sell wine and beer and snacks and have a seating area inside and outside to eat. Besides gas, the big draw is the Made to Order food. I ordered some tacos and French toast sticks both were pretty good. I think I'd like to try more snacks.And they're open 24/7.\n\nPrediction:\n"},
5 {'role': 'assistant', 'content': '[[4]]'}]
For classification tasks, it's recommended to train on label-balanced datasets to ensure models are not biased to a specific label. We will balance the dataset based on routing_label
, as this is the label of primary interest.
1from src.utils import balance_dataset
2
3balanced_train_df = balance_dataset(train_df, key="routing_label")
4
5print(f"Train size: {len(balanced_train_df)}")
1Train size: 29504
To expedite the time to run this tutorial, we will subsample 1,000 examples for training. We'll store the data in JSONL format to prepare for launching the finetuning job in the next section.
1n_sample = 1000
2output_file = "/mnt/user_storage/train_data_sample.jsonl"
3
4subsampled_df = balanced_train_df.sample(n=n_sample, random_state=42)
5subsampled_df.to_json(output_file, orient="records", lines=True)
We will run a fine-tuning job using Anyscale's LLM finetuning API as an isolated job, similar to our end-to-end LLM workflows guide.
For this tutorial, we will perform full-parameter finetuning of Llama3-8B on the same 1,000 samples we showed earlier to debug the training dynamics and ensure the model can fit the training set. Below, we present the training and job configurations before submitting the training job.
1# View the full-param finetuning configuration for llama-3-8B
2!cat configs/ft_config_a10.yaml
1model_id: meta-llama/Meta-Llama-3-8B
2train_path: /mnt/user_storage/train_data_sample.jsonl
3valid_path: /mnt/user_storage/train_data_sample.jsonl
4context_length: 1024
5num_devices: 8
6num_epochs: 5
7checkpoint_every_n_epochs: 5
8train_batch_size_per_device: 4
9eval_batch_size_per_device: 4
10lr_scheduler_type: constant
11learning_rate: 1e-5
12num_checkpoints_to_keep: 1
13no_gradient_checkpoint: False
14output_dir: /mnt/local_storage
15deepspeed:
16 config_path: config_files/deepspeed/zero_3_optimizer_parameter_offload.json
17flash_attention_2: true
18classifier_config:
19 label_tokens:
20 - "[[1]]"
21 - "[[2]]"
22 - "[[3]]"
23 - "[[4]]"
24 - "[[5]]"
1# View job yaml config
2!cat configs/ft_job.yaml
1 name: llm-router-tutorial
2 entrypoint: python src/ft.py configs/ft_config_a10.yaml
3 image_uri: localhost:5555/anyscale/llm-forge:0.5.0.0
4 requirements: requirements.txt
5 max_retries: 0
1# Job submission
2!anyscale job submit --config-file configs/ft_job.yaml --exclude assets
1 Output
2 (anyscale +1.0s) Submitting job with config JobConfig(name='llm-router-tutorial', image_uri='localhost:5555/anyscale/llm-forge:0.5.0.0', compute_config=None, env_vars=None, py_modules=None, cloud=None, project=None, ray_version=None).
3 (anyscale +2.5s) Uploading local dir '.' to cloud storage.
4 (anyscale +3.5s) Job 'llm-router-tutorial' submitted, ID: 'prodjob_16krca7sgdjyeh2eyf81h6q9uf'.
5 (anyscale +3.5s) View the job in the UI: https://console.anyscale.com/jobs/prodjob_16krca7sgdjyeh2eyf81h6q9uf
6 (anyscale +3.5s) Use `--wait` to wait for the job to run and stream logs.
The job takes around 10 minutes on 4xA100-80gb
and 1 hour on 8xA10-22gb
to finish. Training logs will show the final model checkpoint, e.g.:
1Best checkpoint is stored in:
2storage-bucket-cld-tffbxe9ia5phqr1unxhz4f7e1e/org_4snvy99zwbmh4gbtk64jfqggmj/cld_tffbxe9ia5phqr1unxhz4f7e1e/artifact_storage/amjad__almahairi_dkaubsimoyxpiksqxqkxrfgfvzzotwtacs/llmforge-finetuning/meta-llama/Meta-Llama-3-8B/TorchTrainer_2024-06-21_17-02-52/epoch-4
3With perplexity: 1.0318867739521242
This checkpoint can be used to run batch inference or serve the model online.
Next, we will conduct an offline evaluation of the model trained on an out-of-domain dataset. The same model, now trained on the full dataset, is available in the following GitHub repository: https://github.com/lm-sys/RouteLLM/, along with other router models.
RouteLLM
package1# Clone the repository under /home/ray/default/
2!git clone https://github.com/lm-sys/RouteLLM.git /home/ray/default/RouteLLM
3
4# Change to the cloned repository directory
5%cd /home/ray/default/RouteLLM
6
7# Install the package with the specified extras
8!pip install -e .[eval]
1 ...
2 Successfully installed routellm-0.0.1
Let's show an example of loading the model and running inference with a single example sampled from our data. Note that you need to get access to meta-llama/Meta-Llama-3-8B
in order to run these evaluations. Let's first show how a formatted input looks like.
1# Store your `meta-llama` access token in /home/ray/default/.env with the name LLAMA2_HF_TOKEN
2from dotenv import load_dotenv
3load_dotenv("/home/ray/default/.env")
4
5from pprint import pprint
6
7# Sample one row from the DataFrame
8sampled_row = train_df.sample(n=1, random_state=42)
9
10# Convert the sampled row to a dictionary without the index
11input_example = sampled_row.to_dict(orient='records')[0]
12
13print("Prompt:", input_example['prompt'])
14print("Label:", input_example['mixtral_score'])
15print("Messages:")
16pprint(input_example['messages'])
1 Prompt: What challenges did FDR face while in office
2 Label: 5
3 Messages:
4 [{'content': '[Instruction]\n'
5 'Based on the question provided below, predict the score an '
6 "expert evaluator would give to an AI assistant's response, "
7 'considering its helpfulness, relevance, adherence to facts, '
8 'depth, creativity, and detail. Your prediction should infer the '
9 'level of proficiency needed to address the question effectively. '
10 'Use a scale from 1 to 5, where a higher score indicates a higher '
11 'anticipated quality of response. Provide your prediction as: '
12 '"[[predicted rating]]".\n'
13 '\n'
14 'Score criteria:\n'
15 '- **4-5**: The AI assistant can produce a very strong answer, '
16 'showing deep understanding, creativity, detailed insight, and '
17 'high relevance.\n'
18 '- **3**: The AI assistant can provide an adequate answer with '
19 'moderate detail, relevance, and factual accuracy.\n'
20 '- **1-2**: The AI assistant will struggle to produce a strong '
21 "answer due to the question's difficulty, vagueness, or the "
22 "assistant's limitations.\n",
23 'role': 'system'},
24 {'content': '[Question]\n'
25 'What challenges did FDR face while in office\n'
26 '\n'
27 'Prediction:\n',
28 'role': 'user'},
29 {'content': '[[5]]', 'role': 'assistant'}]
Let's run inference with this example and examine the model's output.
1from src.offline_inference import single_example_inference
2
3result = single_example_inference(input_example)
4pprint(result)
1Loading model checkpoint from routellm/causal_llm_gpt4_augmented ...
2
3Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00, 1.76it/s]
4Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
5
6Done loading model in 5.628355264663696 seconds.
7{'binary_prob': 0.9662781,
8 'output_ids': tensor([128006, 78191, 128007, 271, 128260, 128009]),
9 'output_str': '<|start_header_id|>assistant<|end_header_id|>\n'
10 '\n'
11 '[[5]]<|eot_id|>',
12 'output_tokens': ['<|start_header_id|>',
13 'assistant',
14 '<|end_header_id|>',
15 'ĊĊ',
16 '[[5]]',
17 '<|eot_id|>'],
18 'score_logits': array([10.3125, 10.9375, 11.4375, 14.4375, 15. ], dtype=float32),
19 'score_pred': 5,
20 'softmax_scores': array([0.00566901, 0.0105911 , 0.01746178, 0.3507292 , 0.6155489 ],
21 dtype=float32)}
The model outputs the predicted score as a special token[[5]]
, since it is trained to predict one of the 5 labels which we add as special tokens to the vocabulary. We extract softmax scores of each of 5 labels in softmax_scores
, and compute the routing probability as binary_prob = sum(softmax_scores[3:])
.
To optimize inference speed, we can append the header tokens <|start_header_id|>assistant<|end_header_id|>\n\n
so the first token that the model outputs is the predicted label.
We will use the RouteLLM evaluation framework to measure the performance of our router against a random router on GSM8K. We report the percentage of calls the router needs to send to GPT-4 in order to achieve 20%
, 50%
and 80%
of GPT-4 performance, along with area under curve. See our paper for more details on the evalaution metrics.
1!python -m routellm.evals.evaluate --config config.example.yaml --routers random causal_llm --benchmark gsm8k
1Namespace(routers=['random', 'causal_llm'], benchmark='gsm8k', output='.', overwrite_cache=[], parallel=96, config='config.example.yaml', num_results=10)
2...
3
4Loading model checkpoint from routellm/causal_llm_augmented ...
5Loading checkpoint shards: 100%|██████████████████| 4/4 [00:01<00:00, 2.00it/s]
6...
7100%|███████████████████████████████████████| 1307/1307 [06:31<00:00, 3.34it/s]
8...
9mistralai/Mixtral-8x7B-Instruct-v0.1 63.733741392501905
10gpt-4-1106-preview 85.76893649579189
11Saving plot to ./gsm8k.png
12
13Metrics:
14 method 20% qual 50% qual 80% qual AUC APGR
151 causal_llm 11.75% 34.06% 62.38% 77.540277 0.626567
160 random 19.69% 53.05% 83.02% 74.436777 0.485725
1from IPython.display import Image, display
2
3# Display full plot saved in the following path
4image_path = "/home/ray/default/RouteLLM/gsm8k.png"
5display(Image(filename=image_path))
This plot illustrates that as we relax the cost constraints (i.e., increase the percentage of GPT-4 calls), the performance improves. While the performance of a random router improves linearly with cost, our router achieves significantly better results at each cost level.
In this tutorial, we have successfully built and evaluated a finetuned-LLM router. We generated synthetic labeled data using the LLM-as-a-judge method to train the model, finetuned an LLM classifier using Anyscale's API, and conducted offline evaluation on a standard benchmark-- demonstrating that our model is effective in out-of-domain generalization.